summaryrefslogtreecommitdiffstats
path: root/third_party/rust/naga/src/back
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/naga/src/back')
-rw-r--r--third_party/rust/naga/src/back/dot/mod.rs669
-rw-r--r--third_party/rust/naga/src/back/glsl/features.rs525
-rw-r--r--third_party/rust/naga/src/back/glsl/keywords.rs204
-rw-r--r--third_party/rust/naga/src/back/glsl/mod.rs3840
-rw-r--r--third_party/rust/naga/src/back/hlsl/conv.rs227
-rw-r--r--third_party/rust/naga/src/back/hlsl/help.rs1195
-rw-r--r--third_party/rust/naga/src/back/hlsl/keywords.rs166
-rw-r--r--third_party/rust/naga/src/back/hlsl/mod.rs280
-rw-r--r--third_party/rust/naga/src/back/hlsl/storage.rs433
-rw-r--r--third_party/rust/naga/src/back/hlsl/writer.rs2980
-rw-r--r--third_party/rust/naga/src/back/mod.rs209
-rw-r--r--third_party/rust/naga/src/back/msl/keywords.rs217
-rw-r--r--third_party/rust/naga/src/back/msl/mod.rs497
-rw-r--r--third_party/rust/naga/src/back/msl/sampler.rs175
-rw-r--r--third_party/rust/naga/src/back/msl/writer.rs3985
-rw-r--r--third_party/rust/naga/src/back/spv/block.rs2121
-rw-r--r--third_party/rust/naga/src/back/spv/helpers.rs108
-rw-r--r--third_party/rust/naga/src/back/spv/image.rs1179
-rw-r--r--third_party/rust/naga/src/back/spv/index.rs417
-rw-r--r--third_party/rust/naga/src/back/spv/instructions.rs996
-rw-r--r--third_party/rust/naga/src/back/spv/layout.rs210
-rw-r--r--third_party/rust/naga/src/back/spv/mod.rs696
-rw-r--r--third_party/rust/naga/src/back/spv/recyclable.rs60
-rw-r--r--third_party/rust/naga/src/back/spv/selection.rs257
-rw-r--r--third_party/rust/naga/src/back/spv/writer.rs1695
-rw-r--r--third_party/rust/naga/src/back/wgsl/mod.rs52
-rw-r--r--third_party/rust/naga/src/back/wgsl/writer.rs2061
27 files changed, 25454 insertions, 0 deletions
diff --git a/third_party/rust/naga/src/back/dot/mod.rs b/third_party/rust/naga/src/back/dot/mod.rs
new file mode 100644
index 0000000000..ce88f02e8e
--- /dev/null
+++ b/third_party/rust/naga/src/back/dot/mod.rs
@@ -0,0 +1,669 @@
+/*!
+Backend for [DOT][dot] (Graphviz).
+
+This backend writes a graph in the DOT language, for the ease
+of IR inspection and debugging.
+
+[dot]: https://graphviz.org/doc/info/lang.html
+*/
+
+use crate::{
+ arena::Handle,
+ valid::{FunctionInfo, ModuleInfo},
+};
+
+use std::{
+ borrow::Cow,
+ fmt::{Error as FmtError, Write as _},
+};
+
+/// Configuration options for the dot backend
+#[derive(Default)]
+pub struct Options {
+ /// Only emit function bodies
+ pub cfg_only: bool,
+}
+
+/// Identifier used to address a graph node
+type NodeId = usize;
+
+/// Stores the target nodes for control flow statements
+#[derive(Default, Clone, Copy)]
+struct Targets {
+ /// The node, if some, where continue operations will land
+ continue_target: Option<usize>,
+ /// The node, if some, where break operations will land
+ break_target: Option<usize>,
+}
+
+/// Stores information about the graph of statements
+#[derive(Default)]
+struct StatementGraph {
+ /// List of node names
+ nodes: Vec<&'static str>,
+ /// List of edges of the control flow, the items are defined as
+ /// (from, to, label)
+ flow: Vec<(NodeId, NodeId, &'static str)>,
+ /// List of implicit edges of the control flow, used for jump
+ /// operations such as continue or break, the items are defined as
+ /// (from, to, label, color_id)
+ jumps: Vec<(NodeId, NodeId, &'static str, usize)>,
+ /// List of dependency relationships between a statement node and
+ /// expressions
+ dependencies: Vec<(NodeId, Handle<crate::Expression>, &'static str)>,
+ /// List of expression emitted by statement node
+ emits: Vec<(NodeId, Handle<crate::Expression>)>,
+ /// List of function call by statement node
+ calls: Vec<(NodeId, Handle<crate::Function>)>,
+}
+
+impl StatementGraph {
+ /// Adds a new block to the statement graph, returning the first and last node, respectively
+ fn add(&mut self, block: &[crate::Statement], targets: Targets) -> (NodeId, NodeId) {
+ use crate::Statement as S;
+
+ // The first node of the block isn't a statement but a virtual node
+ let root = self.nodes.len();
+ self.nodes.push(if root == 0 { "Root" } else { "Node" });
+ // Track the last placed node, this will be returned to the caller and
+ // will also be used to generate the control flow edges
+ let mut last_node = root;
+ for statement in block {
+ // Reserve a new node for the current statement and link it to the
+ // node of the previous statement
+ let id = self.nodes.len();
+ self.flow.push((last_node, id, ""));
+ self.nodes.push(""); // reserve space
+
+ // Track the node identifier for the merge node, the merge node is
+ // the last node of a statement, normally this is the node itself,
+ // but for control flow statements such as `if`s and `switch`s this
+ // is a virtual node where all branches merge back.
+ let mut merge_id = id;
+
+ self.nodes[id] = match *statement {
+ S::Emit(ref range) => {
+ for handle in range.clone() {
+ self.emits.push((id, handle));
+ }
+ "Emit"
+ }
+ S::Kill => "Kill", //TODO: link to the beginning
+ S::Break => {
+ // Try to link to the break target, otherwise produce
+ // a broken connection
+ if let Some(target) = targets.break_target {
+ self.jumps.push((id, target, "Break", 5))
+ } else {
+ self.jumps.push((id, root, "Broken", 7))
+ }
+ "Break"
+ }
+ S::Continue => {
+ // Try to link to the continue target, otherwise produce
+ // a broken connection
+ if let Some(target) = targets.continue_target {
+ self.jumps.push((id, target, "Continue", 5))
+ } else {
+ self.jumps.push((id, root, "Broken", 7))
+ }
+ "Continue"
+ }
+ S::Barrier(_flags) => "Barrier",
+ S::Block(ref b) => {
+ let (other, last) = self.add(b, targets);
+ self.flow.push((id, other, ""));
+ // All following nodes should connect to the end of the block
+ // statement so change the merge id to it.
+ merge_id = last;
+ "Block"
+ }
+ S::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ self.dependencies.push((id, condition, "condition"));
+ let (accept_id, accept_last) = self.add(accept, targets);
+ self.flow.push((id, accept_id, "accept"));
+ let (reject_id, reject_last) = self.add(reject, targets);
+ self.flow.push((id, reject_id, "reject"));
+
+ // Create a merge node, link the branches to it and set it
+ // as the merge node to make the next statement node link to it
+ merge_id = self.nodes.len();
+ self.nodes.push("Merge");
+ self.flow.push((accept_last, merge_id, ""));
+ self.flow.push((reject_last, merge_id, ""));
+
+ "If"
+ }
+ S::Switch {
+ selector,
+ ref cases,
+ } => {
+ self.dependencies.push((id, selector, "selector"));
+
+ // Create a merge node and set it as the merge node to make
+ // the next statement node link to it
+ merge_id = self.nodes.len();
+ self.nodes.push("Merge");
+
+ // Create a new targets structure and set the break target
+ // to the merge node
+ let mut targets = targets;
+ targets.break_target = Some(merge_id);
+
+ for case in cases {
+ let (case_id, case_last) = self.add(&case.body, targets);
+ let label = match case.value {
+ crate::SwitchValue::Integer(_) => "case",
+ crate::SwitchValue::Default => "default",
+ };
+ self.flow.push((id, case_id, label));
+ // Link the last node of the branch to the merge node
+ self.flow.push((case_last, merge_id, ""));
+ }
+ "Switch"
+ }
+ S::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ // Create a new targets structure and set the break target
+ // to the merge node, this must happen before generating the
+ // continuing block since it can break.
+ let mut targets = targets;
+ targets.break_target = Some(id);
+
+ let (continuing_id, continuing_last) = self.add(continuing, targets);
+
+ // Set the the continue target to the beginning
+ // of the newly generated continuing block
+ targets.continue_target = Some(continuing_id);
+
+ let (body_id, body_last) = self.add(body, targets);
+
+ self.flow.push((id, body_id, "body"));
+
+ // Link the last node of the body to the continuing block
+ self.flow.push((body_last, continuing_id, "continuing"));
+ // Link the last node of the continuing block back to the
+ // beginning of the loop body
+ self.flow.push((continuing_last, body_id, "continuing"));
+
+ if let Some(expr) = break_if {
+ self.dependencies.push((continuing_id, expr, "break if"));
+ }
+
+ "Loop"
+ }
+ S::Return { value } => {
+ if let Some(expr) = value {
+ self.dependencies.push((id, expr, "value"));
+ }
+ "Return"
+ }
+ S::Store { pointer, value } => {
+ self.dependencies.push((id, value, "value"));
+ self.emits.push((id, pointer));
+ "Store"
+ }
+ S::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => {
+ self.dependencies.push((id, image, "image"));
+ self.dependencies.push((id, coordinate, "coordinate"));
+ if let Some(expr) = array_index {
+ self.dependencies.push((id, expr, "array_index"));
+ }
+ self.dependencies.push((id, value, "value"));
+ "ImageStore"
+ }
+ S::Call {
+ function,
+ ref arguments,
+ result,
+ } => {
+ for &arg in arguments {
+ self.dependencies.push((id, arg, "arg"));
+ }
+ if let Some(expr) = result {
+ self.emits.push((id, expr));
+ }
+ self.calls.push((id, function));
+ "Call"
+ }
+ S::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result,
+ } => {
+ self.emits.push((id, result));
+ self.dependencies.push((id, pointer, "pointer"));
+ self.dependencies.push((id, value, "value"));
+ if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
+ self.dependencies.push((id, cmp, "cmp"));
+ }
+ "Atomic"
+ }
+ };
+ // Set the last node to the merge node
+ last_node = merge_id;
+ }
+ (root, last_node)
+ }
+}
+
+#[allow(clippy::manual_unwrap_or)]
+fn name(option: &Option<String>) -> &str {
+ match *option {
+ Some(ref name) => name,
+ None => "",
+ }
+}
+
+/// set39 color scheme from <https://graphviz.org/doc/info/colors.html>
+const COLORS: &[&str] = &[
+ "white", // pattern starts at 1
+ "#8dd3c7", "#ffffb3", "#bebada", "#fb8072", "#80b1d3", "#fdb462", "#b3de69", "#fccde5",
+ "#d9d9d9",
+];
+
+fn write_fun(
+ output: &mut String,
+ prefix: String,
+ fun: &crate::Function,
+ info: Option<&FunctionInfo>,
+ options: &Options,
+) -> Result<(), FmtError> {
+ writeln!(output, "\t\tnode [ style=filled ]")?;
+
+ if !options.cfg_only {
+ for (handle, var) in fun.local_variables.iter() {
+ writeln!(
+ output,
+ "\t\t{}_l{} [ shape=hexagon label=\"{:?} '{}'\" ]",
+ prefix,
+ handle.index(),
+ handle,
+ name(&var.name),
+ )?;
+ }
+
+ write_function_expressions(output, &prefix, fun, info)?;
+ }
+
+ let mut sg = StatementGraph::default();
+ sg.add(&fun.body, Targets::default());
+ for (index, label) in sg.nodes.into_iter().enumerate() {
+ writeln!(
+ output,
+ "\t\t{}_s{} [ shape=square label=\"{}\" ]",
+ prefix, index, label,
+ )?;
+ }
+ for (from, to, label) in sg.flow {
+ writeln!(
+ output,
+ "\t\t{}_s{} -> {}_s{} [ arrowhead=tee label=\"{}\" ]",
+ prefix, from, prefix, to, label,
+ )?;
+ }
+ for (from, to, label, color_id) in sg.jumps {
+ writeln!(
+ output,
+ "\t\t{}_s{} -> {}_s{} [ arrowhead=tee style=dashed color=\"{}\" label=\"{}\" ]",
+ prefix, from, prefix, to, COLORS[color_id], label,
+ )?;
+ }
+
+ if !options.cfg_only {
+ for (to, expr, label) in sg.dependencies {
+ writeln!(
+ output,
+ "\t\t{}_e{} -> {}_s{} [ label=\"{}\" ]",
+ prefix,
+ expr.index(),
+ prefix,
+ to,
+ label,
+ )?;
+ }
+ for (from, to) in sg.emits {
+ writeln!(
+ output,
+ "\t\t{}_s{} -> {}_e{} [ style=dotted ]",
+ prefix,
+ from,
+ prefix,
+ to.index(),
+ )?;
+ }
+ }
+
+ for (from, function) in sg.calls {
+ writeln!(
+ output,
+ "\t\t{}_s{} -> f{}_s0",
+ prefix,
+ from,
+ function.index(),
+ )?;
+ }
+
+ Ok(())
+}
+
+fn write_function_expressions(
+ output: &mut String,
+ prefix: &str,
+ fun: &crate::Function,
+ info: Option<&FunctionInfo>,
+) -> Result<(), FmtError> {
+ enum Payload<'a> {
+ Arguments(&'a [Handle<crate::Expression>]),
+ Local(Handle<crate::LocalVariable>),
+ Global(Handle<crate::GlobalVariable>),
+ }
+
+ let mut edges = crate::FastHashMap::<&str, _>::default();
+ let mut payload = None;
+ for (handle, expression) in fun.expressions.iter() {
+ use crate::Expression as E;
+ let (label, color_id) = match *expression {
+ E::Access { base, index } => {
+ edges.insert("base", base);
+ edges.insert("index", index);
+ ("Access".into(), 1)
+ }
+ E::AccessIndex { base, index } => {
+ edges.insert("base", base);
+ (format!("AccessIndex[{}]", index).into(), 1)
+ }
+ E::Constant(_) => ("Constant".into(), 2),
+ E::Splat { size, value } => {
+ edges.insert("value", value);
+ (format!("Splat{:?}", size).into(), 3)
+ }
+ E::Swizzle {
+ size,
+ vector,
+ pattern,
+ } => {
+ edges.insert("vector", vector);
+ (format!("Swizzle{:?}", &pattern[..size as usize]).into(), 3)
+ }
+ E::Compose { ref components, .. } => {
+ payload = Some(Payload::Arguments(components));
+ ("Compose".into(), 3)
+ }
+ E::FunctionArgument(index) => (format!("Argument[{}]", index).into(), 1),
+ E::GlobalVariable(h) => {
+ payload = Some(Payload::Global(h));
+ ("Global".into(), 2)
+ }
+ E::LocalVariable(h) => {
+ payload = Some(Payload::Local(h));
+ ("Local".into(), 1)
+ }
+ E::Load { pointer } => {
+ edges.insert("pointer", pointer);
+ ("Load".into(), 4)
+ }
+ E::ImageSample {
+ image,
+ sampler,
+ gather,
+ coordinate,
+ array_index,
+ offset: _,
+ level,
+ depth_ref,
+ } => {
+ edges.insert("image", image);
+ edges.insert("sampler", sampler);
+ edges.insert("coordinate", coordinate);
+ if let Some(expr) = array_index {
+ edges.insert("array_index", expr);
+ }
+ match level {
+ crate::SampleLevel::Auto => {}
+ crate::SampleLevel::Zero => {}
+ crate::SampleLevel::Exact(expr) => {
+ edges.insert("level", expr);
+ }
+ crate::SampleLevel::Bias(expr) => {
+ edges.insert("bias", expr);
+ }
+ crate::SampleLevel::Gradient { x, y } => {
+ edges.insert("grad_x", x);
+ edges.insert("grad_y", y);
+ }
+ }
+ if let Some(expr) = depth_ref {
+ edges.insert("depth_ref", expr);
+ }
+ let string = match gather {
+ Some(component) => Cow::Owned(format!("ImageGather{:?}", component)),
+ _ => Cow::Borrowed("ImageSample"),
+ };
+ (string, 5)
+ }
+ E::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => {
+ edges.insert("image", image);
+ edges.insert("coordinate", coordinate);
+ if let Some(expr) = array_index {
+ edges.insert("array_index", expr);
+ }
+ if let Some(sample) = sample {
+ edges.insert("sample", sample);
+ }
+ if let Some(level) = level {
+ edges.insert("level", level);
+ }
+ ("ImageLoad".into(), 5)
+ }
+ E::ImageQuery { image, query } => {
+ edges.insert("image", image);
+ let args = match query {
+ crate::ImageQuery::Size { level } => {
+ if let Some(expr) = level {
+ edges.insert("level", expr);
+ }
+ Cow::from("ImageSize")
+ }
+ _ => Cow::Owned(format!("{:?}", query)),
+ };
+ (args, 7)
+ }
+ E::Unary { op, expr } => {
+ edges.insert("expr", expr);
+ (format!("{:?}", op).into(), 6)
+ }
+ E::Binary { op, left, right } => {
+ edges.insert("left", left);
+ edges.insert("right", right);
+ (format!("{:?}", op).into(), 6)
+ }
+ E::Select {
+ condition,
+ accept,
+ reject,
+ } => {
+ edges.insert("condition", condition);
+ edges.insert("accept", accept);
+ edges.insert("reject", reject);
+ ("Select".into(), 3)
+ }
+ E::Derivative { axis, expr } => {
+ edges.insert("", expr);
+ (format!("d{:?}", axis).into(), 8)
+ }
+ E::Relational { fun, argument } => {
+ edges.insert("arg", argument);
+ (format!("{:?}", fun).into(), 6)
+ }
+ E::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ edges.insert("arg", arg);
+ if let Some(expr) = arg1 {
+ edges.insert("arg1", expr);
+ }
+ if let Some(expr) = arg2 {
+ edges.insert("arg2", expr);
+ }
+ if let Some(expr) = arg3 {
+ edges.insert("arg3", expr);
+ }
+ (format!("{:?}", fun).into(), 7)
+ }
+ E::As {
+ kind,
+ expr,
+ convert,
+ } => {
+ edges.insert("", expr);
+ let string = match convert {
+ Some(width) => format!("Convert<{:?},{}>", kind, width),
+ None => format!("Bitcast<{:?}>", kind),
+ };
+ (string.into(), 3)
+ }
+ E::CallResult(_function) => ("CallResult".into(), 4),
+ E::AtomicResult { .. } => ("AtomicResult".into(), 4),
+ E::ArrayLength(expr) => {
+ edges.insert("", expr);
+ ("ArrayLength".into(), 7)
+ }
+ };
+
+ // give uniform expressions an outline
+ let color_attr = match info {
+ Some(info) if info[handle].uniformity.non_uniform_result.is_none() => "fillcolor",
+ _ => "color",
+ };
+ writeln!(
+ output,
+ "\t\t{}_e{} [ {}=\"{}\" label=\"{:?} {}\" ]",
+ prefix,
+ handle.index(),
+ color_attr,
+ COLORS[color_id],
+ handle,
+ label,
+ )?;
+
+ for (key, edge) in edges.drain() {
+ writeln!(
+ output,
+ "\t\t{}_e{} -> {}_e{} [ label=\"{}\" ]",
+ prefix,
+ edge.index(),
+ prefix,
+ handle.index(),
+ key,
+ )?;
+ }
+ match payload.take() {
+ Some(Payload::Arguments(list)) => {
+ write!(output, "\t\t{{")?;
+ for &comp in list {
+ write!(output, " {}_e{}", prefix, comp.index())?;
+ }
+ writeln!(output, " }} -> {}_e{}", prefix, handle.index())?;
+ }
+ Some(Payload::Local(h)) => {
+ writeln!(
+ output,
+ "\t\t{}_l{} -> {}_e{}",
+ prefix,
+ h.index(),
+ prefix,
+ handle.index(),
+ )?;
+ }
+ Some(Payload::Global(h)) => {
+ writeln!(
+ output,
+ "\t\tg{} -> {}_e{} [fillcolor=gray]",
+ h.index(),
+ prefix,
+ handle.index(),
+ )?;
+ }
+ None => {}
+ }
+ }
+
+ Ok(())
+}
+
+/// Write shader module to a [`String`].
+pub fn write(
+ module: &crate::Module,
+ mod_info: Option<&ModuleInfo>,
+ options: Options,
+) -> Result<String, FmtError> {
+ use std::fmt::Write as _;
+
+ let mut output = String::new();
+ output += "digraph Module {\n";
+
+ if !options.cfg_only {
+ writeln!(output, "\tsubgraph cluster_globals {{")?;
+ writeln!(output, "\t\tlabel=\"Globals\"")?;
+ for (handle, var) in module.global_variables.iter() {
+ writeln!(
+ output,
+ "\t\tg{} [ shape=hexagon label=\"{:?} {:?}/'{}'\" ]",
+ handle.index(),
+ handle,
+ var.space,
+ name(&var.name),
+ )?;
+ }
+ writeln!(output, "\t}}")?;
+ }
+
+ for (handle, fun) in module.functions.iter() {
+ let prefix = format!("f{}", handle.index());
+ writeln!(output, "\tsubgraph cluster_{} {{", prefix)?;
+ writeln!(
+ output,
+ "\t\tlabel=\"Function{:?}/'{}'\"",
+ handle,
+ name(&fun.name)
+ )?;
+ let info = mod_info.map(|a| &a[handle]);
+ write_fun(&mut output, prefix, fun, info, &options)?;
+ writeln!(output, "\t}}")?;
+ }
+ for (ep_index, ep) in module.entry_points.iter().enumerate() {
+ let prefix = format!("ep{}", ep_index);
+ writeln!(output, "\tsubgraph cluster_{} {{", prefix)?;
+ writeln!(output, "\t\tlabel=\"{:?}/'{}'\"", ep.stage, ep.name)?;
+ let info = mod_info.map(|a| a.get_entry_point(ep_index));
+ write_fun(&mut output, prefix, &ep.function, info, &options)?;
+ writeln!(output, "\t}}")?;
+ }
+
+ output += "}\n";
+ Ok(output)
+}
diff --git a/third_party/rust/naga/src/back/glsl/features.rs b/third_party/rust/naga/src/back/glsl/features.rs
new file mode 100644
index 0000000000..b898b1d2b3
--- /dev/null
+++ b/third_party/rust/naga/src/back/glsl/features.rs
@@ -0,0 +1,525 @@
+use super::{BackendResult, Error, Version, Writer};
+use crate::{
+ AddressSpace, Binding, Bytes, Expression, Handle, ImageClass, ImageDimension, Interpolation,
+ MathFunction, Sampling, ScalarKind, ShaderStage, StorageFormat, Type, TypeInner,
+};
+use std::fmt::Write;
+
+bitflags::bitflags! {
+ /// Structure used to encode additions to GLSL that aren't supported by all versions.
+ pub struct Features: u32 {
+ /// Buffer address space support.
+ const BUFFER_STORAGE = 1;
+ const ARRAY_OF_ARRAYS = 1 << 1;
+ /// 8 byte floats.
+ const DOUBLE_TYPE = 1 << 2;
+ /// More image formats.
+ const FULL_IMAGE_FORMATS = 1 << 3;
+ const MULTISAMPLED_TEXTURES = 1 << 4;
+ const MULTISAMPLED_TEXTURE_ARRAYS = 1 << 5;
+ const CUBE_TEXTURES_ARRAY = 1 << 6;
+ const COMPUTE_SHADER = 1 << 7;
+ /// Image load and early depth tests.
+ const IMAGE_LOAD_STORE = 1 << 8;
+ const CONSERVATIVE_DEPTH = 1 << 9;
+ /// Interpolation and auxiliary qualifiers.
+ ///
+ /// Perspective, Flat, and Centroid are available in all GLSL versions we support.
+ const NOPERSPECTIVE_QUALIFIER = 1 << 11;
+ const SAMPLE_QUALIFIER = 1 << 12;
+ const CLIP_DISTANCE = 1 << 13;
+ const CULL_DISTANCE = 1 << 14;
+ /// Sample ID.
+ const SAMPLE_VARIABLES = 1 << 15;
+ /// Arrays with a dynamic length.
+ const DYNAMIC_ARRAY_SIZE = 1 << 16;
+ const MULTI_VIEW = 1 << 17;
+ /// Fused multiply-add.
+ const FMA = 1 << 18;
+ /// Texture samples query
+ const TEXTURE_SAMPLES = 1 << 19;
+ /// Texture levels query
+ const TEXTURE_LEVELS = 1 << 20;
+ /// Image size query
+ const IMAGE_SIZE = 1 << 21;
+ }
+}
+
+/// Helper structure used to store the required [`Features`] needed to output a
+/// [`Module`](crate::Module)
+///
+/// Provides helper methods to check for availability and writing required extensions
+pub struct FeaturesManager(Features);
+
+impl FeaturesManager {
+ /// Creates a new [`FeaturesManager`] instance
+ pub const fn new() -> Self {
+ Self(Features::empty())
+ }
+
+ /// Adds to the list of required [`Features`]
+ pub fn request(&mut self, features: Features) {
+ self.0 |= features
+ }
+
+ /// Checks that all required [`Features`] are available for the specified
+ /// [`Version`](super::Version) otherwise returns an
+ /// [`Error::MissingFeatures`](super::Error::MissingFeatures)
+ pub fn check_availability(&self, version: Version) -> BackendResult {
+ // Will store all the features that are unavailable
+ let mut missing = Features::empty();
+
+ // Helper macro to check for feature availability
+ macro_rules! check_feature {
+ // Used when only core glsl supports the feature
+ ($feature:ident, $core:literal) => {
+ if self.0.contains(Features::$feature)
+ && (version < Version::Desktop($core) || version.is_es())
+ {
+ missing |= Features::$feature;
+ }
+ };
+ // Used when both core and es support the feature
+ ($feature:ident, $core:literal, $es:literal) => {
+ if self.0.contains(Features::$feature)
+ && (version < Version::Desktop($core) || version < Version::new_gles($es))
+ {
+ missing |= Features::$feature;
+ }
+ };
+ }
+
+ check_feature!(COMPUTE_SHADER, 420, 310);
+ check_feature!(BUFFER_STORAGE, 400, 310);
+ check_feature!(DOUBLE_TYPE, 150);
+ check_feature!(CUBE_TEXTURES_ARRAY, 130, 310);
+ check_feature!(MULTISAMPLED_TEXTURES, 150, 300);
+ check_feature!(MULTISAMPLED_TEXTURE_ARRAYS, 150, 310);
+ check_feature!(ARRAY_OF_ARRAYS, 120, 310);
+ check_feature!(IMAGE_LOAD_STORE, 130, 310);
+ check_feature!(CONSERVATIVE_DEPTH, 130, 300);
+ check_feature!(CONSERVATIVE_DEPTH, 130, 300);
+ check_feature!(NOPERSPECTIVE_QUALIFIER, 130);
+ check_feature!(SAMPLE_QUALIFIER, 400, 320);
+ // gl_ClipDistance is supported by core versions > 1.3 and aren't supported by an es versions without extensions
+ check_feature!(CLIP_DISTANCE, 130, 300);
+ check_feature!(CULL_DISTANCE, 450, 300);
+ check_feature!(SAMPLE_VARIABLES, 400, 300);
+ check_feature!(DYNAMIC_ARRAY_SIZE, 430, 310);
+ match version {
+ Version::Embedded { is_webgl: true, .. } => check_feature!(MULTI_VIEW, 140, 300),
+ _ => check_feature!(MULTI_VIEW, 140, 310),
+ };
+ // Only available on glsl core, this means that opengl es can't query the number
+ // of samples nor levels in a image and neither do bound checks on the sample nor
+ // the level argument of texelFecth
+ check_feature!(TEXTURE_SAMPLES, 150);
+ check_feature!(TEXTURE_LEVELS, 130);
+ check_feature!(IMAGE_SIZE, 430, 310);
+
+ // Return an error if there are missing features
+ if missing.is_empty() {
+ Ok(())
+ } else {
+ Err(Error::MissingFeatures(missing))
+ }
+ }
+
+ /// Helper method used to write all needed extensions
+ ///
+ /// # Notes
+ /// This won't check for feature availability so it might output extensions that aren't even
+ /// supported.[`check_availability`](Self::check_availability) will check feature availability
+ pub fn write(&self, version: Version, mut out: impl Write) -> BackendResult {
+ if self.0.contains(Features::COMPUTE_SHADER) && !version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_compute_shader.txt
+ writeln!(out, "#extension GL_ARB_compute_shader : require")?;
+ }
+
+ if self.0.contains(Features::BUFFER_STORAGE) && !version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_shader_storage_buffer_object.txt
+ writeln!(
+ out,
+ "#extension GL_ARB_shader_storage_buffer_object : require"
+ )?;
+ }
+
+ if self.0.contains(Features::DOUBLE_TYPE) && version < Version::Desktop(400) {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_gpu_shader_fp64.txt
+ writeln!(out, "#extension GL_ARB_gpu_shader_fp64 : require")?;
+ }
+
+ if self.0.contains(Features::CUBE_TEXTURES_ARRAY) {
+ if version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_texture_cube_map_array.txt
+ writeln!(out, "#extension GL_EXT_texture_cube_map_array : require")?;
+ } else if version < Version::Desktop(400) {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_texture_cube_map_array.txt
+ writeln!(out, "#extension GL_ARB_texture_cube_map_array : require")?;
+ }
+ }
+
+ if self.0.contains(Features::MULTISAMPLED_TEXTURE_ARRAYS) && version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/OES/OES_texture_storage_multisample_2d_array.txt
+ writeln!(
+ out,
+ "#extension GL_OES_texture_storage_multisample_2d_array : require"
+ )?;
+ }
+
+ if self.0.contains(Features::ARRAY_OF_ARRAYS) && version < Version::Desktop(430) {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_arrays_of_arrays.txt
+ writeln!(out, "#extension ARB_arrays_of_arrays : require")?;
+ }
+
+ if self.0.contains(Features::IMAGE_LOAD_STORE) {
+ if self.0.contains(Features::FULL_IMAGE_FORMATS) && version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/NV/NV_image_formats.txt
+ writeln!(out, "#extension GL_NV_image_formats : require")?;
+ }
+
+ if version < Version::Desktop(420) {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_shader_image_load_store.txt
+ writeln!(out, "#extension GL_ARB_shader_image_load_store : require")?;
+ }
+ }
+
+ if self.0.contains(Features::CONSERVATIVE_DEPTH) {
+ if version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_conservative_depth.txt
+ writeln!(out, "#extension GL_EXT_conservative_depth : require")?;
+ }
+
+ if version < Version::Desktop(420) {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_conservative_depth.txt
+ writeln!(out, "#extension GL_ARB_conservative_depth : require")?;
+ }
+ }
+
+ if (self.0.contains(Features::CLIP_DISTANCE) || self.0.contains(Features::CULL_DISTANCE))
+ && version.is_es()
+ {
+ // TODO: handle gl_ClipDistance and gl_CullDistance usage in better way
+ // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_clip_cull_distance.txt
+ // writeln!(out, "#extension GL_EXT_clip_cull_distance : require")?;
+ }
+
+ if self.0.contains(Features::SAMPLE_VARIABLES) && version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/OES/OES_sample_variables.txt
+ writeln!(out, "#extension GL_OES_sample_variables : require")?;
+ }
+
+ if self.0.contains(Features::SAMPLE_VARIABLES) && version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/OES/OES_sample_variables.txt
+ writeln!(out, "#extension GL_OES_sample_variables : require")?;
+ }
+
+ if self.0.contains(Features::MULTI_VIEW) {
+ if let Version::Embedded { is_webgl: true, .. } = version {
+ // https://www.khronos.org/registry/OpenGL/extensions/OVR/OVR_multiview2.txt
+ writeln!(out, "#extension GL_OVR_multiview2 : require")?;
+ } else {
+ // https://github.com/KhronosGroup/GLSL/blob/master/extensions/ext/GL_EXT_multiview.txt
+ writeln!(out, "#extension GL_EXT_multiview : require")?;
+ }
+ }
+
+ if self.0.contains(Features::FMA) && version >= Version::new_gles(310) {
+ // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_gpu_shader5.txt
+ writeln!(out, "#extension GL_EXT_gpu_shader5 : require")?;
+ }
+
+ if self.0.contains(Features::TEXTURE_SAMPLES) {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_shader_texture_image_samples.txt
+ writeln!(
+ out,
+ "#extension GL_ARB_shader_texture_image_samples : require"
+ )?;
+ }
+
+ if self.0.contains(Features::TEXTURE_LEVELS) && version < Version::Desktop(430) {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_texture_query_levels.txt
+ writeln!(out, "#extension GL_ARB_texture_query_levels : require")?;
+ }
+
+ Ok(())
+ }
+}
+
+impl<'a, W> Writer<'a, W> {
+ /// Helper method that searches the module for all the needed [`Features`]
+ ///
+ /// # Errors
+ /// If the version doesn't support any of the needed [`Features`] a
+ /// [`Error::MissingFeatures`](super::Error::MissingFeatures) will be returned
+ pub(super) fn collect_required_features(&mut self) -> BackendResult {
+ let ep_info = self.info.get_entry_point(self.entry_point_idx as usize);
+
+ if let Some(depth_test) = self.entry_point.early_depth_test {
+ // If IMAGE_LOAD_STORE is supported for this version of GLSL
+ if self.options.version.supports_early_depth_test() {
+ self.features.request(Features::IMAGE_LOAD_STORE);
+ }
+
+ if depth_test.conservative.is_some() {
+ self.features.request(Features::CONSERVATIVE_DEPTH);
+ }
+ }
+
+ for arg in self.entry_point.function.arguments.iter() {
+ self.varying_required_features(arg.binding.as_ref(), arg.ty);
+ }
+ if let Some(ref result) = self.entry_point.function.result {
+ self.varying_required_features(result.binding.as_ref(), result.ty);
+ }
+
+ if let ShaderStage::Compute = self.entry_point.stage {
+ self.features.request(Features::COMPUTE_SHADER)
+ }
+
+ if self.multiview.is_some() {
+ self.features.request(Features::MULTI_VIEW);
+ }
+
+ for (ty_handle, ty) in self.module.types.iter() {
+ match ty.inner {
+ TypeInner::Scalar { kind, width } => self.scalar_required_features(kind, width),
+ TypeInner::Vector { kind, width, .. } => self.scalar_required_features(kind, width),
+ TypeInner::Matrix { width, .. } => {
+ self.scalar_required_features(ScalarKind::Float, width)
+ }
+ TypeInner::Array { base, size, .. } => {
+ if let TypeInner::Array { .. } = self.module.types[base].inner {
+ self.features.request(Features::ARRAY_OF_ARRAYS)
+ }
+
+ // If the array is dynamically sized
+ if size == crate::ArraySize::Dynamic {
+ let mut is_used = false;
+
+ // Check if this type is used in a global that is needed by the current entrypoint
+ for (global_handle, global) in self.module.global_variables.iter() {
+ // Skip unused globals
+ if ep_info[global_handle].is_empty() {
+ continue;
+ }
+
+ // If this array is the type of a global, then this array is used
+ if global.ty == ty_handle {
+ is_used = true;
+ break;
+ }
+
+ // If the type of this global is a struct
+ if let crate::TypeInner::Struct { ref members, .. } =
+ self.module.types[global.ty].inner
+ {
+ // Check the last element of the struct to see if it's type uses
+ // this array
+ if let Some(last) = members.last() {
+ if last.ty == ty_handle {
+ is_used = true;
+ break;
+ }
+ }
+ }
+ }
+
+ // If this dynamically size array is used, we need dynamic array size support
+ if is_used {
+ self.features.request(Features::DYNAMIC_ARRAY_SIZE);
+ }
+ }
+ }
+ TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ if arrayed && dim == ImageDimension::Cube {
+ self.features.request(Features::CUBE_TEXTURES_ARRAY)
+ }
+
+ match class {
+ ImageClass::Sampled { multi: true, .. }
+ | ImageClass::Depth { multi: true } => {
+ self.features.request(Features::MULTISAMPLED_TEXTURES);
+ if arrayed {
+ self.features.request(Features::MULTISAMPLED_TEXTURE_ARRAYS);
+ }
+ }
+ ImageClass::Storage { format, .. } => match format {
+ StorageFormat::R8Unorm
+ | StorageFormat::R8Snorm
+ | StorageFormat::R8Uint
+ | StorageFormat::R8Sint
+ | StorageFormat::R16Uint
+ | StorageFormat::R16Sint
+ | StorageFormat::R16Float
+ | StorageFormat::Rg8Unorm
+ | StorageFormat::Rg8Snorm
+ | StorageFormat::Rg8Uint
+ | StorageFormat::Rg8Sint
+ | StorageFormat::Rg16Uint
+ | StorageFormat::Rg16Sint
+ | StorageFormat::Rg16Float
+ | StorageFormat::Rgb10a2Unorm
+ | StorageFormat::Rg11b10Float
+ | StorageFormat::Rg32Uint
+ | StorageFormat::Rg32Sint
+ | StorageFormat::Rg32Float => {
+ self.features.request(Features::FULL_IMAGE_FORMATS)
+ }
+ _ => {}
+ },
+ ImageClass::Sampled { multi: false, .. }
+ | ImageClass::Depth { multi: false } => {}
+ }
+ }
+ _ => {}
+ }
+ }
+
+ let mut push_constant_used = false;
+
+ for (handle, global) in self.module.global_variables.iter() {
+ if ep_info[handle].is_empty() {
+ continue;
+ }
+ match global.space {
+ AddressSpace::WorkGroup => self.features.request(Features::COMPUTE_SHADER),
+ AddressSpace::Storage { .. } => self.features.request(Features::BUFFER_STORAGE),
+ AddressSpace::PushConstant => {
+ if push_constant_used {
+ return Err(Error::MultiplePushConstants);
+ }
+ push_constant_used = true;
+ }
+ _ => {}
+ }
+ }
+
+ // We will need to pass some of the members to a closure, so we need
+ // to separate them otherwise the borrow checker will complain, this
+ // shouldn't be needed in rust 2021
+ let &mut Self {
+ module,
+ info,
+ ref mut features,
+ entry_point,
+ entry_point_idx,
+ ref policies,
+ ..
+ } = self;
+
+ // Loop trough all expressions in both functions and the entry point
+ // to check for needed features
+ for (expressions, info) in module
+ .functions
+ .iter()
+ .map(|(h, f)| (&f.expressions, &info[h]))
+ .chain(std::iter::once((
+ &entry_point.function.expressions,
+ info.get_entry_point(entry_point_idx as usize),
+ )))
+ {
+ for (_, expr) in expressions.iter() {
+ match *expr {
+ // Check for fused multiply add use
+ Expression::Math { fun, .. } if fun == MathFunction::Fma => {
+ features.request(Features::FMA)
+ }
+ // Check for queries that neeed aditonal features
+ Expression::ImageQuery {
+ image,
+ query,
+ ..
+ } => match query {
+ // Storage images use `imageSize` which is only available
+ // in glsl > 420
+ //
+ // layers queries are also implemented as size queries
+ crate::ImageQuery::Size { .. } | crate::ImageQuery::NumLayers => {
+ if let TypeInner::Image {
+ class: crate::ImageClass::Storage { .. }, ..
+ } = *info[image].ty.inner_with(&module.types) {
+ features.request(Features::IMAGE_SIZE)
+ }
+ },
+ crate::ImageQuery::NumLevels => features.request(Features::TEXTURE_LEVELS),
+ crate::ImageQuery::NumSamples => features.request(Features::TEXTURE_SAMPLES),
+ }
+ ,
+ // Check for image loads that needs bound checking on the sample
+ // or level argument since this requires a feature
+ Expression::ImageLoad {
+ sample, level, ..
+ } => {
+ if policies.image != crate::proc::BoundsCheckPolicy::Unchecked {
+ if sample.is_some() {
+ features.request(Features::TEXTURE_SAMPLES)
+ }
+
+ if level.is_some() {
+ features.request(Features::TEXTURE_LEVELS)
+ }
+ }
+ }
+ _ => {}
+ }
+ }
+ }
+
+ self.features.check_availability(self.options.version)
+ }
+
+ /// Helper method that checks the [`Features`] needed by a scalar
+ fn scalar_required_features(&mut self, kind: ScalarKind, width: Bytes) {
+ if kind == ScalarKind::Float && width == 8 {
+ self.features.request(Features::DOUBLE_TYPE);
+ }
+ }
+
+ fn varying_required_features(&mut self, binding: Option<&Binding>, ty: Handle<Type>) {
+ match self.module.types[ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ for member in members {
+ self.varying_required_features(member.binding.as_ref(), member.ty);
+ }
+ }
+ _ => {
+ if let Some(binding) = binding {
+ match *binding {
+ Binding::BuiltIn(built_in) => match built_in {
+ crate::BuiltIn::ClipDistance => {
+ self.features.request(Features::CLIP_DISTANCE)
+ }
+ crate::BuiltIn::CullDistance => {
+ self.features.request(Features::CULL_DISTANCE)
+ }
+ crate::BuiltIn::SampleIndex => {
+ self.features.request(Features::SAMPLE_VARIABLES)
+ }
+ crate::BuiltIn::ViewIndex => {
+ self.features.request(Features::MULTI_VIEW)
+ }
+ _ => {}
+ },
+ Binding::Location {
+ location: _,
+ interpolation,
+ sampling,
+ } => {
+ if interpolation == Some(Interpolation::Linear) {
+ self.features.request(Features::NOPERSPECTIVE_QUALIFIER);
+ }
+ if sampling == Some(Sampling::Sample) {
+ self.features.request(Features::SAMPLE_QUALIFIER);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/back/glsl/keywords.rs b/third_party/rust/naga/src/back/glsl/keywords.rs
new file mode 100644
index 0000000000..5a2836c189
--- /dev/null
+++ b/third_party/rust/naga/src/back/glsl/keywords.rs
@@ -0,0 +1,204 @@
+pub const RESERVED_KEYWORDS: &[&str] = &[
+ "attribute",
+ "const",
+ "uniform",
+ "varying",
+ "buffer",
+ "shared",
+ "coherent",
+ "volatile",
+ "restrict",
+ "readonly",
+ "writeonly",
+ "atomic_uint",
+ "layout",
+ "centroid",
+ "flat",
+ "smooth",
+ "noperspective",
+ "patch",
+ "sample",
+ "break",
+ "continue",
+ "do",
+ "for",
+ "while",
+ "switch",
+ "case",
+ "default",
+ "if",
+ "else",
+ "subroutine",
+ "in",
+ "out",
+ "inout",
+ "float",
+ "double",
+ "int",
+ "void",
+ "bool",
+ "true",
+ "false",
+ "invariant",
+ "precise",
+ "discard",
+ "return",
+ "mat2",
+ "mat3",
+ "mat4",
+ "dmat2",
+ "dmat3",
+ "dmat4",
+ "mat2x2",
+ "mat2x3",
+ "mat2x4",
+ "dmat2x2",
+ "dmat2x3",
+ "dmat2x4",
+ "mat3x2",
+ "mat3x3",
+ "mat3x4",
+ "dmat3x2",
+ "dmat3x3",
+ "dmat3x4",
+ "mat4x2",
+ "mat4x3",
+ "mat4x4",
+ "dmat4x2",
+ "dmat4x3",
+ "dmat4x4",
+ "vec2",
+ "vec3",
+ "vec4",
+ "ivec2",
+ "ivec3",
+ "ivec4",
+ "bvec2",
+ "bvec3",
+ "bvec4",
+ "dvec2",
+ "dvec3",
+ "dvec4",
+ "uint",
+ "uvec2",
+ "uvec3",
+ "uvec4",
+ "lowp",
+ "mediump",
+ "highp",
+ "precision",
+ "sampler1D",
+ "sampler2D",
+ "sampler3D",
+ "samplerCube",
+ "sampler1DShadow",
+ "sampler2DShadow",
+ "samplerCubeShadow",
+ "sampler1DArray",
+ "sampler2DArray",
+ "sampler1DArrayShadow",
+ "sampler2DArrayShadow",
+ "isampler1D",
+ "isampler2D",
+ "isampler3D",
+ "isamplerCube",
+ "isampler1DArray",
+ "isampler2DArray",
+ "usampler1D",
+ "usampler2D",
+ "usampler3D",
+ "usamplerCube",
+ "usampler1DArray",
+ "usampler2DArray",
+ "sampler2DRect",
+ "sampler2DRectShadow",
+ "isampler2D",
+ "Rect",
+ "usampler2DRect",
+ "samplerBuffer",
+ "isamplerBuffer",
+ "usamplerBuffer",
+ "sampler2DMS",
+ "isampler2DMS",
+ "usampler2DMS",
+ "sampler2DMSArray",
+ "isampler2DMSArray",
+ "usampler2DMSArray",
+ "samplerCubeArray",
+ "samplerCubeArrayShadow",
+ "isamplerCubeArray",
+ "usamplerCubeArray",
+ "image1D",
+ "iimage1D",
+ "uimage1D",
+ "image2D",
+ "iimage2D",
+ "uimage2D",
+ "image3D",
+ "iimage3D",
+ "uimage3D",
+ "image2DRect",
+ "iimage2DRect",
+ "uimage2DRect",
+ "imageCube",
+ "iimageCube",
+ "uimageCube",
+ "imageBuffer",
+ "iimageBuffer",
+ "uimageBuffer",
+ "image1DArray",
+ "iimage1DArray",
+ "uimage1DArray",
+ "image2DArray",
+ "iimage2DArray",
+ "uimage2DArray",
+ "imageCubeArray",
+ "iimageCubeArray",
+ "uimageCubeArray",
+ "image2DMS",
+ "iimage2DMS",
+ "uimage2DMS",
+ "image2DMSArray",
+ "iimage2DMSArray",
+ "uimage2DMSArraystruct",
+ "common",
+ "partition",
+ "active",
+ "asm",
+ "class",
+ "union",
+ "enum",
+ "typedef",
+ "template",
+ "this",
+ "resource",
+ "goto",
+ "inline",
+ "noinline",
+ "public",
+ "static",
+ "extern",
+ "external",
+ "interface",
+ "long",
+ "short",
+ "half",
+ "fixed",
+ "unsigned",
+ "superp",
+ "input",
+ "output",
+ "hvec2",
+ "hvec3",
+ "hvec4",
+ "fvec2",
+ "fvec3",
+ "fvec4",
+ "sampler3DRect",
+ "filter",
+ "sizeof",
+ "cast",
+ "namespace",
+ "using",
+ "main",
+];
diff --git a/third_party/rust/naga/src/back/glsl/mod.rs b/third_party/rust/naga/src/back/glsl/mod.rs
new file mode 100644
index 0000000000..a3f2a53836
--- /dev/null
+++ b/third_party/rust/naga/src/back/glsl/mod.rs
@@ -0,0 +1,3840 @@
+/*!
+Backend for [GLSL][glsl] (OpenGL Shading Language).
+
+The main structure is [`Writer`], it maintains internal state that is used
+to output a [`Module`](crate::Module) into glsl
+
+# Supported versions
+### Core
+- 330
+- 400
+- 410
+- 420
+- 430
+- 450
+
+### ES
+- 300
+- 310
+
+[glsl]: https://www.khronos.org/registry/OpenGL/index_gl.php
+*/
+
+// GLSL is mostly a superset of C but it also removes some parts of it this is a list of relevant
+// aspects for this backend.
+//
+// The most notable change is the introduction of the version preprocessor directive that must
+// always be the first line of a glsl file and is written as
+// `#version number profile`
+// `number` is the version itself (i.e. 300) and `profile` is the
+// shader profile we only support "core" and "es", the former is used in desktop applications and
+// the later is used in embedded contexts, mobile devices and browsers. Each one as it's own
+// versions (at the time of writing this the latest version for "core" is 460 and for "es" is 320)
+//
+// Other important preprocessor addition is the extension directive which is written as
+// `#extension name: behaviour`
+// Extensions provide increased features in a plugin fashion but they aren't required to be
+// supported hence why they are called extensions, that's why `behaviour` is used it specifies
+// whether the extension is strictly required or if it should only be enabled if needed. In our case
+// when we use extensions we set behaviour to `require` always.
+//
+// The only thing that glsl removes that makes a difference are pointers.
+//
+// Additions that are relevant for the backend are the discard keyword, the introduction of
+// vector, matrices, samplers, image types and functions that provide common shader operations
+
+pub use features::Features;
+
+use crate::{
+ back,
+ proc::{self, NameKey},
+ valid, Handle, ShaderStage, TypeInner,
+};
+use features::FeaturesManager;
+use std::{
+ cmp::Ordering,
+ fmt,
+ fmt::{Error as FmtError, Write},
+};
+use thiserror::Error;
+
+/// Contains the features related code and the features querying method
+mod features;
+/// Contains a constant with a slice of all the reserved keywords RESERVED_KEYWORDS
+mod keywords;
+
+/// List of supported `core` GLSL versions.
+pub const SUPPORTED_CORE_VERSIONS: &[u16] = &[330, 400, 410, 420, 430, 440, 450];
+/// List of supported `es` GLSL versions.
+pub const SUPPORTED_ES_VERSIONS: &[u16] = &[300, 310, 320];
+
+/// The suffix of the variable that will hold the calculated clamped level
+/// of detail for bounds checking in `ImageLoad`
+const CLAMPED_LOD_SUFFIX: &str = "_clamped_lod";
+
+/// Mapping between resources and bindings.
+pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, u8>;
+
+impl crate::AtomicFunction {
+ const fn to_glsl(self) -> &'static str {
+ match self {
+ Self::Add | Self::Subtract => "Add",
+ Self::And => "And",
+ Self::InclusiveOr => "Or",
+ Self::ExclusiveOr => "Xor",
+ Self::Min => "Min",
+ Self::Max => "Max",
+ Self::Exchange { compare: None } => "Exchange",
+ Self::Exchange { compare: Some(_) } => "", //TODO
+ }
+ }
+}
+
+impl crate::AddressSpace {
+ const fn is_buffer(&self) -> bool {
+ match *self {
+ crate::AddressSpace::Uniform | crate::AddressSpace::Storage { .. } => true,
+ _ => false,
+ }
+ }
+
+ /// Whether a variable with this address space can be initialized
+ const fn initializable(&self) -> bool {
+ match *self {
+ crate::AddressSpace::Function | crate::AddressSpace::Private => true,
+ crate::AddressSpace::WorkGroup
+ | crate::AddressSpace::Uniform
+ | crate::AddressSpace::Storage { .. }
+ | crate::AddressSpace::Handle
+ | crate::AddressSpace::PushConstant => false,
+ }
+ }
+}
+
+/// A GLSL version.
+#[derive(Debug, Copy, Clone, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub enum Version {
+ /// `core` GLSL.
+ Desktop(u16),
+ /// `es` GLSL.
+ Embedded { version: u16, is_webgl: bool },
+}
+
+impl Version {
+ /// Create a new gles version
+ pub const fn new_gles(version: u16) -> Self {
+ Self::Embedded {
+ version,
+ is_webgl: false,
+ }
+ }
+
+ /// Returns true if self is `Version::Embedded` (i.e. is a es version)
+ const fn is_es(&self) -> bool {
+ match *self {
+ Version::Desktop(_) => false,
+ Version::Embedded { .. } => true,
+ }
+ }
+
+ /// Returns true if targetting WebGL
+ const fn is_webgl(&self) -> bool {
+ match *self {
+ Version::Desktop(_) => false,
+ Version::Embedded { is_webgl, .. } => is_webgl,
+ }
+ }
+
+ /// Checks the list of currently supported versions and returns true if it contains the
+ /// specified version
+ ///
+ /// # Notes
+ /// As an invalid version number will never be added to the supported version list
+ /// so this also checks for version validity
+ fn is_supported(&self) -> bool {
+ match *self {
+ Version::Desktop(v) => SUPPORTED_CORE_VERSIONS.contains(&v),
+ Version::Embedded { version: v, .. } => SUPPORTED_ES_VERSIONS.contains(&v),
+ }
+ }
+
+ /// Checks if the version supports all of the explicit layouts:
+ /// - `location=` qualifiers for bindings
+ /// - `binding=` qualifiers for resources
+ ///
+ /// Note: `location=` for vertex inputs and fragment outputs is supported
+ /// unconditionally for GLES 300.
+ fn supports_explicit_locations(&self) -> bool {
+ *self >= Version::Desktop(410) || *self >= Version::new_gles(310)
+ }
+
+ fn supports_early_depth_test(&self) -> bool {
+ *self >= Version::Desktop(130) || *self >= Version::new_gles(310)
+ }
+
+ fn supports_std430_layout(&self) -> bool {
+ *self >= Version::Desktop(430) || *self >= Version::new_gles(310)
+ }
+
+ fn supports_fma_function(&self) -> bool {
+ *self >= Version::Desktop(400) || *self >= Version::new_gles(310)
+ }
+}
+
+impl PartialOrd for Version {
+ fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+ match (*self, *other) {
+ (Version::Desktop(x), Version::Desktop(y)) => Some(x.cmp(&y)),
+ (Version::Embedded { version: x, .. }, Version::Embedded { version: y, .. }) => {
+ Some(x.cmp(&y))
+ }
+ _ => None,
+ }
+ }
+}
+
+impl fmt::Display for Version {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match *self {
+ Version::Desktop(v) => write!(f, "{} core", v),
+ Version::Embedded { version: v, .. } => write!(f, "{} es", v),
+ }
+ }
+}
+
+bitflags::bitflags! {
+ /// Configuration flags for the [`Writer`].
+ #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+ pub struct WriterFlags: u32 {
+ /// Flip output Y and extend Z from (0, 1) to (-1, 1).
+ const ADJUST_COORDINATE_SPACE = 0x1;
+ /// Supports GL_EXT_texture_shadow_lod on the host, which provides
+ /// additional functions on shadows and arrays of shadows.
+ const TEXTURE_SHADOW_LOD = 0x2;
+ }
+}
+
+/// Configuration used in the [`Writer`].
+#[derive(Debug, Clone)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct Options {
+ /// The GLSL version to be used.
+ pub version: Version,
+ /// Configuration flags for the [`Writer`].
+ pub writer_flags: WriterFlags,
+ /// Map of resources association to binding locations.
+ pub binding_map: BindingMap,
+}
+
+impl Default for Options {
+ fn default() -> Self {
+ Options {
+ version: Version::new_gles(310),
+ writer_flags: WriterFlags::ADJUST_COORDINATE_SPACE,
+ binding_map: BindingMap::default(),
+ }
+ }
+}
+
+/// A subset of options meant to be changed per pipeline.
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct PipelineOptions {
+ /// The stage of the entry point.
+ pub shader_stage: ShaderStage,
+ /// The name of the entry point.
+ ///
+ /// If no entry point that matches is found while creating a [`Writer`], a error will be thrown.
+ pub entry_point: String,
+ /// How many views to render to, if doing multiview rendering.
+ pub multiview: Option<std::num::NonZeroU32>,
+}
+
+/// Reflection info for texture mappings and uniforms.
+pub struct ReflectionInfo {
+ /// Mapping between texture names and variables/samplers.
+ pub texture_mapping: crate::FastHashMap<String, TextureMapping>,
+ /// Mapping between uniform variables and names.
+ pub uniforms: crate::FastHashMap<Handle<crate::GlobalVariable>, String>,
+}
+
+/// Mapping between a texture and its sampler, if it exists.
+///
+/// GLSL pre-Vulkan has no concept of separate textures and samplers. Instead, everything is a
+/// `gsamplerN` where `g` is the scalar type and `N` is the dimension. But naga uses separate textures
+/// and samplers in the IR, so the backend produces a [`FastHashMap`](crate::FastHashMap) with the texture name
+/// as a key and a [`TextureMapping`] as a value. This way, the user knows where to bind.
+///
+/// [`Storage`](crate::ImageClass::Storage) images produce `gimageN` and don't have an associated sampler,
+/// so the [`sampler`](Self::sampler) field will be [`None`].
+#[derive(Debug, Clone)]
+pub struct TextureMapping {
+ /// Handle to the image global variable.
+ pub texture: Handle<crate::GlobalVariable>,
+ /// Handle to the associated sampler global variable, if it exists.
+ pub sampler: Option<Handle<crate::GlobalVariable>>,
+}
+
+/// Helper structure that generates a number
+#[derive(Default)]
+struct IdGenerator(u32);
+
+impl IdGenerator {
+ /// Generates a number that's guaranteed to be unique for this `IdGenerator`
+ fn generate(&mut self) -> u32 {
+ // It's just an increasing number but it does the job
+ let ret = self.0;
+ self.0 += 1;
+ ret
+ }
+}
+
+/// Helper wrapper used to get a name for a varying
+///
+/// Varying have different naming schemes depending on their binding:
+/// - Varyings with builtin bindings get the from [`glsl_built_in`](glsl_built_in).
+/// - Varyings with location bindings are named `_S_location_X` where `S` is a
+/// prefix identifying which pipeline stage the varying connects, and `X` is
+/// the location.
+struct VaryingName<'a> {
+ binding: &'a crate::Binding,
+ stage: ShaderStage,
+ output: bool,
+ targetting_webgl: bool,
+}
+impl fmt::Display for VaryingName<'_> {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ match *self.binding {
+ crate::Binding::Location { location, .. } => {
+ let prefix = match (self.stage, self.output) {
+ (ShaderStage::Compute, _) => unreachable!(),
+ // pipeline to vertex
+ (ShaderStage::Vertex, false) => "p2vs",
+ // vertex to fragment
+ (ShaderStage::Vertex, true) | (ShaderStage::Fragment, false) => "vs2fs",
+ // fragment to pipeline
+ (ShaderStage::Fragment, true) => "fs2p",
+ };
+ write!(f, "_{}_location{}", prefix, location,)
+ }
+ crate::Binding::BuiltIn(built_in) => {
+ write!(
+ f,
+ "{}",
+ glsl_built_in(built_in, self.output, self.targetting_webgl)
+ )
+ }
+ }
+ }
+}
+
+impl ShaderStage {
+ const fn to_str(self) -> &'static str {
+ match self {
+ ShaderStage::Compute => "cs",
+ ShaderStage::Fragment => "fs",
+ ShaderStage::Vertex => "vs",
+ }
+ }
+}
+
+/// Shorthand result used internally by the backend
+type BackendResult<T = ()> = Result<T, Error>;
+
+/// A GLSL compilation error.
+#[derive(Debug, Error)]
+pub enum Error {
+ /// A error occurred while writing to the output.
+ #[error("Format error")]
+ FmtError(#[from] FmtError),
+ /// The specified [`Version`] doesn't have all required [`Features`].
+ ///
+ /// Contains the missing [`Features`].
+ #[error("The selected version doesn't support {0:?}")]
+ MissingFeatures(Features),
+ /// [`AddressSpace::PushConstant`](crate::AddressSpace::PushConstant) was used more than
+ /// once in the entry point, which isn't supported.
+ #[error("Multiple push constants aren't supported")]
+ MultiplePushConstants,
+ /// The specified [`Version`] isn't supported.
+ #[error("The specified version isn't supported")]
+ VersionNotSupported,
+ /// The entry point couldn't be found.
+ #[error("The requested entry point couldn't be found")]
+ EntryPointNotFound,
+ /// A call was made to an unsupported external.
+ #[error("A call was made to an unsupported external: {0}")]
+ UnsupportedExternal(String),
+ /// A scalar with an unsupported width was requested.
+ #[error("A scalar with an unsupported width was requested: {0:?} {1:?}")]
+ UnsupportedScalar(crate::ScalarKind, crate::Bytes),
+ /// A image was used with multiple samplers, which isn't supported.
+ #[error("A image was used with multiple samplers")]
+ ImageMultipleSamplers,
+ #[error("{0}")]
+ Custom(String),
+}
+
+/// Binary operation with a different logic on the GLSL side.
+enum BinaryOperation {
+ /// Vector comparison should use the function like `greaterThan()`, etc.
+ VectorCompare,
+ /// Vector component wise operation; used to polyfill unsupported ops like `|` and `&` for `bvecN`'s
+ VectorComponentWise,
+ /// GLSL `%` is SPIR-V `OpUMod/OpSMod` and `mod()` is `OpFMod`, but [`BinaryOperator::Modulo`](crate::BinaryOperator::Modulo) is `OpFRem`.
+ Modulo,
+ /// Any plain operation. No additional logic required.
+ Other,
+}
+
+/// Writer responsible for all code generation.
+pub struct Writer<'a, W> {
+ // Inputs
+ /// The module being written.
+ module: &'a crate::Module,
+ /// The module analysis.
+ info: &'a valid::ModuleInfo,
+ /// The output writer.
+ out: W,
+ /// User defined configuration to be used.
+ options: &'a Options,
+ /// The bound checking policies to be used
+ policies: proc::BoundsCheckPolicies,
+
+ // Internal State
+ /// Features manager used to store all the needed features and write them.
+ features: FeaturesManager,
+ namer: proc::Namer,
+ /// A map with all the names needed for writing the module
+ /// (generated by a [`Namer`](crate::proc::Namer)).
+ names: crate::FastHashMap<NameKey, String>,
+ /// A map with the names of global variables needed for reflections.
+ reflection_names_globals: crate::FastHashMap<Handle<crate::GlobalVariable>, String>,
+ /// The selected entry point.
+ entry_point: &'a crate::EntryPoint,
+ /// The index of the selected entry point.
+ entry_point_idx: proc::EntryPointIndex,
+ /// A generator for unique block numbers.
+ block_id: IdGenerator,
+ /// Set of expressions that have associated temporary variables.
+ named_expressions: crate::NamedExpressions,
+ /// Set of expressions that need to be baked to avoid unnecessary repetition in output
+ need_bake_expressions: back::NeedBakeExpressions,
+ /// How many views to render to, if doing multiview rendering.
+ multiview: Option<std::num::NonZeroU32>,
+}
+
+impl<'a, W: Write> Writer<'a, W> {
+ /// Creates a new [`Writer`] instance.
+ ///
+ /// # Errors
+ /// - If the version specified is invalid or supported.
+ /// - If the entry point couldn't be found in the module.
+ /// - If the version specified doesn't support some used features.
+ pub fn new(
+ out: W,
+ module: &'a crate::Module,
+ info: &'a valid::ModuleInfo,
+ options: &'a Options,
+ pipeline_options: &'a PipelineOptions,
+ policies: proc::BoundsCheckPolicies,
+ ) -> Result<Self, Error> {
+ // Check if the requested version is supported
+ if !options.version.is_supported() {
+ log::error!("Version {}", options.version);
+ return Err(Error::VersionNotSupported);
+ }
+
+ // Try to find the entry point and corresponding index
+ let ep_idx = module
+ .entry_points
+ .iter()
+ .position(|ep| {
+ pipeline_options.shader_stage == ep.stage && pipeline_options.entry_point == ep.name
+ })
+ .ok_or(Error::EntryPointNotFound)?;
+
+ // Generate a map with names required to write the module
+ let mut names = crate::FastHashMap::default();
+ let mut namer = proc::Namer::default();
+ namer.reset(module, keywords::RESERVED_KEYWORDS, &["gl_"], &mut names);
+
+ // Build the instance
+ let mut this = Self {
+ module,
+ info,
+ out,
+ options,
+ policies,
+
+ namer,
+ features: FeaturesManager::new(),
+ names,
+ reflection_names_globals: crate::FastHashMap::default(),
+ entry_point: &module.entry_points[ep_idx],
+ entry_point_idx: ep_idx as u16,
+ multiview: pipeline_options.multiview,
+ block_id: IdGenerator::default(),
+ named_expressions: Default::default(),
+ need_bake_expressions: Default::default(),
+ };
+
+ // Find all features required to print this module
+ this.collect_required_features()?;
+
+ Ok(this)
+ }
+
+ /// Writes the [`Module`](crate::Module) as glsl to the output
+ ///
+ /// # Notes
+ /// If an error occurs while writing, the output might have been written partially
+ ///
+ /// # Panics
+ /// Might panic if the module is invalid
+ pub fn write(&mut self) -> Result<ReflectionInfo, Error> {
+ // We use `writeln!(self.out)` throughout the write to add newlines
+ // to make the output more readable
+
+ let es = self.options.version.is_es();
+
+ // Write the version (It must be the first thing or it isn't a valid glsl output)
+ writeln!(self.out, "#version {}", self.options.version)?;
+ // Write all the needed extensions
+ //
+ // This used to be the last thing being written as it allowed to search for features while
+ // writing the module saving some loops but some older versions (420 or less) required the
+ // extensions to appear before being used, even though extensions are part of the
+ // preprocessor not the processor ¯\_(ツ)_/¯
+ self.features.write(self.options.version, &mut self.out)?;
+
+ // Write the additional extensions
+ if self
+ .options
+ .writer_flags
+ .contains(WriterFlags::TEXTURE_SHADOW_LOD)
+ {
+ // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_texture_shadow_lod.txt
+ writeln!(self.out, "#extension GL_EXT_texture_shadow_lod : require")?;
+ }
+
+ // glsl es requires a precision to be specified for floats and ints
+ // TODO: Should this be user configurable?
+ if es {
+ writeln!(self.out)?;
+ writeln!(self.out, "precision highp float;")?;
+ writeln!(self.out, "precision highp int;")?;
+ writeln!(self.out)?;
+ }
+
+ if self.entry_point.stage == ShaderStage::Compute {
+ let workgroup_size = self.entry_point.workgroup_size;
+ writeln!(
+ self.out,
+ "layout(local_size_x = {}, local_size_y = {}, local_size_z = {}) in;",
+ workgroup_size[0], workgroup_size[1], workgroup_size[2]
+ )?;
+ writeln!(self.out)?;
+ }
+
+ // Enable early depth tests if needed
+ if let Some(depth_test) = self.entry_point.early_depth_test {
+ // If early depth test is supported for this version of GLSL
+ if self.options.version.supports_early_depth_test() {
+ writeln!(self.out, "layout(early_fragment_tests) in;")?;
+
+ if let Some(conservative) = depth_test.conservative {
+ use crate::ConservativeDepth as Cd;
+
+ let depth = match conservative {
+ Cd::GreaterEqual => "greater",
+ Cd::LessEqual => "less",
+ Cd::Unchanged => "unchanged",
+ };
+ writeln!(self.out, "layout (depth_{}) out float gl_FragDepth;", depth)?;
+ }
+ writeln!(self.out)?;
+ } else {
+ log::warn!(
+ "Early depth testing is not supported for this version of GLSL: {}",
+ self.options.version
+ );
+ }
+ }
+
+ if self.entry_point.stage == ShaderStage::Vertex && self.options.version.is_webgl() {
+ if let Some(multiview) = self.multiview.as_ref() {
+ writeln!(self.out, "layout(num_views = {}) in;", multiview)?;
+ writeln!(self.out)?;
+ }
+ }
+
+ let ep_info = self.info.get_entry_point(self.entry_point_idx as usize);
+
+ // Write struct types.
+ //
+ // This are always ordered because the IR is structured in a way that
+ // you can't make a struct without adding all of its members first.
+ for (handle, ty) in self.module.types.iter() {
+ if let TypeInner::Struct { ref members, .. } = ty.inner {
+ // Structures ending with runtime-sized arrays can only be
+ // rendered as shader storage blocks in GLSL, not stand-alone
+ // struct types.
+ if !self.module.types[members.last().unwrap().ty]
+ .inner
+ .is_dynamically_sized(&self.module.types)
+ {
+ let name = &self.names[&NameKey::Type(handle)];
+ write!(self.out, "struct {} ", name)?;
+ self.write_struct_body(handle, members)?;
+ writeln!(self.out, ";")?;
+ }
+ }
+ }
+
+ // Write the globals
+ //
+ // We filter all globals that aren't used by the selected entry point as they might be
+ // interfere with each other (i.e. two globals with the same location but different with
+ // different classes)
+ for (handle, global) in self.module.global_variables.iter() {
+ if ep_info[handle].is_empty() {
+ continue;
+ }
+
+ match self.module.types[global.ty].inner {
+ // We treat images separately because they might require
+ // writing the storage format
+ TypeInner::Image {
+ mut dim,
+ arrayed,
+ class,
+ } => {
+ // Gather the storage format if needed
+ let storage_format_access = match self.module.types[global.ty].inner {
+ TypeInner::Image {
+ class: crate::ImageClass::Storage { format, access },
+ ..
+ } => Some((format, access)),
+ _ => None,
+ };
+
+ if dim == crate::ImageDimension::D1 && es {
+ dim = crate::ImageDimension::D2
+ }
+
+ // Gether the location if needed
+ let layout_binding = if self.options.version.supports_explicit_locations() {
+ let br = global.binding.as_ref().unwrap();
+ self.options.binding_map.get(br).cloned()
+ } else {
+ None
+ };
+
+ // Write all the layout qualifiers
+ if layout_binding.is_some() || storage_format_access.is_some() {
+ write!(self.out, "layout(")?;
+ if let Some(binding) = layout_binding {
+ write!(self.out, "binding = {}", binding)?;
+ }
+ if let Some((format, _)) = storage_format_access {
+ let format_str = glsl_storage_format(format);
+ let separator = match layout_binding {
+ Some(_) => ",",
+ None => "",
+ };
+ write!(self.out, "{}{}", separator, format_str)?;
+ }
+ write!(self.out, ") ")?;
+ }
+
+ if let Some((_, access)) = storage_format_access {
+ self.write_storage_access(access)?;
+ }
+
+ // All images in glsl are `uniform`
+ // The trailing space is important
+ write!(self.out, "uniform ")?;
+
+ // write the type
+ //
+ // This is way we need the leading space because `write_image_type` doesn't add
+ // any spaces at the beginning or end
+ self.write_image_type(dim, arrayed, class)?;
+
+ // Finally write the name and end the global with a `;`
+ // The leading space is important
+ let global_name = self.get_global_name(handle, global);
+ writeln!(self.out, " {};", global_name)?;
+ writeln!(self.out)?;
+
+ self.reflection_names_globals.insert(handle, global_name);
+ }
+ // glsl has no concept of samplers so we just ignore it
+ TypeInner::Sampler { .. } => continue,
+ // All other globals are written by `write_global`
+ _ => {
+ if !ep_info[handle].is_empty() {
+ self.write_global(handle, global)?;
+ // Add a newline (only for readability)
+ writeln!(self.out)?;
+ }
+ }
+ }
+ }
+
+ for arg in self.entry_point.function.arguments.iter() {
+ self.write_varying(arg.binding.as_ref(), arg.ty, false)?;
+ }
+ if let Some(ref result) = self.entry_point.function.result {
+ self.write_varying(result.binding.as_ref(), result.ty, true)?;
+ }
+ writeln!(self.out)?;
+
+ // Write all regular functions
+ for (handle, function) in self.module.functions.iter() {
+ // Check that the function doesn't use globals that aren't supported
+ // by the current entry point
+ if !ep_info.dominates_global_use(&self.info[handle]) {
+ continue;
+ }
+
+ let fun_info = &self.info[handle];
+
+ // Write the function
+ self.write_function(back::FunctionType::Function(handle), function, fun_info)?;
+
+ writeln!(self.out)?;
+ }
+
+ self.write_function(
+ back::FunctionType::EntryPoint(self.entry_point_idx),
+ &self.entry_point.function,
+ ep_info,
+ )?;
+
+ // Add newline at the end of file
+ writeln!(self.out)?;
+
+ // Collect all reflection info and return it to the user
+ self.collect_reflection_info()
+ }
+
+ fn write_array_size(
+ &mut self,
+ base: Handle<crate::Type>,
+ size: crate::ArraySize,
+ ) -> BackendResult {
+ write!(self.out, "[")?;
+
+ // Write the array size
+ // Writes nothing if `ArraySize::Dynamic`
+ // Panics if `ArraySize::Constant` has a constant that isn't an sint or uint
+ match size {
+ crate::ArraySize::Constant(const_handle) => {
+ match self.module.constants[const_handle].inner {
+ crate::ConstantInner::Scalar {
+ width: _,
+ value: crate::ScalarValue::Uint(size),
+ } => write!(self.out, "{}", size)?,
+ crate::ConstantInner::Scalar {
+ width: _,
+ value: crate::ScalarValue::Sint(size),
+ } => write!(self.out, "{}", size)?,
+ _ => unreachable!(),
+ }
+ }
+ crate::ArraySize::Dynamic => (),
+ }
+
+ write!(self.out, "]")?;
+
+ if let TypeInner::Array {
+ base: next_base,
+ size: next_size,
+ ..
+ } = self.module.types[base].inner
+ {
+ self.write_array_size(next_base, next_size)?;
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write value types
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ ///
+ /// # Panics
+ /// - If type is either a image, a sampler, a pointer, or a struct
+ /// - If it's an Array with a [`ArraySize::Constant`](crate::ArraySize::Constant) with a
+ /// constant that isn't a [`Scalar`](crate::ConstantInner::Scalar) or if the
+ /// scalar value isn't an [`Sint`](crate::ScalarValue::Sint) or [`Uint`](crate::ScalarValue::Uint)
+ fn write_value_type(&mut self, inner: &TypeInner) -> BackendResult {
+ match *inner {
+ // Scalars are simple we just get the full name from `glsl_scalar`
+ TypeInner::Scalar { kind, width }
+ | TypeInner::Atomic { kind, width }
+ | TypeInner::ValuePointer {
+ size: None,
+ kind,
+ width,
+ space: _,
+ } => write!(self.out, "{}", glsl_scalar(kind, width)?.full)?,
+ // Vectors are just `gvecN` where `g` is the scalar prefix and `N` is the vector size
+ TypeInner::Vector { size, kind, width }
+ | TypeInner::ValuePointer {
+ size: Some(size),
+ kind,
+ width,
+ space: _,
+ } => write!(
+ self.out,
+ "{}vec{}",
+ glsl_scalar(kind, width)?.prefix,
+ size as u8
+ )?,
+ // Matrices are written with `gmatMxN` where `g` is the scalar prefix (only floats and
+ // doubles are allowed), `M` is the columns count and `N` is the rows count
+ //
+ // glsl supports a matrix shorthand `gmatN` where `N` = `M` but it doesn't justify the
+ // extra branch to write matrices this way
+ TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => write!(
+ self.out,
+ "{}mat{}x{}",
+ glsl_scalar(crate::ScalarKind::Float, width)?.prefix,
+ columns as u8,
+ rows as u8
+ )?,
+ // GLSL arrays are written as `type name[size]`
+ // Current code is written arrays only as `[size]`
+ // Base `type` and `name` should be written outside
+ TypeInner::Array { base, size, .. } => self.write_array_size(base, size)?,
+ // Panic if either Image, Sampler, Pointer, or a Struct is being written
+ //
+ // Write all variants instead of `_` so that if new variants are added a
+ // no exhaustiveness error is thrown
+ TypeInner::Pointer { .. }
+ | TypeInner::Struct { .. }
+ | TypeInner::Image { .. }
+ | TypeInner::Sampler { .. }
+ | TypeInner::BindingArray { .. } => {
+ return Err(Error::Custom(format!("Unable to write type {:?}", inner)))
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write non image/sampler types
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ ///
+ /// # Panics
+ /// - If type is either a image or sampler
+ /// - If it's an Array with a [`ArraySize::Constant`](crate::ArraySize::Constant) with a
+ /// constant that isn't a [`Scalar`](crate::ConstantInner::Scalar) or if the
+ /// scalar value isn't an [`Sint`](crate::ScalarValue::Sint) or [`Uint`](crate::ScalarValue::Uint)
+ fn write_type(&mut self, ty: Handle<crate::Type>) -> BackendResult {
+ match self.module.types[ty].inner {
+ // glsl has no pointer types so just write types as normal and loads are skipped
+ TypeInner::Pointer { base, .. } => self.write_type(base),
+ // glsl structs are written as just the struct name
+ TypeInner::Struct { .. } => {
+ // Get the struct name
+ let name = &self.names[&NameKey::Type(ty)];
+ write!(self.out, "{}", name)?;
+ Ok(())
+ }
+ // glsl array has the size separated from the base type
+ TypeInner::Array { base, .. } => self.write_type(base),
+ ref other => self.write_value_type(other),
+ }
+ }
+
+ /// Helper method to write a image type
+ ///
+ /// # Notes
+ /// Adds no leading or trailing whitespace
+ fn write_image_type(
+ &mut self,
+ dim: crate::ImageDimension,
+ arrayed: bool,
+ class: crate::ImageClass,
+ ) -> BackendResult {
+ // glsl images consist of four parts the scalar prefix, the image "type", the dimensions
+ // and modifiers
+ //
+ // There exists two image types
+ // - sampler - for sampled images
+ // - image - for storage images
+ //
+ // There are three possible modifiers that can be used together and must be written in
+ // this order to be valid
+ // - MS - used if it's a multisampled image
+ // - Array - used if it's an image array
+ // - Shadow - used if it's a depth image
+ use crate::ImageClass as Ic;
+
+ let (base, kind, ms, comparison) = match class {
+ Ic::Sampled { kind, multi: true } => ("sampler", kind, "MS", ""),
+ Ic::Sampled { kind, multi: false } => ("sampler", kind, "", ""),
+ Ic::Depth { multi: true } => ("sampler", crate::ScalarKind::Float, "MS", ""),
+ Ic::Depth { multi: false } => ("sampler", crate::ScalarKind::Float, "", "Shadow"),
+ Ic::Storage { format, .. } => ("image", format.into(), "", ""),
+ };
+
+ write!(
+ self.out,
+ "highp {}{}{}{}{}{}",
+ glsl_scalar(kind, 4)?.prefix,
+ base,
+ glsl_dimension(dim),
+ ms,
+ if arrayed { "Array" } else { "" },
+ comparison
+ )?;
+
+ Ok(())
+ }
+
+ /// Helper method used to write non images/sampler globals
+ ///
+ /// # Notes
+ /// Adds a newline
+ ///
+ /// # Panics
+ /// If the global has type sampler
+ fn write_global(
+ &mut self,
+ handle: Handle<crate::GlobalVariable>,
+ global: &crate::GlobalVariable,
+ ) -> BackendResult {
+ if self.options.version.supports_explicit_locations() {
+ if let Some(ref br) = global.binding {
+ match self.options.binding_map.get(br) {
+ Some(binding) => {
+ let layout = match global.space {
+ crate::AddressSpace::Storage { .. } => {
+ if self.options.version.supports_std430_layout() {
+ "std430, "
+ } else {
+ "std140, "
+ }
+ }
+ crate::AddressSpace::Uniform => "std140, ",
+ _ => "",
+ };
+ write!(self.out, "layout({}binding = {}) ", layout, binding)?
+ }
+ None => {
+ log::debug!("unassigned binding for {:?}", global.name);
+ if let crate::AddressSpace::Storage { .. } = global.space {
+ if self.options.version.supports_std430_layout() {
+ write!(self.out, "layout(std430) ")?
+ }
+ }
+ }
+ }
+ }
+ }
+
+ if let crate::AddressSpace::Storage { access } = global.space {
+ self.write_storage_access(access)?;
+ }
+
+ if let Some(storage_qualifier) = glsl_storage_qualifier(global.space) {
+ write!(self.out, "{} ", storage_qualifier)?;
+ }
+
+ match global.space {
+ crate::AddressSpace::Private => {
+ self.write_simple_global(handle, global)?;
+ }
+ crate::AddressSpace::WorkGroup => {
+ self.write_simple_global(handle, global)?;
+ }
+ crate::AddressSpace::PushConstant => {
+ self.write_simple_global(handle, global)?;
+ }
+ crate::AddressSpace::Uniform => {
+ self.write_interface_block(handle, global)?;
+ }
+ crate::AddressSpace::Storage { .. } => {
+ self.write_interface_block(handle, global)?;
+ }
+ // A global variable in the `Function` address space is a
+ // contradiction in terms.
+ crate::AddressSpace::Function => unreachable!(),
+ // Textures and samplers are handled directly in `Writer::write`.
+ crate::AddressSpace::Handle => unreachable!(),
+ }
+
+ Ok(())
+ }
+
+ fn write_simple_global(
+ &mut self,
+ handle: Handle<crate::GlobalVariable>,
+ global: &crate::GlobalVariable,
+ ) -> BackendResult {
+ self.write_type(global.ty)?;
+ write!(self.out, " ")?;
+ self.write_global_name(handle, global)?;
+
+ if let TypeInner::Array { base, size, .. } = self.module.types[global.ty].inner {
+ self.write_array_size(base, size)?;
+ }
+
+ if global.space.initializable() && is_value_init_supported(self.module, global.ty) {
+ write!(self.out, " = ")?;
+ if let Some(init) = global.init {
+ self.write_constant(init)?;
+ } else {
+ self.write_zero_init_value(global.ty)?;
+ }
+ }
+
+ writeln!(self.out, ";")?;
+
+ if let crate::AddressSpace::PushConstant = global.space {
+ let global_name = self.get_global_name(handle, global);
+ self.reflection_names_globals.insert(handle, global_name);
+ }
+
+ Ok(())
+ }
+
+ /// Write an interface block for a single Naga global.
+ ///
+ /// Write `block_name { members }`. Since `block_name` must be unique
+ /// between blocks and structs, we add `_block_ID` where `ID` is a
+ /// `IdGenerator` generated number. Write `members` in the same way we write
+ /// a struct's members.
+ fn write_interface_block(
+ &mut self,
+ handle: Handle<crate::GlobalVariable>,
+ global: &crate::GlobalVariable,
+ ) -> BackendResult {
+ // Write the block name, it's just the struct name appended with `_block_ID`
+ let ty_name = &self.names[&NameKey::Type(global.ty)];
+ let block_name = format!(
+ "{}_block_{}{:?}",
+ ty_name,
+ self.block_id.generate(),
+ self.entry_point.stage,
+ );
+ write!(self.out, "{} ", block_name)?;
+ self.reflection_names_globals.insert(handle, block_name);
+
+ match self.module.types[global.ty].inner {
+ crate::TypeInner::Struct { ref members, .. }
+ if self.module.types[members.last().unwrap().ty]
+ .inner
+ .is_dynamically_sized(&self.module.types) =>
+ {
+ // Structs with dynamically sized arrays must have their
+ // members lifted up as members of the interface block. GLSL
+ // can't write such struct types anyway.
+ self.write_struct_body(global.ty, members)?;
+ write!(self.out, " ")?;
+ self.write_global_name(handle, global)?;
+ }
+ _ => {
+ // A global of any other type is written as the sole member
+ // of the interface block. Since the interface block is
+ // anonymous, this becomes visible in the global scope.
+ write!(self.out, "{{ ")?;
+ self.write_type(global.ty)?;
+ write!(self.out, " ")?;
+ self.write_global_name(handle, global)?;
+ if let TypeInner::Array { base, size, .. } = self.module.types[global.ty].inner {
+ self.write_array_size(base, size)?;
+ }
+ write!(self.out, "; }}")?;
+ }
+ }
+
+ writeln!(self.out, ";")?;
+
+ Ok(())
+ }
+
+ /// Helper method used to find which expressions of a given function require baking
+ ///
+ /// # Notes
+ /// Clears `need_bake_expressions` set before adding to it
+ fn update_expressions_to_bake(&mut self, func: &crate::Function, info: &valid::FunctionInfo) {
+ use crate::Expression;
+ self.need_bake_expressions.clear();
+ for expr in func.expressions.iter() {
+ let expr_info = &info[expr.0];
+ let min_ref_count = func.expressions[expr.0].bake_ref_count();
+ if min_ref_count <= expr_info.ref_count {
+ self.need_bake_expressions.insert(expr.0);
+ }
+ // if the expression is a Dot product with integer arguments,
+ // then the args needs baking as well
+ if let (
+ fun_handle,
+ &Expression::Math {
+ fun: crate::MathFunction::Dot,
+ arg,
+ arg1,
+ ..
+ },
+ ) = expr
+ {
+ let inner = info[fun_handle].ty.inner_with(&self.module.types);
+ if let TypeInner::Scalar { kind, .. } = *inner {
+ match kind {
+ crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
+ self.need_bake_expressions.insert(arg);
+ self.need_bake_expressions.insert(arg1.unwrap());
+ }
+ _ => {}
+ }
+ }
+ }
+ }
+ }
+
+ /// Helper method used to get a name for a global
+ ///
+ /// Globals have different naming schemes depending on their binding:
+ /// - Globals without bindings use the name from the [`Namer`](crate::proc::Namer)
+ /// - Globals with resource binding are named `_group_X_binding_Y` where `X`
+ /// is the group and `Y` is the binding
+ fn get_global_name(
+ &self,
+ handle: Handle<crate::GlobalVariable>,
+ global: &crate::GlobalVariable,
+ ) -> String {
+ match global.binding {
+ Some(ref br) => {
+ format!(
+ "_group_{}_binding_{}_{}",
+ br.group,
+ br.binding,
+ self.entry_point.stage.to_str()
+ )
+ }
+ None => self.names[&NameKey::GlobalVariable(handle)].clone(),
+ }
+ }
+
+ /// Helper method used to write a name for a global without additional heap allocation
+ fn write_global_name(
+ &mut self,
+ handle: Handle<crate::GlobalVariable>,
+ global: &crate::GlobalVariable,
+ ) -> BackendResult {
+ match global.binding {
+ Some(ref br) => write!(
+ self.out,
+ "_group_{}_binding_{}_{}",
+ br.group,
+ br.binding,
+ self.entry_point.stage.to_str()
+ )?,
+ None => write!(
+ self.out,
+ "{}",
+ &self.names[&NameKey::GlobalVariable(handle)]
+ )?,
+ }
+
+ Ok(())
+ }
+
+ /// Write a GLSL global that will carry a Naga entry point's argument or return value.
+ ///
+ /// A Naga entry point's arguments and return value are rendered in GLSL as
+ /// variables at global scope with the `in` and `out` storage qualifiers.
+ /// The code we generate for `main` loads from all the `in` globals into
+ /// appropriately named locals. Before it returns, `main` assigns the
+ /// components of its return value into all the `out` globals.
+ ///
+ /// This function writes a declaration for one such GLSL global,
+ /// representing a value passed into or returned from [`self.entry_point`]
+ /// that has a [`Location`] binding. The global's name is generated based on
+ /// the location index and the shader stages being connected; see
+ /// [`VaryingName`]. This means we don't need to know the names of
+ /// arguments, just their types and bindings.
+ ///
+ /// Emit nothing for entry point arguments or return values with [`BuiltIn`]
+ /// bindings; `main` will read from or assign to the appropriate GLSL
+ /// special variable; these are pre-declared. As an exception, we do declare
+ /// `gl_Position` or `gl_FragCoord` with the `invariant` qualifier if
+ /// needed.
+ ///
+ /// Use `output` together with [`self.entry_point.stage`] to determine which
+ /// shader stages are being connected, and choose the `in` or `out` storage
+ /// qualifier.
+ ///
+ /// [`self.entry_point`]: Writer::entry_point
+ /// [`self.entry_point.stage`]: crate::EntryPoint::stage
+ /// [`Location`]: crate::Binding::Location
+ /// [`BuiltIn`]: crate::Binding::BuiltIn
+ fn write_varying(
+ &mut self,
+ binding: Option<&crate::Binding>,
+ ty: Handle<crate::Type>,
+ output: bool,
+ ) -> Result<(), Error> {
+ // For a struct, emit a separate global for each member with a binding.
+ if let crate::TypeInner::Struct { ref members, .. } = self.module.types[ty].inner {
+ for member in members {
+ self.write_varying(member.binding.as_ref(), member.ty, output)?;
+ }
+ return Ok(());
+ }
+
+ let binding = match binding {
+ None => return Ok(()),
+ Some(binding) => binding,
+ };
+
+ let (location, interpolation, sampling) = match *binding {
+ crate::Binding::Location {
+ location,
+ interpolation,
+ sampling,
+ } => (location, interpolation, sampling),
+ crate::Binding::BuiltIn(built_in) => {
+ if let crate::BuiltIn::Position { invariant: true } = built_in {
+ writeln!(
+ self.out,
+ "invariant {};",
+ glsl_built_in(built_in, output, self.options.version.is_webgl())
+ )?;
+ }
+ return Ok(());
+ }
+ };
+
+ // Write the interpolation modifier if needed
+ //
+ // We ignore all interpolation and auxiliary modifiers that aren't used in fragment
+ // shaders' input globals or vertex shaders' output globals.
+ let emit_interpolation_and_auxiliary = match self.entry_point.stage {
+ ShaderStage::Vertex => output,
+ ShaderStage::Fragment => !output,
+ _ => false,
+ };
+
+ // Write the I/O locations, if allowed
+ if self.options.version.supports_explicit_locations() || !emit_interpolation_and_auxiliary {
+ write!(self.out, "layout(location = {}) ", location)?;
+ }
+
+ // Write the interpolation qualifier.
+ if let Some(interp) = interpolation {
+ if emit_interpolation_and_auxiliary {
+ write!(self.out, "{} ", glsl_interpolation(interp))?;
+ }
+ }
+
+ // Write the sampling auxiliary qualifier.
+ //
+ // Before GLSL 4.2, the `centroid` and `sample` qualifiers were required to appear
+ // immediately before the `in` / `out` qualifier, so we'll just follow that rule
+ // here, regardless of the version.
+ if let Some(sampling) = sampling {
+ if emit_interpolation_and_auxiliary {
+ if let Some(qualifier) = glsl_sampling(sampling) {
+ write!(self.out, "{} ", qualifier)?;
+ }
+ }
+ }
+
+ // Write the input/output qualifier.
+ write!(self.out, "{} ", if output { "out" } else { "in" })?;
+
+ // Write the type
+ // `write_type` adds no leading or trailing spaces
+ self.write_type(ty)?;
+
+ // Finally write the global name and end the global with a `;` and a newline
+ // Leading space is important
+ let vname = VaryingName {
+ binding: &crate::Binding::Location {
+ location,
+ interpolation: None,
+ sampling: None,
+ },
+ stage: self.entry_point.stage,
+ output,
+ targetting_webgl: self.options.version.is_webgl(),
+ };
+ writeln!(self.out, " {};", vname)?;
+
+ Ok(())
+ }
+
+ /// Helper method used to write functions (both entry points and regular functions)
+ ///
+ /// # Notes
+ /// Adds a newline
+ fn write_function(
+ &mut self,
+ ty: back::FunctionType,
+ func: &crate::Function,
+ info: &valid::FunctionInfo,
+ ) -> BackendResult {
+ // Create a function context for the function being written
+ let ctx = back::FunctionCtx {
+ ty,
+ info,
+ expressions: &func.expressions,
+ named_expressions: &func.named_expressions,
+ };
+
+ self.named_expressions.clear();
+ self.update_expressions_to_bake(func, info);
+
+ // Write the function header
+ //
+ // glsl headers are the same as in c:
+ // `ret_type name(args)`
+ // `ret_type` is the return type
+ // `name` is the function name
+ // `args` is a comma separated list of `type name`
+ // | - `type` is the argument type
+ // | - `name` is the argument name
+
+ // Start by writing the return type if any otherwise write void
+ // This is the only place where `void` is a valid type
+ // (though it's more a keyword than a type)
+ if let back::FunctionType::EntryPoint(_) = ctx.ty {
+ write!(self.out, "void")?;
+ } else if let Some(ref result) = func.result {
+ self.write_type(result.ty)?;
+ } else {
+ write!(self.out, "void")?;
+ }
+
+ // Write the function name and open parentheses for the argument list
+ let function_name = match ctx.ty {
+ back::FunctionType::Function(handle) => &self.names[&NameKey::Function(handle)],
+ back::FunctionType::EntryPoint(_) => "main",
+ };
+ write!(self.out, " {}(", function_name)?;
+
+ // Write the comma separated argument list
+ //
+ // We need access to `Self` here so we use the reference passed to the closure as an
+ // argument instead of capturing as that would cause a borrow checker error
+ let arguments = match ctx.ty {
+ back::FunctionType::EntryPoint(_) => &[][..],
+ back::FunctionType::Function(_) => &func.arguments,
+ };
+ let arguments: Vec<_> = arguments
+ .iter()
+ .enumerate()
+ .filter(|&(_, arg)| match self.module.types[arg.ty].inner {
+ TypeInner::Sampler { .. } => false,
+ _ => true,
+ })
+ .collect();
+ self.write_slice(&arguments, |this, _, &(i, arg)| {
+ // Write the argument type
+ match this.module.types[arg.ty].inner {
+ // We treat images separately because they might require
+ // writing the storage format
+ TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ // Write the storage format if needed
+ if let TypeInner::Image {
+ class: crate::ImageClass::Storage { format, .. },
+ ..
+ } = this.module.types[arg.ty].inner
+ {
+ write!(this.out, "layout({}) ", glsl_storage_format(format))?;
+ }
+
+ // write the type
+ //
+ // This is way we need the leading space because `write_image_type` doesn't add
+ // any spaces at the beginning or end
+ this.write_image_type(dim, arrayed, class)?;
+ }
+ TypeInner::Pointer { base, .. } => {
+ // write parameter qualifiers
+ write!(this.out, "inout ")?;
+ this.write_type(base)?;
+ }
+ // All other types are written by `write_type`
+ _ => {
+ this.write_type(arg.ty)?;
+ }
+ }
+
+ // Write the argument name
+ // The leading space is important
+ write!(this.out, " {}", &this.names[&ctx.argument_key(i as u32)])?;
+
+ // Write array size
+ if let TypeInner::Array { base, size, .. } = this.module.types[arg.ty].inner {
+ this.write_array_size(base, size)?;
+ }
+
+ Ok(())
+ })?;
+
+ // Close the parentheses and open braces to start the function body
+ writeln!(self.out, ") {{")?;
+
+ // Compose the function arguments from globals, in case of an entry point.
+ if let back::FunctionType::EntryPoint(ep_index) = ctx.ty {
+ let stage = self.module.entry_points[ep_index as usize].stage;
+ for (index, arg) in func.arguments.iter().enumerate() {
+ write!(self.out, "{}", back::INDENT)?;
+ self.write_type(arg.ty)?;
+ let name = &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)];
+ write!(self.out, " {}", name)?;
+ write!(self.out, " = ")?;
+ match self.module.types[arg.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ self.write_type(arg.ty)?;
+ write!(self.out, "(")?;
+ for (index, member) in members.iter().enumerate() {
+ let varying_name = VaryingName {
+ binding: member.binding.as_ref().unwrap(),
+ stage,
+ output: false,
+ targetting_webgl: self.options.version.is_webgl(),
+ };
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ write!(self.out, "{}", varying_name)?;
+ }
+ writeln!(self.out, ");")?;
+ }
+ _ => {
+ let varying_name = VaryingName {
+ binding: arg.binding.as_ref().unwrap(),
+ stage,
+ output: false,
+ targetting_webgl: self.options.version.is_webgl(),
+ };
+ writeln!(self.out, "{};", varying_name)?;
+ }
+ }
+ }
+ }
+
+ // Write all function locals
+ // Locals are `type name (= init)?;` where the init part (including the =) are optional
+ //
+ // Always adds a newline
+ for (handle, local) in func.local_variables.iter() {
+ // Write indentation (only for readability) and the type
+ // `write_type` adds no trailing space
+ write!(self.out, "{}", back::INDENT)?;
+ self.write_type(local.ty)?;
+
+ // Write the local name
+ // The leading space is important
+ write!(self.out, " {}", self.names[&ctx.name_key(handle)])?;
+ // Write size for array type
+ if let TypeInner::Array { base, size, .. } = self.module.types[local.ty].inner {
+ self.write_array_size(base, size)?;
+ }
+ // Write the local initializer if needed
+ if let Some(init) = local.init {
+ // Put the equal signal only if there's a initializer
+ // The leading and trailing spaces aren't needed but help with readability
+ write!(self.out, " = ")?;
+
+ // Write the constant
+ // `write_constant` adds no trailing or leading space/newline
+ self.write_constant(init)?;
+ } else if is_value_init_supported(self.module, local.ty) {
+ write!(self.out, " = ")?;
+ self.write_zero_init_value(local.ty)?;
+ }
+
+ // Finish the local with `;` and add a newline (only for readability)
+ writeln!(self.out, ";")?
+ }
+
+ // Write the function body (statement list)
+ for sta in func.body.iter() {
+ // Write a statement, the indentation should always be 1 when writing the function body
+ // `write_stmt` adds a newline
+ self.write_stmt(sta, &ctx, back::Level(1))?;
+ }
+
+ // Close braces and add a newline
+ writeln!(self.out, "}}")?;
+
+ Ok(())
+ }
+
+ /// Helper method that writes a list of comma separated `T` with a writer function `F`
+ ///
+ /// The writer function `F` receives a mutable reference to `self` that if needed won't cause
+ /// borrow checker issues (using for example a closure with `self` will cause issues), the
+ /// second argument is the 0 based index of the element on the list, and the last element is
+ /// a reference to the element `T` being written
+ ///
+ /// # Notes
+ /// - Adds no newlines or leading/trailing whitespace
+ /// - The last element won't have a trailing `,`
+ fn write_slice<T, F: FnMut(&mut Self, u32, &T) -> BackendResult>(
+ &mut self,
+ data: &[T],
+ mut f: F,
+ ) -> BackendResult {
+ // Loop trough `data` invoking `f` for each element
+ for (i, item) in data.iter().enumerate() {
+ f(self, i as u32, item)?;
+
+ // Only write a comma if isn't the last element
+ if i != data.len().saturating_sub(1) {
+ // The leading space is for readability only
+ write!(self.out, ", ")?;
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write constants
+ ///
+ /// # Notes
+ /// Adds no newlines or leading/trailing whitespace
+ fn write_constant(&mut self, handle: Handle<crate::Constant>) -> BackendResult {
+ use crate::ScalarValue as Sv;
+
+ match self.module.constants[handle].inner {
+ crate::ConstantInner::Scalar {
+ width: _,
+ ref value,
+ } => match *value {
+ // Signed integers don't need anything special
+ Sv::Sint(int) => write!(self.out, "{}", int)?,
+ // Unsigned integers need a `u` at the end
+ //
+ // While `core` doesn't necessarily need it, it's allowed and since `es` needs it we
+ // always write it as the extra branch wouldn't have any benefit in readability
+ Sv::Uint(int) => write!(self.out, "{}u", int)?,
+ // Floats are written using `Debug` instead of `Display` because it always appends the
+ // decimal part even it's zero which is needed for a valid glsl float constant
+ Sv::Float(float) => write!(self.out, "{:?}", float)?,
+ // Booleans are either `true` or `false` so nothing special needs to be done
+ Sv::Bool(boolean) => write!(self.out, "{}", boolean)?,
+ },
+ // Composite constant are created using the same syntax as compose
+ // `type(components)` where `components` is a comma separated list of constants
+ crate::ConstantInner::Composite { ty, ref components } => {
+ self.write_type(ty)?;
+ if let TypeInner::Array { base, size, .. } = self.module.types[ty].inner {
+ self.write_array_size(base, size)?;
+ }
+ write!(self.out, "(")?;
+
+ // Write the comma separated constants
+ self.write_slice(components, |this, _, arg| this.write_constant(*arg))?;
+
+ write!(self.out, ")")?
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to output a dot product as an arithmetic expression
+ ///
+ fn write_dot_product(
+ &mut self,
+ arg: Handle<crate::Expression>,
+ arg1: Handle<crate::Expression>,
+ size: usize,
+ ctx: &back::FunctionCtx<'_>,
+ ) -> BackendResult {
+ // Write parantheses around the dot product expression to prevent operators
+ // with different precedences from applying earlier.
+ write!(self.out, "(")?;
+
+ // Cycle trough all the components of the vector
+ for index in 0..size {
+ let component = back::COMPONENTS[index];
+ // Write the addition to the previous product
+ // This will print an extra '+' at the beginning but that is fine in glsl
+ write!(self.out, " + ")?;
+ // Write the first vector expression, this expression is marked to be
+ // cached so unless it can't be cached (for example, it's a Constant)
+ // it shouldn't produce large expressions.
+ self.write_expr(arg, ctx)?;
+ // Access the current component on the first vector
+ write!(self.out, ".{} * ", component)?;
+ // Write the second vector expression, this expression is marked to be
+ // cached so unless it can't be cached (for example, it's a Constant)
+ // it shouldn't produce large expressions.
+ self.write_expr(arg1, ctx)?;
+ // Access the current component on the second vector
+ write!(self.out, ".{}", component)?;
+ }
+
+ write!(self.out, ")")?;
+ Ok(())
+ }
+
+ /// Helper method used to write structs
+ ///
+ /// # Notes
+ /// Ends in a newline
+ fn write_struct_body(
+ &mut self,
+ handle: Handle<crate::Type>,
+ members: &[crate::StructMember],
+ ) -> BackendResult {
+ // glsl structs are written as in C
+ // `struct name() { members };`
+ // | `struct` is a keyword
+ // | `name` is the struct name
+ // | `members` is a semicolon separated list of `type name`
+ // | `type` is the member type
+ // | `name` is the member name
+ writeln!(self.out, "{{")?;
+
+ for (idx, member) in members.iter().enumerate() {
+ // The indentation is only for readability
+ write!(self.out, "{}", back::INDENT)?;
+
+ match self.module.types[member.ty].inner {
+ TypeInner::Array {
+ base,
+ size,
+ stride: _,
+ } => {
+ self.write_type(base)?;
+ write!(
+ self.out,
+ " {}",
+ &self.names[&NameKey::StructMember(handle, idx as u32)]
+ )?;
+ // Write [size]
+ self.write_array_size(base, size)?;
+ // Newline is important
+ writeln!(self.out, ";")?;
+ }
+ _ => {
+ // Write the member type
+ // Adds no trailing space
+ self.write_type(member.ty)?;
+
+ // Write the member name and put a semicolon
+ // The leading space is important
+ // All members must have a semicolon even the last one
+ writeln!(
+ self.out,
+ " {};",
+ &self.names[&NameKey::StructMember(handle, idx as u32)]
+ )?;
+ }
+ }
+ }
+
+ write!(self.out, "}}")?;
+ Ok(())
+ }
+
+ /// Helper method used to write statements
+ ///
+ /// # Notes
+ /// Always adds a newline
+ fn write_stmt(
+ &mut self,
+ sta: &crate::Statement,
+ ctx: &back::FunctionCtx,
+ level: back::Level,
+ ) -> BackendResult {
+ use crate::Statement;
+
+ match *sta {
+ // This is where we can generate intermediate constants for some expression types.
+ Statement::Emit(ref range) => {
+ for handle in range.clone() {
+ let info = &ctx.info[handle];
+ let ptr_class = info.ty.inner_with(&self.module.types).pointer_space();
+ let expr_name = if ptr_class.is_some() {
+ // GLSL can't save a pointer-valued expression in a variable,
+ // but we shouldn't ever need to: they should never be named expressions,
+ // and none of the expression types flagged by bake_ref_count can be pointer-valued.
+ None
+ } else if let Some(name) = ctx.named_expressions.get(&handle) {
+ // Front end provides names for all variables at the start of writing.
+ // But we write them to step by step. We need to recache them
+ // Otherwise, we could accidentally write variable name instead of full expression.
+ // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords.
+ Some(self.namer.call(name))
+ } else if self.need_bake_expressions.contains(&handle) {
+ Some(format!("{}{}", back::BAKE_PREFIX, handle.index()))
+ } else if info.ref_count == 0 {
+ Some(self.namer.call(""))
+ } else {
+ None
+ };
+
+ // If we are going to write an `ImageLoad` next and the target image
+ // is sampled and we are using the `Restrict` policy for bounds
+ // checking images we need to write a local holding the clamped lod.
+ if let crate::Expression::ImageLoad {
+ image,
+ level: Some(level_expr),
+ ..
+ } = ctx.expressions[handle]
+ {
+ if let TypeInner::Image {
+ class: crate::ImageClass::Sampled { .. },
+ ..
+ } = *ctx.info[image].ty.inner_with(&self.module.types)
+ {
+ if let proc::BoundsCheckPolicy::Restrict = self.policies.image {
+ write!(self.out, "{}", level)?;
+ self.write_clamped_lod(ctx, handle, image, level_expr)?
+ }
+ }
+ }
+
+ if let Some(name) = expr_name {
+ write!(self.out, "{}", level)?;
+ self.write_named_expr(handle, name, ctx)?;
+ }
+ }
+ }
+ // Blocks are simple we just need to write the block statements between braces
+ // We could also just print the statements but this is more readable and maps more
+ // closely to the IR
+ Statement::Block(ref block) => {
+ write!(self.out, "{}", level)?;
+ writeln!(self.out, "{{")?;
+ for sta in block.iter() {
+ // Increase the indentation to help with readability
+ self.write_stmt(sta, ctx, level.next())?
+ }
+ writeln!(self.out, "{}}}", level)?
+ }
+ // Ifs are written as in C:
+ // ```
+ // if(condition) {
+ // accept
+ // } else {
+ // reject
+ // }
+ // ```
+ Statement::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ write!(self.out, "{}", level)?;
+ write!(self.out, "if (")?;
+ self.write_expr(condition, ctx)?;
+ writeln!(self.out, ") {{")?;
+
+ for sta in accept {
+ // Increase indentation to help with readability
+ self.write_stmt(sta, ctx, level.next())?;
+ }
+
+ // If there are no statements in the reject block we skip writing it
+ // This is only for readability
+ if !reject.is_empty() {
+ writeln!(self.out, "{}}} else {{", level)?;
+
+ for sta in reject {
+ // Increase indentation to help with readability
+ self.write_stmt(sta, ctx, level.next())?;
+ }
+ }
+
+ writeln!(self.out, "{}}}", level)?
+ }
+ // Switch are written as in C:
+ // ```
+ // switch (selector) {
+ // // Fallthrough
+ // case label:
+ // block
+ // // Non fallthrough
+ // case label:
+ // block
+ // break;
+ // default:
+ // block
+ // }
+ // ```
+ // Where the `default` case happens isn't important but we put it last
+ // so that we don't need to print a `break` for it
+ Statement::Switch {
+ selector,
+ ref cases,
+ } => {
+ // Start the switch
+ write!(self.out, "{}", level)?;
+ write!(self.out, "switch(")?;
+ self.write_expr(selector, ctx)?;
+ writeln!(self.out, ") {{")?;
+ let type_postfix = match *ctx.info[selector].ty.inner_with(&self.module.types) {
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Uint,
+ ..
+ } => "u",
+ _ => "",
+ };
+
+ // Write all cases
+ let l2 = level.next();
+ for case in cases {
+ match case.value {
+ crate::SwitchValue::Integer(value) => {
+ writeln!(self.out, "{}case {}{}:", l2, value, type_postfix)?
+ }
+ crate::SwitchValue::Default => writeln!(self.out, "{}default:", l2)?,
+ }
+
+ for sta in case.body.iter() {
+ self.write_stmt(sta, ctx, l2.next())?;
+ }
+
+ // Write fallthrough comment if the case is fallthrough,
+ // otherwise write a break, if the case is not already
+ // broken out of at the end of its body.
+ if case.fall_through {
+ writeln!(self.out, "{}/* fallthrough */", l2.next())?;
+ } else if case.body.last().map_or(true, |s| !s.is_terminator()) {
+ writeln!(self.out, "{}break;", l2.next())?;
+ }
+ }
+
+ writeln!(self.out, "{}}}", level)?
+ }
+ // Loops in naga IR are based on wgsl loops, glsl can emulate the behaviour by using a
+ // while true loop and appending the continuing block to the body resulting on:
+ // ```
+ // bool loop_init = true;
+ // while(true) {
+ // if (!loop_init) { <continuing> }
+ // loop_init = false;
+ // <body>
+ // }
+ // ```
+ Statement::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ if !continuing.is_empty() || break_if.is_some() {
+ let gate_name = self.namer.call("loop_init");
+ writeln!(self.out, "{}bool {} = true;", level, gate_name)?;
+ writeln!(self.out, "{}while(true) {{", level)?;
+ let l2 = level.next();
+ let l3 = l2.next();
+ writeln!(self.out, "{}if (!{}) {{", l2, gate_name)?;
+ for sta in continuing {
+ self.write_stmt(sta, ctx, l3)?;
+ }
+ if let Some(condition) = break_if {
+ write!(self.out, "{}if (", l3)?;
+ self.write_expr(condition, ctx)?;
+ writeln!(self.out, ") {{")?;
+ writeln!(self.out, "{}break;", l3.next())?;
+ writeln!(self.out, "{}}}", l3)?;
+ }
+ writeln!(self.out, "{}}}", l2)?;
+ writeln!(self.out, "{}{} = false;", level.next(), gate_name)?;
+ } else {
+ writeln!(self.out, "{}while(true) {{", level)?;
+ }
+ for sta in body {
+ self.write_stmt(sta, ctx, level.next())?;
+ }
+ writeln!(self.out, "{}}}", level)?
+ }
+ // Break, continue and return as written as in C
+ // `break;`
+ Statement::Break => {
+ write!(self.out, "{}", level)?;
+ writeln!(self.out, "break;")?
+ }
+ // `continue;`
+ Statement::Continue => {
+ write!(self.out, "{}", level)?;
+ writeln!(self.out, "continue;")?
+ }
+ // `return expr;`, `expr` is optional
+ Statement::Return { value } => {
+ write!(self.out, "{}", level)?;
+ match ctx.ty {
+ back::FunctionType::Function(_) => {
+ write!(self.out, "return")?;
+ // Write the expression to be returned if needed
+ if let Some(expr) = value {
+ write!(self.out, " ")?;
+ self.write_expr(expr, ctx)?;
+ }
+ writeln!(self.out, ";")?;
+ }
+ back::FunctionType::EntryPoint(ep_index) => {
+ let ep = &self.module.entry_points[ep_index as usize];
+ if let Some(ref result) = ep.function.result {
+ let value = value.unwrap();
+ match self.module.types[result.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ let temp_struct_name = match ctx.expressions[value] {
+ crate::Expression::Compose { .. } => {
+ let return_struct = "_tmp_return";
+ write!(
+ self.out,
+ "{} {} = ",
+ &self.names[&NameKey::Type(result.ty)],
+ return_struct
+ )?;
+ self.write_expr(value, ctx)?;
+ writeln!(self.out, ";")?;
+ write!(self.out, "{}", level)?;
+ Some(return_struct)
+ }
+ _ => None,
+ };
+
+ for (index, member) in members.iter().enumerate() {
+ // TODO: handle builtin in better way
+ if let Some(crate::Binding::BuiltIn(builtin)) =
+ member.binding
+ {
+ match builtin {
+ crate::BuiltIn::ClipDistance
+ | crate::BuiltIn::CullDistance
+ | crate::BuiltIn::PointSize => {
+ if self.options.version.is_es() {
+ continue;
+ }
+ }
+ _ => {}
+ }
+ }
+
+ let varying_name = VaryingName {
+ binding: member.binding.as_ref().unwrap(),
+ stage: ep.stage,
+ output: true,
+ targetting_webgl: self.options.version.is_webgl(),
+ };
+ write!(self.out, "{} = ", varying_name)?;
+
+ if let Some(struct_name) = temp_struct_name {
+ write!(self.out, "{}", struct_name)?;
+ } else {
+ self.write_expr(value, ctx)?;
+ }
+
+ // Write field name
+ writeln!(
+ self.out,
+ ".{};",
+ &self.names
+ [&NameKey::StructMember(result.ty, index as u32)]
+ )?;
+ write!(self.out, "{}", level)?;
+ }
+ }
+ _ => {
+ let name = VaryingName {
+ binding: result.binding.as_ref().unwrap(),
+ stage: ep.stage,
+ output: true,
+ targetting_webgl: self.options.version.is_webgl(),
+ };
+ write!(self.out, "{} = ", name)?;
+ self.write_expr(value, ctx)?;
+ writeln!(self.out, ";")?;
+ write!(self.out, "{}", level)?;
+ }
+ }
+ }
+
+ if let back::FunctionType::EntryPoint(ep_index) = ctx.ty {
+ if self.module.entry_points[ep_index as usize].stage
+ == crate::ShaderStage::Vertex
+ && self
+ .options
+ .writer_flags
+ .contains(WriterFlags::ADJUST_COORDINATE_SPACE)
+ {
+ writeln!(
+ self.out,
+ "gl_Position.yz = vec2(-gl_Position.y, gl_Position.z * 2.0 - gl_Position.w);",
+ )?;
+ write!(self.out, "{}", level)?;
+ }
+ }
+ writeln!(self.out, "return;")?;
+ }
+ }
+ }
+ // This is one of the places were glsl adds to the syntax of C in this case the discard
+ // keyword which ceases all further processing in a fragment shader, it's called OpKill
+ // in spir-v that's why it's called `Statement::Kill`
+ Statement::Kill => writeln!(self.out, "{}discard;", level)?,
+ // Issue a memory barrier. Please note that to ensure visibility,
+ // OpenGL always requires a call to the `barrier()` function after a `memoryBarrier*()`
+ Statement::Barrier(flags) => {
+ if flags.contains(crate::Barrier::STORAGE) {
+ writeln!(self.out, "{}memoryBarrierBuffer();", level)?;
+ }
+
+ if flags.contains(crate::Barrier::WORK_GROUP) {
+ writeln!(self.out, "{}memoryBarrierShared();", level)?;
+ }
+
+ writeln!(self.out, "{}barrier();", level)?;
+ }
+ // Stores in glsl are just variable assignments written as `pointer = value;`
+ Statement::Store { pointer, value } => {
+ write!(self.out, "{}", level)?;
+ self.write_expr(pointer, ctx)?;
+ write!(self.out, " = ")?;
+ self.write_expr(value, ctx)?;
+ writeln!(self.out, ";")?
+ }
+ // Stores a value into an image.
+ Statement::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => {
+ write!(self.out, "{}", level)?;
+ self.write_image_store(ctx, image, coordinate, array_index, value)?
+ }
+ // A `Call` is written `name(arguments)` where `arguments` is a comma separated expressions list
+ Statement::Call {
+ function,
+ ref arguments,
+ result,
+ } => {
+ write!(self.out, "{}", level)?;
+ if let Some(expr) = result {
+ let name = format!("{}{}", back::BAKE_PREFIX, expr.index());
+ let result = self.module.functions[function].result.as_ref().unwrap();
+ self.write_type(result.ty)?;
+ write!(self.out, " {} = ", name)?;
+ self.named_expressions.insert(expr, name);
+ }
+ write!(self.out, "{}(", &self.names[&NameKey::Function(function)])?;
+ let arguments: Vec<_> = arguments
+ .iter()
+ .enumerate()
+ .filter_map(|(i, arg)| {
+ let arg_ty = self.module.functions[function].arguments[i].ty;
+ match self.module.types[arg_ty].inner {
+ TypeInner::Sampler { .. } => None,
+ _ => Some(*arg),
+ }
+ })
+ .collect();
+ self.write_slice(&arguments, |this, _, arg| this.write_expr(*arg, ctx))?;
+ writeln!(self.out, ");")?
+ }
+ Statement::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result,
+ } => {
+ write!(self.out, "{}", level)?;
+ let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ let res_ty = ctx.info[result].ty.inner_with(&self.module.types);
+ self.write_value_type(res_ty)?;
+ write!(self.out, " {} = ", res_name)?;
+ self.named_expressions.insert(result, res_name);
+
+ let fun_str = fun.to_glsl();
+ write!(self.out, "atomic{}(", fun_str)?;
+ self.write_expr(pointer, ctx)?;
+ write!(self.out, ", ")?;
+ // handle the special cases
+ match *fun {
+ crate::AtomicFunction::Subtract => {
+ // we just wrote `InterlockedAdd`, so negate the argument
+ write!(self.out, "-")?;
+ }
+ crate::AtomicFunction::Exchange { compare: Some(_) } => {
+ return Err(Error::Custom(
+ "atomic CompareExchange is not implemented".to_string(),
+ ));
+ }
+ _ => {}
+ }
+ self.write_expr(value, ctx)?;
+ writeln!(self.out, ");")?;
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Helper method to write expressions
+ ///
+ /// # Notes
+ /// Doesn't add any newlines or leading/trailing spaces
+ fn write_expr(
+ &mut self,
+ expr: Handle<crate::Expression>,
+ ctx: &back::FunctionCtx<'_>,
+ ) -> BackendResult {
+ use crate::Expression;
+
+ if let Some(name) = self.named_expressions.get(&expr) {
+ write!(self.out, "{}", name)?;
+ return Ok(());
+ }
+
+ match ctx.expressions[expr] {
+ // `Access` is applied to arrays, vectors and matrices and is written as indexing
+ Expression::Access { base, index } => {
+ self.write_expr(base, ctx)?;
+ write!(self.out, "[")?;
+ self.write_expr(index, ctx)?;
+ write!(self.out, "]")?
+ }
+ // `AccessIndex` is the same as `Access` except that the index is a constant and it can
+ // be applied to structs, in this case we need to find the name of the field at that
+ // index and write `base.field_name`
+ Expression::AccessIndex { base, index } => {
+ self.write_expr(base, ctx)?;
+
+ let base_ty_res = &ctx.info[base].ty;
+ let mut resolved = base_ty_res.inner_with(&self.module.types);
+ let base_ty_handle = match *resolved {
+ TypeInner::Pointer { base, space: _ } => {
+ resolved = &self.module.types[base].inner;
+ Some(base)
+ }
+ _ => base_ty_res.handle(),
+ };
+
+ match *resolved {
+ TypeInner::Vector { .. } => {
+ // Write vector access as a swizzle
+ write!(self.out, ".{}", back::COMPONENTS[index as usize])?
+ }
+ TypeInner::Matrix { .. }
+ | TypeInner::Array { .. }
+ | TypeInner::ValuePointer { .. } => write!(self.out, "[{}]", index)?,
+ TypeInner::Struct { .. } => {
+ // This will never panic in case the type is a `Struct`, this is not true
+ // for other types so we can only check while inside this match arm
+ let ty = base_ty_handle.unwrap();
+
+ write!(
+ self.out,
+ ".{}",
+ &self.names[&NameKey::StructMember(ty, index)]
+ )?
+ }
+ ref other => return Err(Error::Custom(format!("Cannot index {:?}", other))),
+ }
+ }
+ // Constants are delegated to `write_constant`
+ Expression::Constant(constant) => self.write_constant(constant)?,
+ // `Splat` needs to actually write down a vector, it's not always inferred in GLSL.
+ Expression::Splat { size: _, value } => {
+ let resolved = ctx.info[expr].ty.inner_with(&self.module.types);
+ self.write_value_type(resolved)?;
+ write!(self.out, "(")?;
+ self.write_expr(value, ctx)?;
+ write!(self.out, ")")?
+ }
+ // `Swizzle` adds a few letters behind the dot.
+ Expression::Swizzle {
+ size,
+ vector,
+ pattern,
+ } => {
+ self.write_expr(vector, ctx)?;
+ write!(self.out, ".")?;
+ for &sc in pattern[..size as usize].iter() {
+ self.out.write_char(back::COMPONENTS[sc as usize])?;
+ }
+ }
+ // `Compose` is pretty simple we just write `type(components)` where `components` is a
+ // comma separated list of expressions
+ Expression::Compose { ty, ref components } => {
+ self.write_type(ty)?;
+
+ let resolved = ctx.info[expr].ty.inner_with(&self.module.types);
+ if let TypeInner::Array { base, size, .. } = *resolved {
+ self.write_array_size(base, size)?;
+ }
+
+ write!(self.out, "(")?;
+ self.write_slice(components, |this, _, arg| this.write_expr(*arg, ctx))?;
+ write!(self.out, ")")?
+ }
+ // Function arguments are written as the argument name
+ Expression::FunctionArgument(pos) => {
+ write!(self.out, "{}", &self.names[&ctx.argument_key(pos)])?
+ }
+ // Global variables need some special work for their name but
+ // `get_global_name` does the work for us
+ Expression::GlobalVariable(handle) => {
+ let global = &self.module.global_variables[handle];
+ self.write_global_name(handle, global)?
+ }
+ // A local is written as it's name
+ Expression::LocalVariable(handle) => {
+ write!(self.out, "{}", self.names[&ctx.name_key(handle)])?
+ }
+ // glsl has no pointers so there's no load operation, just write the pointer expression
+ Expression::Load { pointer } => self.write_expr(pointer, ctx)?,
+ // `ImageSample` is a bit complicated compared to the rest of the IR.
+ //
+ // First there are three variations depending whether the sample level is explicitly set,
+ // if it's automatic or it it's bias:
+ // `texture(image, coordinate)` - Automatic sample level
+ // `texture(image, coordinate, bias)` - Bias sample level
+ // `textureLod(image, coordinate, level)` - Zero or Exact sample level
+ //
+ // Furthermore if `depth_ref` is some we need to append it to the coordinate vector
+ Expression::ImageSample {
+ image,
+ sampler: _, //TODO?
+ gather,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ } => {
+ let dim = match *ctx.info[image].ty.inner_with(&self.module.types) {
+ TypeInner::Image { dim, .. } => dim,
+ _ => unreachable!(),
+ };
+
+ if dim == crate::ImageDimension::Cube
+ && array_index.is_some()
+ && depth_ref.is_some()
+ {
+ match level {
+ crate::SampleLevel::Zero
+ | crate::SampleLevel::Exact(_)
+ | crate::SampleLevel::Gradient { .. }
+ | crate::SampleLevel::Bias(_) => {
+ return Err(Error::Custom(String::from(
+ "gsamplerCubeArrayShadow isn't supported in textureGrad, \
+ textureLod or texture with bias",
+ )))
+ }
+ crate::SampleLevel::Auto => {}
+ }
+ }
+
+ // textureLod on sampler2DArrayShadow and samplerCubeShadow does not exist in GLSL.
+ // To emulate this, we will have to use textureGrad with a constant gradient of 0.
+ let workaround_lod_array_shadow_as_grad = (array_index.is_some()
+ || dim == crate::ImageDimension::Cube)
+ && depth_ref.is_some()
+ && gather.is_none()
+ && !self
+ .options
+ .writer_flags
+ .contains(WriterFlags::TEXTURE_SHADOW_LOD);
+
+ //Write the function to be used depending on the sample level
+ let fun_name = match level {
+ crate::SampleLevel::Zero if gather.is_some() => "textureGather",
+ crate::SampleLevel::Auto | crate::SampleLevel::Bias(_) => "texture",
+ crate::SampleLevel::Zero | crate::SampleLevel::Exact(_) => {
+ if workaround_lod_array_shadow_as_grad {
+ "textureGrad"
+ } else {
+ "textureLod"
+ }
+ }
+ crate::SampleLevel::Gradient { .. } => "textureGrad",
+ };
+ let offset_name = match offset {
+ Some(_) => "Offset",
+ None => "",
+ };
+
+ write!(self.out, "{}{}(", fun_name, offset_name)?;
+
+ // Write the image that will be used
+ self.write_expr(image, ctx)?;
+ // The space here isn't required but it helps with readability
+ write!(self.out, ", ")?;
+
+ // We need to get the coordinates vector size to later build a vector that's `size + 1`
+ // if `depth_ref` is some, if it isn't a vector we panic as that's not a valid expression
+ let mut coord_dim = match *ctx.info[coordinate].ty.inner_with(&self.module.types) {
+ TypeInner::Vector { size, .. } => size as u8,
+ TypeInner::Scalar { .. } => 1,
+ _ => unreachable!(),
+ };
+
+ if array_index.is_some() {
+ coord_dim += 1;
+ }
+ let merge_depth_ref = depth_ref.is_some() && gather.is_none() && coord_dim < 4;
+ if merge_depth_ref {
+ coord_dim += 1;
+ }
+
+ let tex_1d_hack = dim == crate::ImageDimension::D1 && self.options.version.is_es();
+ let is_vec = tex_1d_hack || coord_dim != 1;
+ // Compose a new texture coordinates vector
+ if is_vec {
+ write!(self.out, "vec{}(", coord_dim + tex_1d_hack as u8)?;
+ }
+ self.write_expr(coordinate, ctx)?;
+ if tex_1d_hack {
+ write!(self.out, ", 0.0")?;
+ }
+ if let Some(expr) = array_index {
+ write!(self.out, ", ")?;
+ self.write_expr(expr, ctx)?;
+ }
+ if merge_depth_ref {
+ write!(self.out, ", ")?;
+ self.write_expr(depth_ref.unwrap(), ctx)?;
+ }
+ if is_vec {
+ write!(self.out, ")")?;
+ }
+
+ if let (Some(expr), false) = (depth_ref, merge_depth_ref) {
+ write!(self.out, ", ")?;
+ self.write_expr(expr, ctx)?;
+ }
+
+ match level {
+ // Auto needs no more arguments
+ crate::SampleLevel::Auto => (),
+ // Zero needs level set to 0
+ crate::SampleLevel::Zero => {
+ if workaround_lod_array_shadow_as_grad {
+ let vec_dim = match dim {
+ crate::ImageDimension::Cube => 3,
+ _ => 2,
+ };
+ write!(self.out, ", vec{}(0.0), vec{}(0.0)", vec_dim, vec_dim)?;
+ } else if gather.is_none() {
+ write!(self.out, ", 0.0")?;
+ }
+ }
+ // Exact and bias require another argument
+ crate::SampleLevel::Exact(expr) => {
+ if workaround_lod_array_shadow_as_grad {
+ log::warn!("Unable to `textureLod` a shadow array, ignoring the LOD");
+ write!(self.out, ", vec2(0,0), vec2(0,0)")?;
+ } else {
+ write!(self.out, ", ")?;
+ self.write_expr(expr, ctx)?;
+ }
+ }
+ crate::SampleLevel::Bias(_) => {
+ // This needs to be done after the offset writing
+ }
+ crate::SampleLevel::Gradient { x, y } => {
+ // If we are using sampler2D to replace sampler1D, we also
+ // need to make sure to use vec2 gradients
+ if tex_1d_hack {
+ write!(self.out, ", vec2(")?;
+ self.write_expr(x, ctx)?;
+ write!(self.out, ", 0.0)")?;
+ write!(self.out, ", vec2(")?;
+ self.write_expr(y, ctx)?;
+ write!(self.out, ", 0.0)")?;
+ } else {
+ write!(self.out, ", ")?;
+ self.write_expr(x, ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(y, ctx)?;
+ }
+ }
+ }
+
+ if let Some(constant) = offset {
+ write!(self.out, ", ")?;
+ if tex_1d_hack {
+ write!(self.out, "ivec2(")?;
+ }
+ self.write_constant(constant)?;
+ if tex_1d_hack {
+ write!(self.out, ", 0)")?;
+ }
+ }
+
+ // Bias is always the last argument
+ if let crate::SampleLevel::Bias(expr) = level {
+ write!(self.out, ", ")?;
+ self.write_expr(expr, ctx)?;
+ }
+
+ if let (Some(component), None) = (gather, depth_ref) {
+ write!(self.out, ", {}", component as usize)?;
+ }
+
+ // End the function
+ write!(self.out, ")")?
+ }
+ Expression::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => self.write_image_load(expr, ctx, image, coordinate, array_index, sample, level)?,
+ // Query translates into one of the:
+ // - textureSize/imageSize
+ // - textureQueryLevels
+ // - textureSamples/imageSamples
+ Expression::ImageQuery { image, query } => {
+ use crate::ImageClass;
+
+ // This will only panic if the module is invalid
+ let (dim, class) = match *ctx.info[image].ty.inner_with(&self.module.types) {
+ TypeInner::Image {
+ dim,
+ arrayed: _,
+ class,
+ } => (dim, class),
+ _ => unreachable!(),
+ };
+ let components = match dim {
+ crate::ImageDimension::D1 => 1,
+ crate::ImageDimension::D2 => 2,
+ crate::ImageDimension::D3 => 3,
+ crate::ImageDimension::Cube => 2,
+ };
+ match query {
+ crate::ImageQuery::Size { level } => {
+ match class {
+ ImageClass::Sampled { multi, .. } | ImageClass::Depth { multi } => {
+ write!(self.out, "textureSize(")?;
+ self.write_expr(image, ctx)?;
+ if let Some(expr) = level {
+ write!(self.out, ", ")?;
+ self.write_expr(expr, ctx)?;
+ } else if !multi {
+ // All textureSize calls requires an lod argument
+ // except for multisampled samplers
+ write!(self.out, ", 0")?;
+ }
+ }
+ ImageClass::Storage { .. } => {
+ write!(self.out, "imageSize(")?;
+ self.write_expr(image, ctx)?;
+ }
+ }
+ write!(self.out, ")")?;
+ if components != 1 || self.options.version.is_es() {
+ write!(self.out, ".{}", &"xyz"[..components])?;
+ }
+ }
+ crate::ImageQuery::NumLevels => {
+ write!(self.out, "textureQueryLevels(",)?;
+ self.write_expr(image, ctx)?;
+ write!(self.out, ")",)?;
+ }
+ crate::ImageQuery::NumLayers => {
+ let fun_name = match class {
+ ImageClass::Sampled { .. } | ImageClass::Depth { .. } => "textureSize",
+ ImageClass::Storage { .. } => "imageSize",
+ };
+ write!(self.out, "{}(", fun_name)?;
+ self.write_expr(image, ctx)?;
+ // All textureSize calls requires an lod argument
+ // except for multisampled samplers
+ if class.is_multisampled() {
+ write!(self.out, ", 0")?;
+ }
+ write!(self.out, ")")?;
+ if components != 1 || self.options.version.is_es() {
+ write!(self.out, ".{}", back::COMPONENTS[components])?;
+ }
+ }
+ crate::ImageQuery::NumSamples => {
+ let fun_name = match class {
+ ImageClass::Sampled { .. } | ImageClass::Depth { .. } => {
+ "textureSamples"
+ }
+ ImageClass::Storage { .. } => "imageSamples",
+ };
+ write!(self.out, "{}(", fun_name)?;
+ self.write_expr(image, ctx)?;
+ write!(self.out, ")",)?;
+ }
+ }
+ }
+ // `Unary` is pretty straightforward
+ // "-" - for `Negate`
+ // "~" - for `Not` if it's an integer
+ // "!" - for `Not` if it's a boolean
+ //
+ // We also wrap the everything in parentheses to avoid precedence issues
+ Expression::Unary { op, expr } => {
+ use crate::{ScalarKind as Sk, UnaryOperator as Uo};
+
+ let ty = ctx.info[expr].ty.inner_with(&self.module.types);
+
+ match *ty {
+ TypeInner::Vector { kind: Sk::Bool, .. } => {
+ write!(self.out, "not(")?;
+ }
+ _ => {
+ let operator = match op {
+ Uo::Negate => "-",
+ Uo::Not => match ty.scalar_kind() {
+ Some(Sk::Sint) | Some(Sk::Uint) => "~",
+ Some(Sk::Bool) => "!",
+ ref other => {
+ return Err(Error::Custom(format!(
+ "Cannot apply not to type {:?}",
+ other
+ )))
+ }
+ },
+ };
+
+ write!(self.out, "({}", operator)?;
+ }
+ }
+
+ self.write_expr(expr, ctx)?;
+
+ write!(self.out, ")")?
+ }
+ // `Binary` we just write `left op right`, except when dealing with
+ // comparison operations on vectors as they are implemented with
+ // builtin functions.
+ // Once again we wrap everything in parentheses to avoid precedence issues
+ Expression::Binary {
+ mut op,
+ left,
+ right,
+ } => {
+ // Holds `Some(function_name)` if the binary operation is
+ // implemented as a function call
+ use crate::{BinaryOperator as Bo, ScalarKind as Sk, TypeInner as Ti};
+
+ let left_inner = ctx.info[left].ty.inner_with(&self.module.types);
+ let right_inner = ctx.info[right].ty.inner_with(&self.module.types);
+
+ let function = match (left_inner, right_inner) {
+ (&Ti::Vector { kind, .. }, &Ti::Vector { .. }) => match op {
+ Bo::Less
+ | Bo::LessEqual
+ | Bo::Greater
+ | Bo::GreaterEqual
+ | Bo::Equal
+ | Bo::NotEqual => BinaryOperation::VectorCompare,
+ Bo::Modulo if kind == Sk::Float => BinaryOperation::Modulo,
+ Bo::And if kind == Sk::Bool => {
+ op = crate::BinaryOperator::LogicalAnd;
+ BinaryOperation::VectorComponentWise
+ }
+ Bo::InclusiveOr if kind == Sk::Bool => {
+ op = crate::BinaryOperator::LogicalOr;
+ BinaryOperation::VectorComponentWise
+ }
+ _ => BinaryOperation::Other,
+ },
+ _ => match (left_inner.scalar_kind(), right_inner.scalar_kind()) {
+ (Some(Sk::Float), _) | (_, Some(Sk::Float)) => match op {
+ Bo::Modulo => BinaryOperation::Modulo,
+ _ => BinaryOperation::Other,
+ },
+ (Some(Sk::Bool), Some(Sk::Bool)) => match op {
+ Bo::InclusiveOr => {
+ op = crate::BinaryOperator::LogicalOr;
+ BinaryOperation::Other
+ }
+ Bo::And => {
+ op = crate::BinaryOperator::LogicalAnd;
+ BinaryOperation::Other
+ }
+ _ => BinaryOperation::Other,
+ },
+ _ => BinaryOperation::Other,
+ },
+ };
+
+ match function {
+ BinaryOperation::VectorCompare => {
+ let op_str = match op {
+ Bo::Less => "lessThan(",
+ Bo::LessEqual => "lessThanEqual(",
+ Bo::Greater => "greaterThan(",
+ Bo::GreaterEqual => "greaterThanEqual(",
+ Bo::Equal => "equal(",
+ Bo::NotEqual => "notEqual(",
+ _ => unreachable!(),
+ };
+ write!(self.out, "{}", op_str)?;
+ self.write_expr(left, ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(right, ctx)?;
+ write!(self.out, ")")?;
+ }
+ BinaryOperation::VectorComponentWise => {
+ self.write_value_type(left_inner)?;
+ write!(self.out, "(")?;
+
+ let size = match *left_inner {
+ Ti::Vector { size, .. } => size,
+ _ => unreachable!(),
+ };
+
+ for i in 0..size as usize {
+ if i != 0 {
+ write!(self.out, ", ")?;
+ }
+
+ self.write_expr(left, ctx)?;
+ write!(self.out, ".{}", back::COMPONENTS[i])?;
+
+ write!(self.out, " {} ", back::binary_operation_str(op))?;
+
+ self.write_expr(right, ctx)?;
+ write!(self.out, ".{}", back::COMPONENTS[i])?;
+ }
+
+ write!(self.out, ")")?;
+ }
+ // TODO: handle undefined behavior of BinaryOperator::Modulo
+ //
+ // sint:
+ // if right == 0 return 0
+ // if left == min(type_of(left)) && right == -1 return 0
+ // if sign(left) == -1 || sign(right) == -1 return result as defined by WGSL
+ //
+ // uint:
+ // if right == 0 return 0
+ //
+ // float:
+ // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798
+ BinaryOperation::Modulo => {
+ write!(self.out, "(")?;
+
+ // write `e1 - e2 * trunc(e1 / e2)`
+ self.write_expr(left, ctx)?;
+ write!(self.out, " - ")?;
+ self.write_expr(right, ctx)?;
+ write!(self.out, " * ")?;
+ write!(self.out, "trunc(")?;
+ self.write_expr(left, ctx)?;
+ write!(self.out, " / ")?;
+ self.write_expr(right, ctx)?;
+ write!(self.out, ")")?;
+
+ write!(self.out, ")")?;
+ }
+ BinaryOperation::Other => {
+ write!(self.out, "(")?;
+
+ self.write_expr(left, ctx)?;
+ write!(self.out, " {} ", back::binary_operation_str(op))?;
+ self.write_expr(right, ctx)?;
+
+ write!(self.out, ")")?;
+ }
+ }
+ }
+ // `Select` is written as `condition ? accept : reject`
+ // We wrap everything in parentheses to avoid precedence issues
+ Expression::Select {
+ condition,
+ accept,
+ reject,
+ } => {
+ let cond_ty = ctx.info[condition].ty.inner_with(&self.module.types);
+ let vec_select = if let TypeInner::Vector { .. } = *cond_ty {
+ true
+ } else {
+ false
+ };
+
+ // TODO: Boolean mix on desktop required GL_EXT_shader_integer_mix
+ if vec_select {
+ // Glsl defines that for mix when the condition is a boolean the first element
+ // is picked if condition is false and the second if condition is true
+ write!(self.out, "mix(")?;
+ self.write_expr(reject, ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(accept, ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(condition, ctx)?;
+ } else {
+ write!(self.out, "(")?;
+ self.write_expr(condition, ctx)?;
+ write!(self.out, " ? ")?;
+ self.write_expr(accept, ctx)?;
+ write!(self.out, " : ")?;
+ self.write_expr(reject, ctx)?;
+ }
+
+ write!(self.out, ")")?
+ }
+ // `Derivative` is a function call to a glsl provided function
+ Expression::Derivative { axis, expr } => {
+ use crate::DerivativeAxis as Da;
+
+ write!(
+ self.out,
+ "{}(",
+ match axis {
+ Da::X => "dFdx",
+ Da::Y => "dFdy",
+ Da::Width => "fwidth",
+ }
+ )?;
+ self.write_expr(expr, ctx)?;
+ write!(self.out, ")")?
+ }
+ // `Relational` is a normal function call to some glsl provided functions
+ Expression::Relational { fun, argument } => {
+ use crate::RelationalFunction as Rf;
+
+ let fun_name = match fun {
+ // There's no specific function for this but we can invert the result of `isinf`
+ Rf::IsFinite => "!isinf",
+ Rf::IsInf => "isinf",
+ Rf::IsNan => "isnan",
+ // There's also no function for this but we can invert `isnan`
+ Rf::IsNormal => "!isnan",
+ Rf::All => "all",
+ Rf::Any => "any",
+ };
+ write!(self.out, "{}(", fun_name)?;
+
+ self.write_expr(argument, ctx)?;
+
+ write!(self.out, ")")?
+ }
+ Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ use crate::MathFunction as Mf;
+
+ let fun_name = match fun {
+ // comparison
+ Mf::Abs => "abs",
+ Mf::Min => "min",
+ Mf::Max => "max",
+ Mf::Clamp => "clamp",
+ Mf::Saturate => {
+ write!(self.out, "clamp(")?;
+
+ self.write_expr(arg, ctx)?;
+
+ match *ctx.info[arg].ty.inner_with(&self.module.types) {
+ crate::TypeInner::Vector { size, .. } => write!(
+ self.out,
+ ", vec{}(0.0), vec{0}(1.0)",
+ back::vector_size_str(size)
+ )?,
+ _ => write!(self.out, ", 0.0, 1.0")?,
+ }
+
+ write!(self.out, ")")?;
+
+ return Ok(());
+ }
+ // trigonometry
+ Mf::Cos => "cos",
+ Mf::Cosh => "cosh",
+ Mf::Sin => "sin",
+ Mf::Sinh => "sinh",
+ Mf::Tan => "tan",
+ Mf::Tanh => "tanh",
+ Mf::Acos => "acos",
+ Mf::Asin => "asin",
+ Mf::Atan => "atan",
+ Mf::Asinh => "asinh",
+ Mf::Acosh => "acosh",
+ Mf::Atanh => "atanh",
+ Mf::Radians => "radians",
+ Mf::Degrees => "degrees",
+ // glsl doesn't have atan2 function
+ // use two-argument variation of the atan function
+ Mf::Atan2 => "atan",
+ // decomposition
+ Mf::Ceil => "ceil",
+ Mf::Floor => "floor",
+ Mf::Round => "roundEven",
+ Mf::Fract => "fract",
+ Mf::Trunc => "trunc",
+ Mf::Modf => "modf",
+ Mf::Frexp => "frexp",
+ Mf::Ldexp => "ldexp",
+ // exponent
+ Mf::Exp => "exp",
+ Mf::Exp2 => "exp2",
+ Mf::Log => "log",
+ Mf::Log2 => "log2",
+ Mf::Pow => "pow",
+ // geometry
+ Mf::Dot => match *ctx.info[arg].ty.inner_with(&self.module.types) {
+ crate::TypeInner::Vector {
+ kind: crate::ScalarKind::Float,
+ ..
+ } => "dot",
+ crate::TypeInner::Vector { size, .. } => {
+ return self.write_dot_product(arg, arg1.unwrap(), size as usize, ctx)
+ }
+ _ => unreachable!(
+ "Correct TypeInner for dot product should be already validated"
+ ),
+ },
+ Mf::Outer => "outerProduct",
+ Mf::Cross => "cross",
+ Mf::Distance => "distance",
+ Mf::Length => "length",
+ Mf::Normalize => "normalize",
+ Mf::FaceForward => "faceforward",
+ Mf::Reflect => "reflect",
+ Mf::Refract => "refract",
+ // computational
+ Mf::Sign => "sign",
+ Mf::Fma => {
+ if self.options.version.supports_fma_function() {
+ // Use the fma function when available
+ "fma"
+ } else {
+ // No fma support. Transform the function call into an arithmetic expression
+ write!(self.out, "(")?;
+
+ self.write_expr(arg, ctx)?;
+ write!(self.out, " * ")?;
+
+ let arg1 =
+ arg1.ok_or_else(|| Error::Custom("Missing fma arg1".to_owned()))?;
+ self.write_expr(arg1, ctx)?;
+ write!(self.out, " + ")?;
+
+ let arg2 =
+ arg2.ok_or_else(|| Error::Custom("Missing fma arg2".to_owned()))?;
+ self.write_expr(arg2, ctx)?;
+ write!(self.out, ")")?;
+
+ return Ok(());
+ }
+ }
+ Mf::Mix => "mix",
+ Mf::Step => "step",
+ Mf::SmoothStep => "smoothstep",
+ Mf::Sqrt => "sqrt",
+ Mf::InverseSqrt => "inversesqrt",
+ Mf::Inverse => "inverse",
+ Mf::Transpose => "transpose",
+ Mf::Determinant => "determinant",
+ // bits
+ Mf::CountOneBits => "bitCount",
+ Mf::ReverseBits => "bitfieldReverse",
+ Mf::ExtractBits => "bitfieldExtract",
+ Mf::InsertBits => "bitfieldInsert",
+ Mf::FindLsb => "findLSB",
+ Mf::FindMsb => "findMSB",
+ // data packing
+ Mf::Pack4x8snorm => "packSnorm4x8",
+ Mf::Pack4x8unorm => "packUnorm4x8",
+ Mf::Pack2x16snorm => "packSnorm2x16",
+ Mf::Pack2x16unorm => "packUnorm2x16",
+ Mf::Pack2x16float => "packHalf2x16",
+ // data unpacking
+ Mf::Unpack4x8snorm => "unpackSnorm4x8",
+ Mf::Unpack4x8unorm => "unpackUnorm4x8",
+ Mf::Unpack2x16snorm => "unpackSnorm2x16",
+ Mf::Unpack2x16unorm => "unpackUnorm2x16",
+ Mf::Unpack2x16float => "unpackHalf2x16",
+ };
+
+ let extract_bits = fun == Mf::ExtractBits;
+ let insert_bits = fun == Mf::InsertBits;
+
+ // Some GLSL functions always return signed integers (like findMSB),
+ // so they need to be cast to uint if the argument is also an uint.
+ let ret_might_need_int_to_uint =
+ matches!(fun, Mf::FindLsb | Mf::FindMsb | Mf::CountOneBits | Mf::Abs);
+
+ // Some GLSL functions only accept signed integers (like abs),
+ // so they need their argument cast from uint to int.
+ let arg_might_need_uint_to_int = matches!(fun, Mf::Abs);
+
+ // Check if the argument is an unsigned integer and return the vector size
+ // in case it's a vector
+ let maybe_uint_size = match *ctx.info[arg].ty.inner_with(&self.module.types) {
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Uint,
+ ..
+ } => Some(None),
+ crate::TypeInner::Vector {
+ kind: crate::ScalarKind::Uint,
+ size,
+ ..
+ } => Some(Some(size)),
+ _ => None,
+ };
+
+ // Cast to uint if the function needs it
+ if ret_might_need_int_to_uint {
+ if let Some(maybe_size) = maybe_uint_size {
+ match maybe_size {
+ Some(size) => write!(self.out, "uvec{}(", size as u8)?,
+ None => write!(self.out, "uint(")?,
+ }
+ }
+ }
+
+ write!(self.out, "{}(", fun_name)?;
+
+ // Cast to int if the function needs it
+ if arg_might_need_uint_to_int {
+ if let Some(maybe_size) = maybe_uint_size {
+ match maybe_size {
+ Some(size) => write!(self.out, "ivec{}(", size as u8)?,
+ None => write!(self.out, "int(")?,
+ }
+ }
+ }
+
+ self.write_expr(arg, ctx)?;
+
+ // Close the cast from uint to int
+ if arg_might_need_uint_to_int && maybe_uint_size.is_some() {
+ write!(self.out, ")")?
+ }
+
+ if let Some(arg) = arg1 {
+ write!(self.out, ", ")?;
+ if extract_bits {
+ write!(self.out, "int(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ")")?;
+ } else {
+ self.write_expr(arg, ctx)?;
+ }
+ }
+ if let Some(arg) = arg2 {
+ write!(self.out, ", ")?;
+ if extract_bits || insert_bits {
+ write!(self.out, "int(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ")")?;
+ } else {
+ self.write_expr(arg, ctx)?;
+ }
+ }
+ if let Some(arg) = arg3 {
+ write!(self.out, ", ")?;
+ if insert_bits {
+ write!(self.out, "int(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ")")?;
+ } else {
+ self.write_expr(arg, ctx)?;
+ }
+ }
+ write!(self.out, ")")?;
+
+ // Close the cast from int to uint
+ if ret_might_need_int_to_uint && maybe_uint_size.is_some() {
+ write!(self.out, ")")?
+ }
+ }
+ // `As` is always a call.
+ // If `convert` is true the function name is the type
+ // Else the function name is one of the glsl provided bitcast functions
+ Expression::As {
+ expr,
+ kind: target_kind,
+ convert,
+ } => {
+ let inner = ctx.info[expr].ty.inner_with(&self.module.types);
+ match convert {
+ Some(width) => {
+ // this is similar to `write_type`, but with the target kind
+ let scalar = glsl_scalar(target_kind, width)?;
+ match *inner {
+ TypeInner::Matrix { columns, rows, .. } => write!(
+ self.out,
+ "{}mat{}x{}",
+ scalar.prefix, columns as u8, rows as u8
+ )?,
+ TypeInner::Vector { size, .. } => {
+ write!(self.out, "{}vec{}", scalar.prefix, size as u8)?
+ }
+ _ => write!(self.out, "{}", scalar.full)?,
+ }
+
+ write!(self.out, "(")?;
+ self.write_expr(expr, ctx)?;
+ write!(self.out, ")")?
+ }
+ None => {
+ use crate::ScalarKind as Sk;
+
+ let target_vector_type = match *inner {
+ TypeInner::Vector { size, width, .. } => Some(TypeInner::Vector {
+ size,
+ width,
+ kind: target_kind,
+ }),
+ _ => None,
+ };
+
+ let source_kind = inner.scalar_kind().unwrap();
+
+ match (source_kind, target_kind, target_vector_type) {
+ // No conversion needed
+ (Sk::Sint, Sk::Sint, _)
+ | (Sk::Uint, Sk::Uint, _)
+ | (Sk::Float, Sk::Float, _)
+ | (Sk::Bool, Sk::Bool, _) => {
+ self.write_expr(expr, ctx)?;
+ return Ok(());
+ }
+
+ // Cast to/from floats
+ (Sk::Float, Sk::Sint, _) => write!(self.out, "floatBitsToInt")?,
+ (Sk::Float, Sk::Uint, _) => write!(self.out, "floatBitsToUint")?,
+ (Sk::Sint, Sk::Float, _) => write!(self.out, "intBitsToFloat")?,
+ (Sk::Uint, Sk::Float, _) => write!(self.out, "uintBitsToFloat")?,
+
+ // Cast between vector types
+ (_, _, Some(vector)) => {
+ self.write_value_type(&vector)?;
+ }
+
+ // There is no way to bitcast between Uint/Sint in glsl. Use constructor conversion
+ (Sk::Uint | Sk::Bool, Sk::Sint, None) => write!(self.out, "int")?,
+ (Sk::Sint | Sk::Bool, Sk::Uint, None) => write!(self.out, "uint")?,
+ (Sk::Bool, Sk::Float, None) => write!(self.out, "float")?,
+ (Sk::Sint | Sk::Uint | Sk::Float, Sk::Bool, None) => {
+ write!(self.out, "bool")?
+ }
+ };
+
+ write!(self.out, "(")?;
+ self.write_expr(expr, ctx)?;
+ write!(self.out, ")")?;
+ }
+ }
+ }
+ // These expressions never show up in `Emit`.
+ Expression::CallResult(_) | Expression::AtomicResult { .. } => unreachable!(),
+ // `ArrayLength` is written as `expr.length()` and we convert it to a uint
+ Expression::ArrayLength(expr) => {
+ write!(self.out, "uint(")?;
+ self.write_expr(expr, ctx)?;
+ write!(self.out, ".length())")?
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Helper function to write the local holding the clamped lod
+ fn write_clamped_lod(
+ &mut self,
+ ctx: &back::FunctionCtx,
+ expr: Handle<crate::Expression>,
+ image: Handle<crate::Expression>,
+ level_expr: Handle<crate::Expression>,
+ ) -> Result<(), Error> {
+ // Define our local and start a call to `clamp`
+ write!(
+ self.out,
+ "int {}{}{} = clamp(",
+ back::BAKE_PREFIX,
+ expr.index(),
+ CLAMPED_LOD_SUFFIX
+ )?;
+ // Write the lod that will be clamped
+ self.write_expr(level_expr, ctx)?;
+ // Set the min value to 0 and start a call to `textureQueryLevels` to get
+ // the maximum value
+ write!(self.out, ", 0, textureQueryLevels(")?;
+ // Write the target image as an argument to `textureQueryLevels`
+ self.write_expr(image, ctx)?;
+ // Close the call to `textureQueryLevels` subtract 1 from it since
+ // the lod argument is 0 based, close the `clamp` call and end the
+ // local declaration statement.
+ writeln!(self.out, ") - 1);")?;
+
+ Ok(())
+ }
+
+ // Helper method used to retrieve how many elements a coordinate vector
+ // for the images operations need.
+ fn get_coordinate_vector_size(&self, dim: crate::ImageDimension, arrayed: bool) -> u8 {
+ // openGL es doesn't have 1D images so we need workaround it
+ let tex_1d_hack = dim == crate::ImageDimension::D1 && self.options.version.is_es();
+ // Get how many components the coordinate vector needs for the dimensions only
+ let tex_coord_size = match dim {
+ crate::ImageDimension::D1 => 1,
+ crate::ImageDimension::D2 => 2,
+ crate::ImageDimension::D3 => 3,
+ crate::ImageDimension::Cube => 2,
+ };
+ // Calculate the true size of the coordinate vector by adding 1 for arrayed images
+ // and another 1 if we need to workaround 1D images by making them 2D
+ tex_coord_size + tex_1d_hack as u8 + arrayed as u8
+ }
+
+ /// Helper method to write the coordinate vector for image operations
+ fn write_texture_coord(
+ &mut self,
+ ctx: &back::FunctionCtx,
+ vector_size: u8,
+ coordinate: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ // Emulate 1D images as 2D for profiles that don't support it (glsl es)
+ tex_1d_hack: bool,
+ ) -> Result<(), Error> {
+ match array_index {
+ // If the image needs an array indice we need to add it to the end of our
+ // coordinate vector, to do so we will use the `ivec(ivec, scalar)`
+ // constructor notation (NOTE: the inner `ivec` can also be a scalar, this
+ // is important for 1D arrayed images).
+ Some(layer_expr) => {
+ write!(self.out, "ivec{}(", vector_size)?;
+ self.write_expr(coordinate, ctx)?;
+ write!(self.out, ", ")?;
+ // If we are replacing sampler1D with sampler2D we also need
+ // to add another zero to the coordinates vector for the y component
+ if tex_1d_hack {
+ write!(self.out, "0, ")?;
+ }
+ self.write_expr(layer_expr, ctx)?;
+ write!(self.out, ")")?;
+ }
+ // Otherwise write just the expression (and the 1D hack if needed)
+ None => {
+ if tex_1d_hack {
+ write!(self.out, "ivec2(")?;
+ }
+ self.write_expr(coordinate, ctx)?;
+ if tex_1d_hack {
+ write!(self.out, ", 0)")?;
+ }
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Helper method to write the `ImageStore` statement
+ fn write_image_store(
+ &mut self,
+ ctx: &back::FunctionCtx,
+ image: Handle<crate::Expression>,
+ coordinate: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ value: Handle<crate::Expression>,
+ ) -> Result<(), Error> {
+ use crate::ImageDimension as IDim;
+
+ // NOTE: openGL requires that `imageStore`s have no effets when the texel is invalid
+ // so we don't need to generate bounds checks (OpenGL 4.2 Core §3.9.20)
+
+ // This will only panic if the module is invalid
+ let dim = match *ctx.info[image].ty.inner_with(&self.module.types) {
+ TypeInner::Image { dim, .. } => dim,
+ _ => unreachable!(),
+ };
+
+ // Begin our call to `imageStore`
+ write!(self.out, "imageStore(")?;
+ self.write_expr(image, ctx)?;
+ // Separate the image argument from the coordinates
+ write!(self.out, ", ")?;
+
+ // openGL es doesn't have 1D images so we need workaround it
+ let tex_1d_hack = dim == IDim::D1 && self.options.version.is_es();
+ // Write the coordinate vector
+ self.write_texture_coord(
+ ctx,
+ // Get the size of the coordinate vector
+ self.get_coordinate_vector_size(dim, array_index.is_some()),
+ coordinate,
+ array_index,
+ tex_1d_hack,
+ )?;
+
+ // Separate the coordinate from the value to write and write the expression
+ // of the value to write.
+ write!(self.out, ", ")?;
+ self.write_expr(value, ctx)?;
+ // End the call to `imageStore` and the statement.
+ writeln!(self.out, ");")?;
+
+ Ok(())
+ }
+
+ /// Helper method for writing an `ImageLoad` expression.
+ #[allow(clippy::too_many_arguments)]
+ fn write_image_load(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ ctx: &back::FunctionCtx,
+ image: Handle<crate::Expression>,
+ coordinate: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ sample: Option<Handle<crate::Expression>>,
+ level: Option<Handle<crate::Expression>>,
+ ) -> Result<(), Error> {
+ use crate::ImageDimension as IDim;
+
+ // `ImageLoad` is a bit complicated.
+ // There are two functions one for sampled
+ // images another for storage images, the former uses `texelFetch` and the
+ // latter uses `imageLoad`.
+ //
+ // Furthermore we have `level` which is always `Some` for sampled images
+ // and `None` for storage images, so we end up with two functions:
+ // - `texelFetch(image, coordinate, level)` for sampled images
+ // - `imageLoad(image, coordinate)` for storage images
+ //
+ // Finally we also have to consider bounds checking, for storage images
+ // this is easy since openGL requires that invalid texels always return
+ // 0, for sampled images we need to either verify that all arguments are
+ // in bounds (`ReadZeroSkipWrite`) or make them a valid texel (`Restrict`).
+
+ // This will only panic if the module is invalid
+ let (dim, class) = match *ctx.info[image].ty.inner_with(&self.module.types) {
+ TypeInner::Image {
+ dim,
+ arrayed: _,
+ class,
+ } => (dim, class),
+ _ => unreachable!(),
+ };
+
+ // Get the name of the function to be used for the load operation
+ // and the policy to be used with it.
+ let (fun_name, policy) = match class {
+ // Sampled images inherit the policy from the user passed policies
+ crate::ImageClass::Sampled { .. } => ("texelFetch", self.policies.image),
+ crate::ImageClass::Storage { .. } => {
+ // OpenGL 4.2 Core §3.9.20 defines that out of bounds texels in `imageLoad`s
+ // always return zero values so we don't need to generate bounds checks
+ ("imageLoad", proc::BoundsCheckPolicy::Unchecked)
+ }
+ // TODO: Is there even a function for this?
+ crate::ImageClass::Depth { multi: _ } => {
+ return Err(Error::Custom(
+ "WGSL `textureLoad` from depth textures is not supported in GLSL".to_string(),
+ ))
+ }
+ };
+
+ // openGL es doesn't have 1D images so we need workaround it
+ let tex_1d_hack = dim == IDim::D1 && self.options.version.is_es();
+ // Get the size of the coordinate vector
+ let vector_size = self.get_coordinate_vector_size(dim, array_index.is_some());
+
+ if let proc::BoundsCheckPolicy::ReadZeroSkipWrite = policy {
+ // To write the bounds checks for `ReadZeroSkipWrite` we will use a
+ // ternary operator since we are in the middle of an expression and
+ // need to return a value.
+ //
+ // NOTE: glsl does short circuit when evaluating logical
+ // expressions so we can be sure that after we test a
+ // condition it will be true for the next ones
+
+ // Write parantheses around the ternary operator to prevent problems with
+ // expressions emitted before or after it having more precedence
+ write!(self.out, "(",)?;
+
+ // The lod check needs to precede the size check since we need
+ // to use the lod to get the size of the image at that level.
+ if let Some(level_expr) = level {
+ self.write_expr(level_expr, ctx)?;
+ write!(self.out, " < textureQueryLevels(",)?;
+ self.write_expr(image, ctx)?;
+ // Chain the next check
+ write!(self.out, ") && ")?;
+ }
+
+ // Check that the sample arguments doesn't exceed the number of samples
+ if let Some(sample_expr) = sample {
+ self.write_expr(sample_expr, ctx)?;
+ write!(self.out, " < textureSamples(",)?;
+ self.write_expr(image, ctx)?;
+ // Chain the next check
+ write!(self.out, ") && ")?;
+ }
+
+ // We now need to write the size checks for the coordinates and array index
+ // first we write the comparation function in case the image is 1D non arrayed
+ // (and no 1D to 2D hack was needed) we are comparing scalars so the less than
+ // operator will suffice, but otherwise we'll be comparing two vectors so we'll
+ // need to use the `lessThan` function but it returns a vector of booleans (one
+ // for each comparison) so we need to fold it all in one scalar boolean, since
+ // we want all comparisons to pass we use the `all` function which will only
+ // return `true` if all the elements of the boolean vector are also `true`.
+ //
+ // So we'll end with one of the following forms
+ // - `coord < textureSize(image, lod)` for 1D images
+ // - `all(lessThan(coord, textureSize(image, lod)))` for normal images
+ // - `all(lessThan(ivec(coord, array_index), textureSize(image, lod)))`
+ // for arrayed images
+ // - `all(lessThan(coord, textureSize(image)))` for multi sampled images
+
+ if vector_size != 1 {
+ write!(self.out, "all(lessThan(")?;
+ }
+
+ // Write the coordinate vector
+ self.write_texture_coord(ctx, vector_size, coordinate, array_index, tex_1d_hack)?;
+
+ if vector_size != 1 {
+ // If we used the `lessThan` function we need to separate the
+ // coordinates from the image size.
+ write!(self.out, ", ")?;
+ } else {
+ // If we didn't use it (ie. 1D images) we perform the comparsion
+ // using the less than operator.
+ write!(self.out, " < ")?;
+ }
+
+ // Call `textureSize` to get our image size
+ write!(self.out, "textureSize(")?;
+ self.write_expr(image, ctx)?;
+ // `textureSize` uses the lod as a second argument for mipmapped images
+ if let Some(level_expr) = level {
+ // Separate the image from the lod
+ write!(self.out, ", ")?;
+ self.write_expr(level_expr, ctx)?;
+ }
+ // Close the `textureSize` call
+ write!(self.out, ")")?;
+
+ if vector_size != 1 {
+ // Close the `all` and `lessThan` calls
+ write!(self.out, "))")?;
+ }
+
+ // Finally end the condition part of the ternary operator
+ write!(self.out, " ? ")?;
+ }
+
+ // Begin the call to the function used to load the texel
+ write!(self.out, "{}(", fun_name)?;
+ self.write_expr(image, ctx)?;
+ write!(self.out, ", ")?;
+
+ // If we are using `Restrict` bounds checking we need to pass valid texel
+ // coordinates, to do so we use the `clamp` function to get a value between
+ // 0 and the image size - 1 (indexing begins at 0)
+ if let proc::BoundsCheckPolicy::Restrict = policy {
+ write!(self.out, "clamp(")?;
+ }
+
+ // Write the coordinate vector
+ self.write_texture_coord(ctx, vector_size, coordinate, array_index, tex_1d_hack)?;
+
+ // If we are using `Restrict` bounds checking we need to write the rest of the
+ // clamp we initiated before writing the coordinates.
+ if let proc::BoundsCheckPolicy::Restrict = policy {
+ // Write the min value 0
+ if vector_size == 1 {
+ write!(self.out, ", 0")?;
+ } else {
+ write!(self.out, ", ivec{}(0)", vector_size)?;
+ }
+ // Start the `textureSize` call to use as the max value.
+ write!(self.out, ", textureSize(")?;
+ self.write_expr(image, ctx)?;
+ // If the image is mipmapped we need to add the lod argument to the
+ // `textureSize` call, but this needs to be the clamped lod, this should
+ // have been generated earlier and put in a local.
+ if class.is_mipmapped() {
+ write!(
+ self.out,
+ ", {}{}{}",
+ back::BAKE_PREFIX,
+ handle.index(),
+ CLAMPED_LOD_SUFFIX
+ )?;
+ }
+ // Close the `textureSize` call
+ write!(self.out, ")")?;
+
+ // Subtract 1 from the `textureSize` call since the coordinates are zero based.
+ if vector_size == 1 {
+ write!(self.out, " - 1")?;
+ } else {
+ write!(self.out, " - ivec{}(1)", vector_size)?;
+ }
+
+ // Close the `clamp` call
+ write!(self.out, ")")?;
+
+ // Add the clamped lod (if present) as the second argument to the
+ // image load function.
+ if level.is_some() {
+ write!(
+ self.out,
+ ", {}{}{}",
+ back::BAKE_PREFIX,
+ handle.index(),
+ CLAMPED_LOD_SUFFIX
+ )?;
+ }
+
+ // If a sample argument is needed we need to clamp it between 0 and
+ // the number of samples the image has.
+ if let Some(sample_expr) = sample {
+ write!(self.out, ", clamp(")?;
+ self.write_expr(sample_expr, ctx)?;
+ // Set the min value to 0 and start the call to `textureSamples`
+ write!(self.out, ", 0, textureSamples(")?;
+ self.write_expr(image, ctx)?;
+ // Close the `textureSamples` call, subtract 1 from it since the sample
+ // argument is zero based, and close the `clamp` call
+ writeln!(self.out, ") - 1)")?;
+ }
+ } else if let Some(sample_or_level) = sample.or(level) {
+ // If no bounds checking is need just add the sample or level argument
+ // after the coordinates
+ write!(self.out, ", ")?;
+ self.write_expr(sample_or_level, ctx)?;
+ }
+
+ // Close the image load function.
+ write!(self.out, ")")?;
+
+ // If we were using the `ReadZeroSkipWrite` policy we need to end the first branch
+ // (which is taken if the condition is `true`) with a colon (`:`) and write the
+ // second branch which is just a 0 value.
+ if let proc::BoundsCheckPolicy::ReadZeroSkipWrite = policy {
+ // Get the kind of the output value.
+ let kind = match class {
+ // Only sampled images can reach here since storage images
+ // don't need bounds checks and depth images aren't implmented
+ crate::ImageClass::Sampled { kind, .. } => kind,
+ _ => unreachable!(),
+ };
+
+ // End the first branch
+ write!(self.out, " : ")?;
+ // Write the 0 value
+ write!(self.out, "{}vec4(", glsl_scalar(kind, 4)?.prefix,)?;
+ self.write_zero_init_scalar(kind)?;
+ // Close the zero value constructor
+ write!(self.out, ")")?;
+ // Close the parantheses surrounding our ternary
+ write!(self.out, ")")?;
+ }
+
+ Ok(())
+ }
+
+ fn write_named_expr(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ name: String,
+ ctx: &back::FunctionCtx,
+ ) -> BackendResult {
+ match ctx.info[handle].ty {
+ proc::TypeResolution::Handle(ty_handle) => match self.module.types[ty_handle].inner {
+ TypeInner::Struct { .. } => {
+ let ty_name = &self.names[&NameKey::Type(ty_handle)];
+ write!(self.out, "{}", ty_name)?;
+ }
+ _ => {
+ self.write_type(ty_handle)?;
+ }
+ },
+ proc::TypeResolution::Value(ref inner) => {
+ self.write_value_type(inner)?;
+ }
+ }
+
+ let base_ty_res = &ctx.info[handle].ty;
+ let resolved = base_ty_res.inner_with(&self.module.types);
+
+ write!(self.out, " {}", name)?;
+ if let TypeInner::Array { base, size, .. } = *resolved {
+ self.write_array_size(base, size)?;
+ }
+ write!(self.out, " = ")?;
+ self.write_expr(handle, ctx)?;
+ writeln!(self.out, ";")?;
+ self.named_expressions.insert(handle, name);
+
+ Ok(())
+ }
+
+ /// Helper function that write string with default zero initialization for supported types
+ fn write_zero_init_value(&mut self, ty: Handle<crate::Type>) -> BackendResult {
+ let inner = &self.module.types[ty].inner;
+ match *inner {
+ TypeInner::Scalar { kind, .. } => {
+ self.write_zero_init_scalar(kind)?;
+ }
+ TypeInner::Vector { kind, .. } => {
+ self.write_value_type(inner)?;
+ write!(self.out, "(")?;
+ self.write_zero_init_scalar(kind)?;
+ write!(self.out, ")")?;
+ }
+ TypeInner::Matrix { .. } => {
+ self.write_value_type(inner)?;
+ write!(self.out, "(")?;
+ self.write_zero_init_scalar(crate::ScalarKind::Float)?;
+ write!(self.out, ")")?;
+ }
+ TypeInner::Array { base, size, .. } => {
+ let count = match size
+ .to_indexable_length(self.module)
+ .expect("Bad array size")
+ {
+ proc::IndexableLength::Known(count) => count,
+ proc::IndexableLength::Dynamic => return Ok(()),
+ };
+ self.write_type(base)?;
+ self.write_array_size(base, size)?;
+ write!(self.out, "(")?;
+ for _ in 1..count {
+ self.write_zero_init_value(base)?;
+ write!(self.out, ", ")?;
+ }
+ // write last parameter without comma and space
+ self.write_zero_init_value(base)?;
+ write!(self.out, ")")?;
+ }
+ TypeInner::Struct { ref members, .. } => {
+ let name = &self.names[&NameKey::Type(ty)];
+ write!(self.out, "{}(", name)?;
+ for (i, member) in members.iter().enumerate() {
+ self.write_zero_init_value(member.ty)?;
+ if i != members.len().saturating_sub(1) {
+ write!(self.out, ", ")?;
+ }
+ }
+ write!(self.out, ")")?;
+ }
+ _ => {} // TODO:
+ }
+
+ Ok(())
+ }
+
+ /// Helper function that write string with zero initialization for scalar
+ fn write_zero_init_scalar(&mut self, kind: crate::ScalarKind) -> BackendResult {
+ match kind {
+ crate::ScalarKind::Bool => write!(self.out, "false")?,
+ crate::ScalarKind::Uint => write!(self.out, "0u")?,
+ crate::ScalarKind::Float => write!(self.out, "0.0")?,
+ crate::ScalarKind::Sint => write!(self.out, "0")?,
+ }
+
+ Ok(())
+ }
+
+ /// Helper function that return the glsl storage access string of [`StorageAccess`](crate::StorageAccess)
+ ///
+ /// glsl allows adding both `readonly` and `writeonly` but this means that
+ /// they can only be used to query information about the resource which isn't what
+ /// we want here so when storage access is both `LOAD` and `STORE` add no modifiers
+ fn write_storage_access(&mut self, storage_access: crate::StorageAccess) -> BackendResult {
+ if !storage_access.contains(crate::StorageAccess::STORE) {
+ write!(self.out, "readonly ")?;
+ }
+ if !storage_access.contains(crate::StorageAccess::LOAD) {
+ write!(self.out, "writeonly ")?;
+ }
+ Ok(())
+ }
+
+ /// Helper method used to produce the reflection info that's returned to the user
+ fn collect_reflection_info(&self) -> Result<ReflectionInfo, Error> {
+ use std::collections::hash_map::Entry;
+ let info = self.info.get_entry_point(self.entry_point_idx as usize);
+ let mut texture_mapping = crate::FastHashMap::default();
+ let mut uniforms = crate::FastHashMap::default();
+
+ for sampling in info.sampling_set.iter() {
+ let tex_name = self.reflection_names_globals[&sampling.image].clone();
+
+ match texture_mapping.entry(tex_name) {
+ Entry::Vacant(v) => {
+ v.insert(TextureMapping {
+ texture: sampling.image,
+ sampler: Some(sampling.sampler),
+ });
+ }
+ Entry::Occupied(e) => {
+ if e.get().sampler != Some(sampling.sampler) {
+ log::error!("Conflicting samplers for {}", e.key());
+ return Err(Error::ImageMultipleSamplers);
+ }
+ }
+ }
+ }
+
+ for (handle, var) in self.module.global_variables.iter() {
+ if info[handle].is_empty() {
+ continue;
+ }
+ match self.module.types[var.ty].inner {
+ crate::TypeInner::Struct { .. } => match var.space {
+ crate::AddressSpace::Uniform | crate::AddressSpace::Storage { .. } => {
+ let name = self.reflection_names_globals[&handle].clone();
+ uniforms.insert(handle, name);
+ }
+ _ => (),
+ },
+ crate::TypeInner::Image { .. } => {
+ let tex_name = self.reflection_names_globals[&handle].clone();
+ match texture_mapping.entry(tex_name) {
+ Entry::Vacant(v) => {
+ v.insert(TextureMapping {
+ texture: handle,
+ sampler: None,
+ });
+ }
+ Entry::Occupied(_) => {
+ // already used with a sampler, do nothing
+ }
+ }
+ }
+ _ => {}
+ }
+ }
+
+ Ok(ReflectionInfo {
+ texture_mapping,
+ uniforms,
+ })
+ }
+}
+
+/// Structure returned by [`glsl_scalar`](glsl_scalar)
+///
+/// It contains both a prefix used in other types and the full type name
+struct ScalarString<'a> {
+ /// The prefix used to compose other types
+ prefix: &'a str,
+ /// The name of the scalar type
+ full: &'a str,
+}
+
+/// Helper function that returns scalar related strings
+///
+/// Check [`ScalarString`](ScalarString) for the information provided
+///
+/// # Errors
+/// If a [`Float`](crate::ScalarKind::Float) with an width that isn't 4 or 8
+const fn glsl_scalar(
+ kind: crate::ScalarKind,
+ width: crate::Bytes,
+) -> Result<ScalarString<'static>, Error> {
+ use crate::ScalarKind as Sk;
+
+ Ok(match kind {
+ Sk::Sint => ScalarString {
+ prefix: "i",
+ full: "int",
+ },
+ Sk::Uint => ScalarString {
+ prefix: "u",
+ full: "uint",
+ },
+ Sk::Float => match width {
+ 4 => ScalarString {
+ prefix: "",
+ full: "float",
+ },
+ 8 => ScalarString {
+ prefix: "d",
+ full: "double",
+ },
+ _ => return Err(Error::UnsupportedScalar(kind, width)),
+ },
+ Sk::Bool => ScalarString {
+ prefix: "b",
+ full: "bool",
+ },
+ })
+}
+
+/// Helper function that returns the glsl variable name for a builtin
+const fn glsl_built_in(
+ built_in: crate::BuiltIn,
+ output: bool,
+ targetting_webgl: bool,
+) -> &'static str {
+ use crate::BuiltIn as Bi;
+
+ match built_in {
+ Bi::Position { .. } => {
+ if output {
+ "gl_Position"
+ } else {
+ "gl_FragCoord"
+ }
+ }
+ Bi::ViewIndex if targetting_webgl => "int(gl_ViewID_OVR)",
+ Bi::ViewIndex => "gl_ViewIndex",
+ // vertex
+ Bi::BaseInstance => "uint(gl_BaseInstance)",
+ Bi::BaseVertex => "uint(gl_BaseVertex)",
+ Bi::ClipDistance => "gl_ClipDistance",
+ Bi::CullDistance => "gl_CullDistance",
+ Bi::InstanceIndex => "uint(gl_InstanceID)",
+ Bi::PointSize => "gl_PointSize",
+ Bi::VertexIndex => "uint(gl_VertexID)",
+ // fragment
+ Bi::FragDepth => "gl_FragDepth",
+ Bi::FrontFacing => "gl_FrontFacing",
+ Bi::PrimitiveIndex => "uint(gl_PrimitiveID)",
+ Bi::SampleIndex => "gl_SampleID",
+ Bi::SampleMask => {
+ if output {
+ "gl_SampleMask"
+ } else {
+ "gl_SampleMaskIn"
+ }
+ }
+ // compute
+ Bi::GlobalInvocationId => "gl_GlobalInvocationID",
+ Bi::LocalInvocationId => "gl_LocalInvocationID",
+ Bi::LocalInvocationIndex => "gl_LocalInvocationIndex",
+ Bi::WorkGroupId => "gl_WorkGroupID",
+ Bi::WorkGroupSize => "gl_WorkGroupSize",
+ Bi::NumWorkGroups => "gl_NumWorkGroups",
+ }
+}
+
+/// Helper function that returns the string corresponding to the address space
+const fn glsl_storage_qualifier(space: crate::AddressSpace) -> Option<&'static str> {
+ use crate::AddressSpace as As;
+
+ match space {
+ As::Function => None,
+ As::Private => None,
+ As::Storage { .. } => Some("buffer"),
+ As::Uniform => Some("uniform"),
+ As::Handle => Some("uniform"),
+ As::WorkGroup => Some("shared"),
+ As::PushConstant => Some("uniform"),
+ }
+}
+
+/// Helper function that returns the string corresponding to the glsl interpolation qualifier
+const fn glsl_interpolation(interpolation: crate::Interpolation) -> &'static str {
+ use crate::Interpolation as I;
+
+ match interpolation {
+ I::Perspective => "smooth",
+ I::Linear => "noperspective",
+ I::Flat => "flat",
+ }
+}
+
+/// Return the GLSL auxiliary qualifier for the given sampling value.
+const fn glsl_sampling(sampling: crate::Sampling) -> Option<&'static str> {
+ use crate::Sampling as S;
+
+ match sampling {
+ S::Center => None,
+ S::Centroid => Some("centroid"),
+ S::Sample => Some("sample"),
+ }
+}
+
+/// Helper function that returns the glsl dimension string of [`ImageDimension`](crate::ImageDimension)
+const fn glsl_dimension(dim: crate::ImageDimension) -> &'static str {
+ use crate::ImageDimension as IDim;
+
+ match dim {
+ IDim::D1 => "1D",
+ IDim::D2 => "2D",
+ IDim::D3 => "3D",
+ IDim::Cube => "Cube",
+ }
+}
+
+/// Helper function that returns the glsl storage format string of [`StorageFormat`](crate::StorageFormat)
+const fn glsl_storage_format(format: crate::StorageFormat) -> &'static str {
+ use crate::StorageFormat as Sf;
+
+ match format {
+ Sf::R8Unorm => "r8",
+ Sf::R8Snorm => "r8_snorm",
+ Sf::R8Uint => "r8ui",
+ Sf::R8Sint => "r8i",
+ Sf::R16Uint => "r16ui",
+ Sf::R16Sint => "r16i",
+ Sf::R16Float => "r16f",
+ Sf::Rg8Unorm => "rg8",
+ Sf::Rg8Snorm => "rg8_snorm",
+ Sf::Rg8Uint => "rg8ui",
+ Sf::Rg8Sint => "rg8i",
+ Sf::R32Uint => "r32ui",
+ Sf::R32Sint => "r32i",
+ Sf::R32Float => "r32f",
+ Sf::Rg16Uint => "rg16ui",
+ Sf::Rg16Sint => "rg16i",
+ Sf::Rg16Float => "rg16f",
+ Sf::Rgba8Unorm => "rgba8",
+ Sf::Rgba8Snorm => "rgba8_snorm",
+ Sf::Rgba8Uint => "rgba8ui",
+ Sf::Rgba8Sint => "rgba8i",
+ Sf::Rgb10a2Unorm => "rgb10_a2ui",
+ Sf::Rg11b10Float => "r11f_g11f_b10f",
+ Sf::Rg32Uint => "rg32ui",
+ Sf::Rg32Sint => "rg32i",
+ Sf::Rg32Float => "rg32f",
+ Sf::Rgba16Uint => "rgba16ui",
+ Sf::Rgba16Sint => "rgba16i",
+ Sf::Rgba16Float => "rgba16f",
+ Sf::Rgba32Uint => "rgba32ui",
+ Sf::Rgba32Sint => "rgba32i",
+ Sf::Rgba32Float => "rgba32f",
+ }
+}
+
+fn is_value_init_supported(module: &crate::Module, ty: Handle<crate::Type>) -> bool {
+ match module.types[ty].inner {
+ TypeInner::Scalar { .. } | TypeInner::Vector { .. } | TypeInner::Matrix { .. } => true,
+ TypeInner::Array { base, size, .. } => {
+ size != crate::ArraySize::Dynamic && is_value_init_supported(module, base)
+ }
+ TypeInner::Struct { ref members, .. } => members
+ .iter()
+ .all(|member| is_value_init_supported(module, member.ty)),
+ _ => false,
+ }
+}
diff --git a/third_party/rust/naga/src/back/hlsl/conv.rs b/third_party/rust/naga/src/back/hlsl/conv.rs
new file mode 100644
index 0000000000..039bfcce30
--- /dev/null
+++ b/third_party/rust/naga/src/back/hlsl/conv.rs
@@ -0,0 +1,227 @@
+use std::borrow::Cow;
+
+use crate::proc::Alignment;
+
+use super::Error;
+
+impl crate::ScalarKind {
+ pub(super) fn to_hlsl_cast(self) -> &'static str {
+ match self {
+ Self::Float => "asfloat",
+ Self::Sint => "asint",
+ Self::Uint => "asuint",
+ Self::Bool => unreachable!(),
+ }
+ }
+
+ /// Helper function that returns scalar related strings
+ ///
+ /// <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-scalar>
+ pub(super) const fn to_hlsl_str(self, width: crate::Bytes) -> Result<&'static str, Error> {
+ match self {
+ Self::Sint => Ok("int"),
+ Self::Uint => Ok("uint"),
+ Self::Float => match width {
+ 2 => Ok("half"),
+ 4 => Ok("float"),
+ 8 => Ok("double"),
+ _ => Err(Error::UnsupportedScalar(self, width)),
+ },
+ Self::Bool => Ok("bool"),
+ }
+ }
+}
+
+impl crate::TypeInner {
+ pub(super) const fn is_matrix(&self) -> bool {
+ match *self {
+ Self::Matrix { .. } => true,
+ _ => false,
+ }
+ }
+
+ pub(super) fn try_size_hlsl(
+ &self,
+ types: &crate::UniqueArena<crate::Type>,
+ constants: &crate::Arena<crate::Constant>,
+ ) -> Result<u32, crate::arena::BadHandle> {
+ Ok(match *self {
+ Self::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ let stride = Alignment::from(rows) * width as u32;
+ let last_row_size = rows as u32 * width as u32;
+ ((columns as u32 - 1) * stride) + last_row_size
+ }
+ Self::Array { base, size, stride } => {
+ let count = match size {
+ crate::ArraySize::Constant(handle) => {
+ let constant = constants.try_get(handle)?;
+ constant.to_array_length().unwrap_or(1)
+ }
+ // A dynamically-sized array has to have at least one element
+ crate::ArraySize::Dynamic => 1,
+ };
+ let last_el_size = types[base].inner.try_size_hlsl(types, constants)?;
+ ((count - 1) * stride) + last_el_size
+ }
+ _ => self.try_size(constants)?,
+ })
+ }
+
+ /// Used to generate the name of the wrapped type constructor
+ pub(super) fn hlsl_type_id<'a>(
+ &self,
+ base: crate::Handle<crate::Type>,
+ types: &crate::UniqueArena<crate::Type>,
+ constants: &crate::Arena<crate::Constant>,
+ names: &'a crate::FastHashMap<crate::proc::NameKey, String>,
+ ) -> Result<Cow<'a, str>, Error> {
+ Ok(match types[base].inner {
+ crate::TypeInner::Scalar { kind, width } => Cow::Borrowed(kind.to_hlsl_str(width)?),
+ crate::TypeInner::Vector { size, kind, width } => Cow::Owned(format!(
+ "{}{}",
+ kind.to_hlsl_str(width)?,
+ crate::back::vector_size_str(size)
+ )),
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => Cow::Owned(format!(
+ "{}{}x{}",
+ crate::ScalarKind::Float.to_hlsl_str(width)?,
+ crate::back::vector_size_str(columns),
+ crate::back::vector_size_str(rows),
+ )),
+ crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(size),
+ ..
+ } => Cow::Owned(format!(
+ "array{}_{}_",
+ constants[size].to_array_length().unwrap(),
+ self.hlsl_type_id(base, types, constants, names)?
+ )),
+ crate::TypeInner::Struct { .. } => {
+ Cow::Borrowed(&names[&crate::proc::NameKey::Type(base)])
+ }
+ _ => unreachable!(),
+ })
+ }
+}
+
+impl crate::StorageFormat {
+ pub(super) const fn to_hlsl_str(self) -> &'static str {
+ match self {
+ Self::R16Float => "float",
+ Self::R8Unorm => "unorm float",
+ Self::R8Snorm => "snorm float",
+ Self::R8Uint | Self::R16Uint => "uint",
+ Self::R8Sint | Self::R16Sint => "int",
+
+ Self::Rg16Float => "float2",
+ Self::Rg8Unorm => "unorm float2",
+ Self::Rg8Snorm => "snorm float2",
+
+ Self::Rg8Sint | Self::Rg16Sint => "int2",
+ Self::Rg8Uint | Self::Rg16Uint => "uint2",
+
+ Self::Rg11b10Float => "float3",
+
+ Self::Rgba16Float | Self::R32Float | Self::Rg32Float | Self::Rgba32Float => "float4",
+ Self::Rgba8Unorm | Self::Rgb10a2Unorm => "unorm float4",
+ Self::Rgba8Snorm => "snorm float4",
+
+ Self::Rgba8Uint
+ | Self::Rgba16Uint
+ | Self::R32Uint
+ | Self::Rg32Uint
+ | Self::Rgba32Uint => "uint4",
+ Self::Rgba8Sint
+ | Self::Rgba16Sint
+ | Self::R32Sint
+ | Self::Rg32Sint
+ | Self::Rgba32Sint => "int4",
+ }
+ }
+}
+
+impl crate::BuiltIn {
+ pub(super) fn to_hlsl_str(self) -> Result<&'static str, Error> {
+ Ok(match self {
+ Self::Position { .. } => "SV_Position",
+ // vertex
+ Self::ClipDistance => "SV_ClipDistance",
+ Self::CullDistance => "SV_CullDistance",
+ Self::InstanceIndex => "SV_InstanceID",
+ // based on this page https://docs.microsoft.com/en-us/windows/uwp/gaming/glsl-to-hlsl-reference#comparing-opengl-es-20-with-direct3d-11
+ // No meaning unless you target Direct3D 9
+ Self::PointSize => "PSIZE",
+ Self::VertexIndex => "SV_VertexID",
+ // fragment
+ Self::FragDepth => "SV_Depth",
+ Self::FrontFacing => "SV_IsFrontFace",
+ Self::PrimitiveIndex => "SV_PrimitiveID",
+ Self::SampleIndex => "SV_SampleIndex",
+ Self::SampleMask => "SV_Coverage",
+ // compute
+ Self::GlobalInvocationId => "SV_DispatchThreadID",
+ Self::LocalInvocationId => "SV_GroupThreadID",
+ Self::LocalInvocationIndex => "SV_GroupIndex",
+ Self::WorkGroupId => "SV_GroupID",
+ // The specific semantic we use here doesn't matter, because references
+ // to this field will get replaced with references to `SPECIAL_CBUF_VAR`
+ // in `Writer::write_expr`.
+ Self::NumWorkGroups => "SV_GroupID",
+ Self::BaseInstance | Self::BaseVertex | Self::WorkGroupSize => {
+ return Err(Error::Unimplemented(format!("builtin {:?}", self)))
+ }
+ Self::ViewIndex => {
+ return Err(Error::Custom(format!("Unsupported builtin {:?}", self)))
+ }
+ })
+ }
+}
+
+impl crate::Interpolation {
+ /// Return the string corresponding to the HLSL interpolation qualifier.
+ pub(super) const fn to_hlsl_str(self) -> Option<&'static str> {
+ match self {
+ // Would be "linear", but it's the default interpolation in SM4 and up
+ // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-struct#interpolation-modifiers-introduced-in-shader-model-4
+ Self::Perspective => None,
+ Self::Linear => Some("noperspective"),
+ Self::Flat => Some("nointerpolation"),
+ }
+ }
+}
+
+impl crate::Sampling {
+ /// Return the HLSL auxiliary qualifier for the given sampling value.
+ pub(super) const fn to_hlsl_str(self) -> Option<&'static str> {
+ match self {
+ Self::Center => None,
+ Self::Centroid => Some("centroid"),
+ Self::Sample => Some("sample"),
+ }
+ }
+}
+
+impl crate::AtomicFunction {
+ /// Return the HLSL suffix for the `InterlockedXxx` method.
+ pub(super) const fn to_hlsl_suffix(self) -> &'static str {
+ match self {
+ Self::Add | Self::Subtract => "Add",
+ Self::And => "And",
+ Self::InclusiveOr => "Or",
+ Self::ExclusiveOr => "Xor",
+ Self::Min => "Min",
+ Self::Max => "Max",
+ Self::Exchange { compare: None } => "Exchange",
+ Self::Exchange { .. } => "", //TODO
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/back/hlsl/help.rs b/third_party/rust/naga/src/back/hlsl/help.rs
new file mode 100644
index 0000000000..ec913ba66d
--- /dev/null
+++ b/third_party/rust/naga/src/back/hlsl/help.rs
@@ -0,0 +1,1195 @@
+/*!
+Helpers for the hlsl backend
+
+Important note about `Expression::ImageQuery`/`Expression::ArrayLength` and hlsl backend:
+
+Due to implementation of `GetDimensions` function in hlsl (<https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-getdimensions>)
+backend can't work with it as an expression.
+Instead, it generates a unique wrapped function per `Expression::ImageQuery`, based on texture info and query function.
+See `WrappedImageQuery` struct that represents a unique function and will be generated before writing all statements and expressions.
+This allowed to works with `Expression::ImageQuery` as expression and write wrapped function.
+
+For example:
+```wgsl
+let dim_1d = textureDimensions(image_1d);
+```
+
+```hlsl
+int NagaDimensions1D(Texture1D<float4>)
+{
+ uint4 ret;
+ image_1d.GetDimensions(ret.x);
+ return ret.x;
+}
+
+int dim_1d = NagaDimensions1D(image_1d);
+```
+*/
+
+use super::{super::FunctionCtx, BackendResult};
+use crate::{arena::Handle, proc::NameKey};
+use std::fmt::Write;
+
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+pub(super) struct WrappedArrayLength {
+ pub(super) writable: bool,
+}
+
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+pub(super) struct WrappedImageQuery {
+ pub(super) dim: crate::ImageDimension,
+ pub(super) arrayed: bool,
+ pub(super) class: crate::ImageClass,
+ pub(super) query: ImageQuery,
+}
+
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+pub(super) struct WrappedConstructor {
+ pub(super) ty: Handle<crate::Type>,
+}
+
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+pub(super) struct WrappedStructMatrixAccess {
+ pub(super) ty: Handle<crate::Type>,
+ pub(super) index: u32,
+}
+
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+pub(super) struct WrappedMatCx2 {
+ pub(super) columns: crate::VectorSize,
+}
+
+/// HLSL backend requires its own `ImageQuery` enum.
+///
+/// It is used inside `WrappedImageQuery` and should be unique per ImageQuery function.
+/// IR version can't be unique per function, because it's store mipmap level as an expression.
+///
+/// For example:
+/// ```wgsl
+/// let dim_cube_array_lod = textureDimensions(image_cube_array, 1);
+/// let dim_cube_array_lod2 = textureDimensions(image_cube_array, 1);
+/// ```
+///
+/// ```ir
+/// ImageQuery {
+/// image: [1],
+/// query: Size {
+/// level: Some(
+/// [1],
+/// ),
+/// },
+/// },
+/// ImageQuery {
+/// image: [1],
+/// query: Size {
+/// level: Some(
+/// [2],
+/// ),
+/// },
+/// },
+/// ```
+///
+/// HLSL should generate only 1 function for this case.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+pub(super) enum ImageQuery {
+ Size,
+ SizeLevel,
+ NumLevels,
+ NumLayers,
+ NumSamples,
+}
+
+impl From<crate::ImageQuery> for ImageQuery {
+ fn from(q: crate::ImageQuery) -> Self {
+ use crate::ImageQuery as Iq;
+ match q {
+ Iq::Size { level: Some(_) } => ImageQuery::SizeLevel,
+ Iq::Size { level: None } => ImageQuery::Size,
+ Iq::NumLevels => ImageQuery::NumLevels,
+ Iq::NumLayers => ImageQuery::NumLayers,
+ Iq::NumSamples => ImageQuery::NumSamples,
+ }
+ }
+}
+
+impl<'a, W: Write> super::Writer<'a, W> {
+ pub(super) fn write_image_type(
+ &mut self,
+ dim: crate::ImageDimension,
+ arrayed: bool,
+ class: crate::ImageClass,
+ ) -> BackendResult {
+ let access_str = match class {
+ crate::ImageClass::Storage { .. } => "RW",
+ _ => "",
+ };
+ let dim_str = dim.to_hlsl_str();
+ let arrayed_str = if arrayed { "Array" } else { "" };
+ write!(self.out, "{}Texture{}{}", access_str, dim_str, arrayed_str)?;
+ match class {
+ crate::ImageClass::Depth { multi } => {
+ let multi_str = if multi { "MS" } else { "" };
+ write!(self.out, "{}<float>", multi_str)?
+ }
+ crate::ImageClass::Sampled { kind, multi } => {
+ let multi_str = if multi { "MS" } else { "" };
+ let scalar_kind_str = kind.to_hlsl_str(4)?;
+ write!(self.out, "{}<{}4>", multi_str, scalar_kind_str)?
+ }
+ crate::ImageClass::Storage { format, .. } => {
+ let storage_format_str = format.to_hlsl_str();
+ write!(self.out, "<{}>", storage_format_str)?
+ }
+ }
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_array_length_function_name(
+ &mut self,
+ query: WrappedArrayLength,
+ ) -> BackendResult {
+ let access_str = if query.writable { "RW" } else { "" };
+ write!(self.out, "NagaBufferLength{}", access_str,)?;
+
+ Ok(())
+ }
+
+ /// Helper function that write wrapped function for `Expression::ArrayLength`
+ ///
+ /// <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-object-rwbyteaddressbuffer-getdimensions>
+ pub(super) fn write_wrapped_array_length_function(
+ &mut self,
+ module: &crate::Module,
+ wal: WrappedArrayLength,
+ expr_handle: Handle<crate::Expression>,
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ const ARGUMENT_VARIABLE_NAME: &str = "buffer";
+ const RETURN_VARIABLE_NAME: &str = "ret";
+
+ // Write function return type and name
+ let ret_ty = func_ctx.info[expr_handle].ty.inner_with(&module.types);
+ self.write_value_type(module, ret_ty)?;
+ write!(self.out, " ")?;
+ self.write_wrapped_array_length_function_name(wal)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+ let access_str = if wal.writable { "RW" } else { "" };
+ writeln!(
+ self.out,
+ "{}ByteAddressBuffer {})",
+ access_str, ARGUMENT_VARIABLE_NAME
+ )?;
+ // Write function body
+ writeln!(self.out, "{{")?;
+
+ // Write `GetDimensions` function.
+ writeln!(self.out, "{}uint {};", INDENT, RETURN_VARIABLE_NAME)?;
+ writeln!(
+ self.out,
+ "{}{}.GetDimensions({});",
+ INDENT, ARGUMENT_VARIABLE_NAME, RETURN_VARIABLE_NAME
+ )?;
+
+ // Write return value
+ writeln!(self.out, "{}return {};", INDENT, RETURN_VARIABLE_NAME)?;
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_image_query_function_name(
+ &mut self,
+ query: WrappedImageQuery,
+ ) -> BackendResult {
+ let dim_str = query.dim.to_hlsl_str();
+ let class_str = match query.class {
+ crate::ImageClass::Sampled { multi: true, .. } => "MS",
+ crate::ImageClass::Depth { multi: true } => "DepthMS",
+ crate::ImageClass::Depth { multi: false } => "Depth",
+ crate::ImageClass::Sampled { multi: false, .. } => "",
+ crate::ImageClass::Storage { .. } => "RW",
+ };
+ let arrayed_str = if query.arrayed { "Array" } else { "" };
+ let query_str = match query.query {
+ ImageQuery::Size => "Dimensions",
+ ImageQuery::SizeLevel => "MipDimensions",
+ ImageQuery::NumLevels => "NumLevels",
+ ImageQuery::NumLayers => "NumLayers",
+ ImageQuery::NumSamples => "NumSamples",
+ };
+
+ write!(
+ self.out,
+ "Naga{}{}{}{}",
+ class_str, query_str, dim_str, arrayed_str
+ )?;
+
+ Ok(())
+ }
+
+ /// Helper function that write wrapped function for `Expression::ImageQuery`
+ ///
+ /// <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-getdimensions>
+ pub(super) fn write_wrapped_image_query_function(
+ &mut self,
+ module: &crate::Module,
+ wiq: WrappedImageQuery,
+ expr_handle: Handle<crate::Expression>,
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ use crate::{
+ back::{COMPONENTS, INDENT},
+ ImageDimension as IDim,
+ };
+
+ const ARGUMENT_VARIABLE_NAME: &str = "tex";
+ const RETURN_VARIABLE_NAME: &str = "ret";
+ const MIP_LEVEL_PARAM: &str = "mip_level";
+
+ // Write function return type and name
+ let ret_ty = func_ctx.info[expr_handle].ty.inner_with(&module.types);
+ self.write_value_type(module, ret_ty)?;
+ write!(self.out, " ")?;
+ self.write_wrapped_image_query_function_name(wiq)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+ // Texture always first parameter
+ self.write_image_type(wiq.dim, wiq.arrayed, wiq.class)?;
+ write!(self.out, " {}", ARGUMENT_VARIABLE_NAME)?;
+ // Mipmap is a second parameter if exists
+ if let ImageQuery::SizeLevel = wiq.query {
+ write!(self.out, ", uint {}", MIP_LEVEL_PARAM)?;
+ }
+ writeln!(self.out, ")")?;
+
+ // Write function body
+ writeln!(self.out, "{{")?;
+
+ let array_coords = if wiq.arrayed { 1 } else { 0 };
+ // extra parameter is the mip level count or the sample count
+ let extra_coords = match wiq.class {
+ crate::ImageClass::Storage { .. } => 0,
+ crate::ImageClass::Sampled { .. } | crate::ImageClass::Depth { .. } => 1,
+ };
+
+ // GetDimensions Overloaded Methods
+ // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-getdimensions#overloaded-methods
+ let (ret_swizzle, number_of_params) = match wiq.query {
+ ImageQuery::Size | ImageQuery::SizeLevel => {
+ let ret = match wiq.dim {
+ IDim::D1 => "x",
+ IDim::D2 => "xy",
+ IDim::D3 => "xyz",
+ IDim::Cube => "xy",
+ };
+ (ret, ret.len() + array_coords + extra_coords)
+ }
+ ImageQuery::NumLevels | ImageQuery::NumSamples | ImageQuery::NumLayers => {
+ if wiq.arrayed || wiq.dim == IDim::D3 {
+ ("w", 4)
+ } else {
+ ("z", 3)
+ }
+ }
+ };
+
+ // Write `GetDimensions` function.
+ writeln!(self.out, "{}uint4 {};", INDENT, RETURN_VARIABLE_NAME)?;
+ write!(
+ self.out,
+ "{}{}.GetDimensions(",
+ INDENT, ARGUMENT_VARIABLE_NAME
+ )?;
+ match wiq.query {
+ ImageQuery::SizeLevel => {
+ write!(self.out, "{}, ", MIP_LEVEL_PARAM)?;
+ }
+ _ => match wiq.class {
+ crate::ImageClass::Sampled { multi: true, .. }
+ | crate::ImageClass::Depth { multi: true }
+ | crate::ImageClass::Storage { .. } => {}
+ _ => {
+ // Write zero mipmap level for supported types
+ write!(self.out, "0, ")?;
+ }
+ },
+ }
+
+ for component in COMPONENTS[..number_of_params - 1].iter() {
+ write!(self.out, "{}.{}, ", RETURN_VARIABLE_NAME, component)?;
+ }
+
+ // write last parameter without comma and space for last parameter
+ write!(
+ self.out,
+ "{}.{}",
+ RETURN_VARIABLE_NAME,
+ COMPONENTS[number_of_params - 1]
+ )?;
+
+ writeln!(self.out, ");")?;
+
+ // Write return value
+ writeln!(
+ self.out,
+ "{}return {}.{};",
+ INDENT, RETURN_VARIABLE_NAME, ret_swizzle
+ )?;
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_constructor_function_name(
+ &mut self,
+ module: &crate::Module,
+ constructor: WrappedConstructor,
+ ) -> BackendResult {
+ let name = module.types[constructor.ty].inner.hlsl_type_id(
+ constructor.ty,
+ &module.types,
+ &module.constants,
+ &self.names,
+ )?;
+ write!(self.out, "Construct{}", name)?;
+ Ok(())
+ }
+
+ /// Helper function that write wrapped function for `Expression::Compose` for structures.
+ pub(super) fn write_wrapped_constructor_function(
+ &mut self,
+ module: &crate::Module,
+ constructor: WrappedConstructor,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ const ARGUMENT_VARIABLE_NAME: &str = "arg";
+ const RETURN_VARIABLE_NAME: &str = "ret";
+
+ // Write function return type and name
+ if let crate::TypeInner::Array { base, size, .. } = module.types[constructor.ty].inner {
+ write!(self.out, "typedef ")?;
+ self.write_type(module, constructor.ty)?;
+ write!(self.out, " ret_")?;
+ self.write_wrapped_constructor_function_name(module, constructor)?;
+ self.write_array_size(module, base, size)?;
+ writeln!(self.out, ";")?;
+
+ write!(self.out, "ret_")?;
+ self.write_wrapped_constructor_function_name(module, constructor)?;
+ } else {
+ self.write_type(module, constructor.ty)?;
+ }
+ write!(self.out, " ")?;
+ self.write_wrapped_constructor_function_name(module, constructor)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+
+ let mut write_arg = |i, ty| -> BackendResult {
+ if i != 0 {
+ write!(self.out, ", ")?;
+ }
+ self.write_type(module, ty)?;
+ write!(self.out, " {}{}", ARGUMENT_VARIABLE_NAME, i)?;
+ if let crate::TypeInner::Array { base, size, .. } = module.types[ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+ Ok(())
+ };
+
+ match module.types[constructor.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ for (i, member) in members.iter().enumerate() {
+ write_arg(i, member.ty)?;
+ }
+ }
+ crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(size),
+ ..
+ } => {
+ let count = module.constants[size].to_array_length().unwrap();
+ for i in 0..count as usize {
+ write_arg(i, base)?;
+ }
+ }
+ _ => unreachable!(),
+ };
+
+ write!(self.out, ")")?;
+
+ // Write function body
+ writeln!(self.out, " {{")?;
+
+ match module.types[constructor.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ let struct_name = &self.names[&NameKey::Type(constructor.ty)];
+ writeln!(
+ self.out,
+ "{}{} {} = ({})0;",
+ INDENT, struct_name, RETURN_VARIABLE_NAME, struct_name
+ )?;
+ for (i, member) in members.iter().enumerate() {
+ let field_name = &self.names[&NameKey::StructMember(constructor.ty, i as u32)];
+
+ match module.types[member.ty].inner {
+ crate::TypeInner::Matrix {
+ columns,
+ rows: crate::VectorSize::Bi,
+ ..
+ } if member.binding.is_none() => {
+ for j in 0..columns as u8 {
+ writeln!(
+ self.out,
+ "{}{}.{}_{} = {}{}[{}];",
+ INDENT,
+ RETURN_VARIABLE_NAME,
+ field_name,
+ j,
+ ARGUMENT_VARIABLE_NAME,
+ i,
+ j
+ )?;
+ }
+ }
+ ref other => {
+ // We cast arrays of native HLSL `floatCx2`s to arrays of `matCx2`s
+ // (where the inner matrix is represented by a struct with C `float2` members).
+ // See the module-level block comment in mod.rs for details.
+ if let Some(super::writer::MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = super::writer::get_inner_matrix_data(module, member.ty)
+ {
+ write!(
+ self.out,
+ "{}{}.{} = (__mat{}x2",
+ INDENT, RETURN_VARIABLE_NAME, field_name, columns as u8
+ )?;
+ if let crate::TypeInner::Array { base, size, .. } = *other {
+ self.write_array_size(module, base, size)?;
+ }
+ writeln!(self.out, "){}{};", ARGUMENT_VARIABLE_NAME, i,)?;
+ } else {
+ writeln!(
+ self.out,
+ "{}{}.{} = {}{};",
+ INDENT,
+ RETURN_VARIABLE_NAME,
+ field_name,
+ ARGUMENT_VARIABLE_NAME,
+ i,
+ )?;
+ }
+ }
+ }
+ }
+ }
+ crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(size),
+ ..
+ } => {
+ write!(self.out, "{}", INDENT)?;
+ self.write_type(module, base)?;
+ write!(self.out, " {}", RETURN_VARIABLE_NAME)?;
+ self.write_array_size(module, base, crate::ArraySize::Constant(size))?;
+ write!(self.out, " = {{ ")?;
+ let count = module.constants[size].to_array_length().unwrap();
+ for i in 0..count {
+ if i != 0 {
+ write!(self.out, ", ")?;
+ }
+ write!(self.out, "{}{}", ARGUMENT_VARIABLE_NAME, i)?;
+ }
+ writeln!(self.out, " }};",)?;
+ }
+ _ => unreachable!(),
+ }
+
+ // Write return value
+ writeln!(self.out, "{}return {};", INDENT, RETURN_VARIABLE_NAME)?;
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_struct_matrix_get_function_name(
+ &mut self,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ let name = &self.names[&NameKey::Type(access.ty)];
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+ write!(self.out, "GetMat{}On{}", field_name, name)?;
+ Ok(())
+ }
+
+ /// Writes a function used to get a matCx2 from within a structure.
+ pub(super) fn write_wrapped_struct_matrix_get_function(
+ &mut self,
+ module: &crate::Module,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj";
+
+ // Write function return type and name
+ let member = match module.types[access.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize],
+ _ => unreachable!(),
+ };
+ let ret_ty = &module.types[member.ty].inner;
+ self.write_value_type(module, ret_ty)?;
+ write!(self.out, " ")?;
+ self.write_wrapped_struct_matrix_get_function_name(access)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+ let struct_name = &self.names[&NameKey::Type(access.ty)];
+ write!(
+ self.out,
+ "{} {}",
+ struct_name, STRUCT_ARGUMENT_VARIABLE_NAME
+ )?;
+
+ // Write function body
+ writeln!(self.out, ") {{")?;
+
+ // Write return value
+ write!(self.out, "{}return ", INDENT)?;
+ self.write_value_type(module, ret_ty)?;
+ write!(self.out, "(")?;
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+ match module.types[member.ty].inner {
+ crate::TypeInner::Matrix { columns, .. } => {
+ for i in 0..columns as u8 {
+ if i != 0 {
+ write!(self.out, ", ")?;
+ }
+ write!(
+ self.out,
+ "{}.{}_{}",
+ STRUCT_ARGUMENT_VARIABLE_NAME, field_name, i
+ )?;
+ }
+ }
+ _ => unreachable!(),
+ }
+ writeln!(self.out, ");")?;
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_struct_matrix_set_function_name(
+ &mut self,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ let name = &self.names[&NameKey::Type(access.ty)];
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+ write!(self.out, "SetMat{}On{}", field_name, name)?;
+ Ok(())
+ }
+
+ /// Writes a function used to set a matCx2 from within a structure.
+ pub(super) fn write_wrapped_struct_matrix_set_function(
+ &mut self,
+ module: &crate::Module,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj";
+ const MATRIX_ARGUMENT_VARIABLE_NAME: &str = "mat";
+
+ // Write function return type and name
+ write!(self.out, "void ")?;
+ self.write_wrapped_struct_matrix_set_function_name(access)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+ let struct_name = &self.names[&NameKey::Type(access.ty)];
+ write!(
+ self.out,
+ "{} {}, ",
+ struct_name, STRUCT_ARGUMENT_VARIABLE_NAME
+ )?;
+ let member = match module.types[access.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize],
+ _ => unreachable!(),
+ };
+ self.write_type(module, member.ty)?;
+ write!(self.out, " {}", MATRIX_ARGUMENT_VARIABLE_NAME)?;
+ // Write function body
+ writeln!(self.out, ") {{")?;
+
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+
+ match module.types[member.ty].inner {
+ crate::TypeInner::Matrix { columns, .. } => {
+ for i in 0..columns as u8 {
+ writeln!(
+ self.out,
+ "{}{}.{}_{} = {}[{}];",
+ INDENT,
+ STRUCT_ARGUMENT_VARIABLE_NAME,
+ field_name,
+ i,
+ MATRIX_ARGUMENT_VARIABLE_NAME,
+ i
+ )?;
+ }
+ }
+ _ => unreachable!(),
+ }
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_struct_matrix_set_vec_function_name(
+ &mut self,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ let name = &self.names[&NameKey::Type(access.ty)];
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+ write!(self.out, "SetMatVec{}On{}", field_name, name)?;
+ Ok(())
+ }
+
+ /// Writes a function used to set a vec2 on a matCx2 from within a structure.
+ pub(super) fn write_wrapped_struct_matrix_set_vec_function(
+ &mut self,
+ module: &crate::Module,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj";
+ const VECTOR_ARGUMENT_VARIABLE_NAME: &str = "vec";
+ const MATRIX_INDEX_ARGUMENT_VARIABLE_NAME: &str = "mat_idx";
+
+ // Write function return type and name
+ write!(self.out, "void ")?;
+ self.write_wrapped_struct_matrix_set_vec_function_name(access)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+ let struct_name = &self.names[&NameKey::Type(access.ty)];
+ write!(
+ self.out,
+ "{} {}, ",
+ struct_name, STRUCT_ARGUMENT_VARIABLE_NAME
+ )?;
+ let member = match module.types[access.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize],
+ _ => unreachable!(),
+ };
+ let vec_ty = match module.types[member.ty].inner {
+ crate::TypeInner::Matrix { rows, width, .. } => crate::TypeInner::Vector {
+ size: rows,
+ kind: crate::ScalarKind::Float,
+ width,
+ },
+ _ => unreachable!(),
+ };
+ self.write_value_type(module, &vec_ty)?;
+ write!(
+ self.out,
+ " {}, uint {}",
+ VECTOR_ARGUMENT_VARIABLE_NAME, MATRIX_INDEX_ARGUMENT_VARIABLE_NAME
+ )?;
+
+ // Write function body
+ writeln!(self.out, ") {{")?;
+
+ writeln!(
+ self.out,
+ "{}switch({}) {{",
+ INDENT, MATRIX_INDEX_ARGUMENT_VARIABLE_NAME
+ )?;
+
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+
+ match module.types[member.ty].inner {
+ crate::TypeInner::Matrix { columns, .. } => {
+ for i in 0..columns as u8 {
+ writeln!(
+ self.out,
+ "{}case {}: {{ {}.{}_{} = {}; break; }}",
+ INDENT,
+ i,
+ STRUCT_ARGUMENT_VARIABLE_NAME,
+ field_name,
+ i,
+ VECTOR_ARGUMENT_VARIABLE_NAME
+ )?;
+ }
+ }
+ _ => unreachable!(),
+ }
+
+ writeln!(self.out, "{}}}", INDENT)?;
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_struct_matrix_set_scalar_function_name(
+ &mut self,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ let name = &self.names[&NameKey::Type(access.ty)];
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+ write!(self.out, "SetMatScalar{}On{}", field_name, name)?;
+ Ok(())
+ }
+
+ /// Writes a function used to set a float on a matCx2 from within a structure.
+ pub(super) fn write_wrapped_struct_matrix_set_scalar_function(
+ &mut self,
+ module: &crate::Module,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj";
+ const SCALAR_ARGUMENT_VARIABLE_NAME: &str = "scalar";
+ const MATRIX_INDEX_ARGUMENT_VARIABLE_NAME: &str = "mat_idx";
+ const VECTOR_INDEX_ARGUMENT_VARIABLE_NAME: &str = "vec_idx";
+
+ // Write function return type and name
+ write!(self.out, "void ")?;
+ self.write_wrapped_struct_matrix_set_scalar_function_name(access)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+ let struct_name = &self.names[&NameKey::Type(access.ty)];
+ write!(
+ self.out,
+ "{} {}, ",
+ struct_name, STRUCT_ARGUMENT_VARIABLE_NAME
+ )?;
+ let member = match module.types[access.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize],
+ _ => unreachable!(),
+ };
+ let scalar_ty = match module.types[member.ty].inner {
+ crate::TypeInner::Matrix { width, .. } => crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Float,
+ width,
+ },
+ _ => unreachable!(),
+ };
+ self.write_value_type(module, &scalar_ty)?;
+ write!(
+ self.out,
+ " {}, uint {}, uint {}",
+ SCALAR_ARGUMENT_VARIABLE_NAME,
+ MATRIX_INDEX_ARGUMENT_VARIABLE_NAME,
+ VECTOR_INDEX_ARGUMENT_VARIABLE_NAME
+ )?;
+
+ // Write function body
+ writeln!(self.out, ") {{")?;
+
+ writeln!(
+ self.out,
+ "{}switch({}) {{",
+ INDENT, MATRIX_INDEX_ARGUMENT_VARIABLE_NAME
+ )?;
+
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+
+ match module.types[member.ty].inner {
+ crate::TypeInner::Matrix { columns, .. } => {
+ for i in 0..columns as u8 {
+ writeln!(
+ self.out,
+ "{}case {}: {{ {}.{}_{}[{}] = {}; break; }}",
+ INDENT,
+ i,
+ STRUCT_ARGUMENT_VARIABLE_NAME,
+ field_name,
+ i,
+ VECTOR_INDEX_ARGUMENT_VARIABLE_NAME,
+ SCALAR_ARGUMENT_VARIABLE_NAME
+ )?;
+ }
+ }
+ _ => unreachable!(),
+ }
+
+ writeln!(self.out, "{}}}", INDENT)?;
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ /// Helper function that write wrapped function for `Expression::ImageQuery` and `Expression::ArrayLength`
+ ///
+ /// <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-getdimensions>
+ pub(super) fn write_wrapped_functions(
+ &mut self,
+ module: &crate::Module,
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ for (handle, _) in func_ctx.expressions.iter() {
+ match func_ctx.expressions[handle] {
+ crate::Expression::ArrayLength(expr) => {
+ let global_expr = match func_ctx.expressions[expr] {
+ crate::Expression::GlobalVariable(_) => expr,
+ crate::Expression::AccessIndex { base, index: _ } => base,
+ ref other => unreachable!("Array length of {:?}", other),
+ };
+ let global_var = match func_ctx.expressions[global_expr] {
+ crate::Expression::GlobalVariable(var_handle) => {
+ &module.global_variables[var_handle]
+ }
+ ref other => unreachable!("Array length of base {:?}", other),
+ };
+ let storage_access = match global_var.space {
+ crate::AddressSpace::Storage { access } => access,
+ _ => crate::StorageAccess::default(),
+ };
+ let wal = WrappedArrayLength {
+ writable: storage_access.contains(crate::StorageAccess::STORE),
+ };
+
+ if !self.wrapped.array_lengths.contains(&wal) {
+ self.write_wrapped_array_length_function(module, wal, handle, func_ctx)?;
+ self.wrapped.array_lengths.insert(wal);
+ }
+ }
+ crate::Expression::ImageQuery { image, query } => {
+ let wiq = match *func_ctx.info[image].ty.inner_with(&module.types) {
+ crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => WrappedImageQuery {
+ dim,
+ arrayed,
+ class,
+ query: query.into(),
+ },
+ _ => unreachable!("we only query images"),
+ };
+
+ if !self.wrapped.image_queries.contains(&wiq) {
+ self.write_wrapped_image_query_function(module, wiq, handle, func_ctx)?;
+ self.wrapped.image_queries.insert(wiq);
+ }
+ }
+ // Write `WrappedConstructor` for structs that are loaded from `AddressSpace::Storage`
+ // since they will later be used by the fn `write_storage_load`
+ crate::Expression::Load { pointer } => {
+ let pointer_space = func_ctx.info[pointer]
+ .ty
+ .inner_with(&module.types)
+ .pointer_space();
+
+ if let Some(crate::AddressSpace::Storage { .. }) = pointer_space {
+ if let Some(ty) = func_ctx.info[handle].ty.handle() {
+ write_wrapped_constructor(self, ty, module, func_ctx)?;
+ }
+ }
+
+ fn write_wrapped_constructor<W: Write>(
+ writer: &mut super::Writer<'_, W>,
+ ty: Handle<crate::Type>,
+ module: &crate::Module,
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ match module.types[ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ for member in members {
+ write_wrapped_constructor(writer, member.ty, module, func_ctx)?;
+ }
+
+ let constructor = WrappedConstructor { ty };
+ if !writer.wrapped.constructors.contains(&constructor) {
+ writer
+ .write_wrapped_constructor_function(module, constructor)?;
+ writer.wrapped.constructors.insert(constructor);
+ }
+ }
+ crate::TypeInner::Array { base, .. } => {
+ write_wrapped_constructor(writer, base, module, func_ctx)?;
+ }
+ _ => {}
+ };
+
+ Ok(())
+ }
+ }
+ crate::Expression::Compose { ty, components: _ } => {
+ let constructor = match module.types[ty].inner {
+ crate::TypeInner::Struct { .. } | crate::TypeInner::Array { .. } => {
+ WrappedConstructor { ty }
+ }
+ _ => continue,
+ };
+ if !self.wrapped.constructors.contains(&constructor) {
+ self.write_wrapped_constructor_function(module, constructor)?;
+ self.wrapped.constructors.insert(constructor);
+ }
+ }
+ // We treat matrices of the form `matCx2` as a sequence of C `vec2`s
+ // (see top level module docs for details).
+ //
+ // The functions injected here are required to get the matrix accesses working.
+ crate::Expression::AccessIndex { base, index } => {
+ let base_ty_res = &func_ctx.info[base].ty;
+ let mut resolved = base_ty_res.inner_with(&module.types);
+ let base_ty_handle = match *resolved {
+ crate::TypeInner::Pointer { base, .. } => {
+ resolved = &module.types[base].inner;
+ Some(base)
+ }
+ _ => base_ty_res.handle(),
+ };
+ if let crate::TypeInner::Struct { ref members, .. } = *resolved {
+ let member = &members[index as usize];
+
+ match module.types[member.ty].inner {
+ crate::TypeInner::Matrix {
+ rows: crate::VectorSize::Bi,
+ ..
+ } if member.binding.is_none() => {
+ let ty = base_ty_handle.unwrap();
+ let access = WrappedStructMatrixAccess { ty, index };
+
+ if !self.wrapped.struct_matrix_access.contains(&access) {
+ self.write_wrapped_struct_matrix_get_function(module, access)?;
+ self.write_wrapped_struct_matrix_set_function(module, access)?;
+ self.write_wrapped_struct_matrix_set_vec_function(
+ module, access,
+ )?;
+ self.write_wrapped_struct_matrix_set_scalar_function(
+ module, access,
+ )?;
+ self.wrapped.struct_matrix_access.insert(access);
+ }
+ }
+ _ => {}
+ }
+ }
+ }
+ _ => {}
+ };
+ }
+
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_constructor_function_for_constant(
+ &mut self,
+ module: &crate::Module,
+ constant: &crate::Constant,
+ ) -> BackendResult {
+ if let crate::ConstantInner::Composite { ty, ref components } = constant.inner {
+ match module.types[ty].inner {
+ crate::TypeInner::Struct { .. } | crate::TypeInner::Array { .. } => {
+ let constructor = WrappedConstructor { ty };
+ if !self.wrapped.constructors.contains(&constructor) {
+ self.write_wrapped_constructor_function(module, constructor)?;
+ self.wrapped.constructors.insert(constructor);
+ }
+ }
+ _ => {}
+ }
+ for constant in components {
+ self.write_wrapped_constructor_function_for_constant(
+ module,
+ &module.constants[*constant],
+ )?;
+ }
+ }
+
+ Ok(())
+ }
+
+ pub(super) fn write_texture_coordinates(
+ &mut self,
+ kind: &str,
+ coordinate: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ mip_level: Option<Handle<crate::Expression>>,
+ module: &crate::Module,
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ // HLSL expects the array index to be merged with the coordinate
+ let extra = array_index.is_some() as usize + (mip_level.is_some()) as usize;
+ if extra == 0 {
+ self.write_expr(module, coordinate, func_ctx)?;
+ } else {
+ let num_coords = match *func_ctx.info[coordinate].ty.inner_with(&module.types) {
+ crate::TypeInner::Scalar { .. } => 1,
+ crate::TypeInner::Vector { size, .. } => size as usize,
+ _ => unreachable!(),
+ };
+ write!(self.out, "{}{}(", kind, num_coords + extra)?;
+ self.write_expr(module, coordinate, func_ctx)?;
+ if let Some(expr) = array_index {
+ write!(self.out, ", ")?;
+ self.write_expr(module, expr, func_ctx)?;
+ }
+ if let Some(expr) = mip_level {
+ write!(self.out, ", ")?;
+ self.write_expr(module, expr, func_ctx)?;
+ }
+ write!(self.out, ")")?;
+ }
+ Ok(())
+ }
+
+ pub(super) fn write_mat_cx2_typedef_and_functions(
+ &mut self,
+ WrappedMatCx2 { columns }: WrappedMatCx2,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ // typedef
+ write!(self.out, "typedef struct {{ ")?;
+ for i in 0..columns as u8 {
+ write!(self.out, "float2 _{}; ", i)?;
+ }
+ writeln!(self.out, "}} __mat{}x2;", columns as u8)?;
+
+ // __get_col_of_mat
+ writeln!(
+ self.out,
+ "float2 __get_col_of_mat{}x2(__mat{}x2 mat, uint idx) {{",
+ columns as u8, columns as u8
+ )?;
+ writeln!(self.out, "{}switch(idx) {{", INDENT)?;
+ for i in 0..columns as u8 {
+ writeln!(self.out, "{}case {}: {{ return mat._{}; }}", INDENT, i, i)?;
+ }
+ writeln!(self.out, "{}default: {{ return (float2)0; }}", INDENT)?;
+ writeln!(self.out, "{}}}", INDENT)?;
+ writeln!(self.out, "}}")?;
+
+ // __set_col_of_mat
+ writeln!(
+ self.out,
+ "void __set_col_of_mat{}x2(__mat{}x2 mat, uint idx, float2 value) {{",
+ columns as u8, columns as u8
+ )?;
+ writeln!(self.out, "{}switch(idx) {{", INDENT)?;
+ for i in 0..columns as u8 {
+ writeln!(
+ self.out,
+ "{}case {}: {{ mat._{} = value; break; }}",
+ INDENT, i, i
+ )?;
+ }
+ writeln!(self.out, "{}}}", INDENT)?;
+ writeln!(self.out, "}}")?;
+
+ // __set_el_of_mat
+ writeln!(
+ self.out,
+ "void __set_el_of_mat{}x2(__mat{}x2 mat, uint idx, uint vec_idx, float value) {{",
+ columns as u8, columns as u8
+ )?;
+ writeln!(self.out, "{}switch(idx) {{", INDENT)?;
+ for i in 0..columns as u8 {
+ writeln!(
+ self.out,
+ "{}case {}: {{ mat._{}[vec_idx] = value; break; }}",
+ INDENT, i, i
+ )?;
+ }
+ writeln!(self.out, "{}}}", INDENT)?;
+ writeln!(self.out, "}}")?;
+
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_all_mat_cx2_typedefs_and_functions(
+ &mut self,
+ module: &crate::Module,
+ ) -> BackendResult {
+ for (handle, _) in module.global_variables.iter() {
+ let global = &module.global_variables[handle];
+
+ if global.space == crate::AddressSpace::Uniform {
+ if let Some(super::writer::MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = super::writer::get_inner_matrix_data(module, global.ty)
+ {
+ let entry = WrappedMatCx2 { columns };
+ if !self.wrapped.mat_cx2s.contains(&entry) {
+ self.write_mat_cx2_typedef_and_functions(entry)?;
+ self.wrapped.mat_cx2s.insert(entry);
+ }
+ }
+ }
+ }
+
+ for (_, ty) in module.types.iter() {
+ if let crate::TypeInner::Struct { ref members, .. } = ty.inner {
+ for member in members.iter() {
+ if let crate::TypeInner::Array { .. } = module.types[member.ty].inner {
+ if let Some(super::writer::MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = super::writer::get_inner_matrix_data(module, member.ty)
+ {
+ let entry = WrappedMatCx2 { columns };
+ if !self.wrapped.mat_cx2s.contains(&entry) {
+ self.write_mat_cx2_typedef_and_functions(entry)?;
+ self.wrapped.mat_cx2s.insert(entry);
+ }
+ }
+ }
+ }
+ }
+ }
+
+ Ok(())
+ }
+}
diff --git a/third_party/rust/naga/src/back/hlsl/keywords.rs b/third_party/rust/naga/src/back/hlsl/keywords.rs
new file mode 100644
index 0000000000..7519b767a1
--- /dev/null
+++ b/third_party/rust/naga/src/back/hlsl/keywords.rs
@@ -0,0 +1,166 @@
+/*!
+HLSL Reserved Words
+- <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-appendix-keywords>
+- <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-appendix-reserved-words>
+*/
+
+pub const RESERVED: &[&str] = &[
+ "AppendStructuredBuffer",
+ "asm",
+ "asm_fragment",
+ "BlendState",
+ "bool",
+ "break",
+ "Buffer",
+ "ByteAddressBuffer",
+ "case",
+ "cbuffer",
+ "centroid",
+ "class",
+ "column_major",
+ "compile",
+ "compile_fragment",
+ "CompileShader",
+ "const",
+ "continue",
+ "ComputeShader",
+ "ConsumeStructuredBuffer",
+ "default",
+ "DepthStencilState",
+ "DepthStencilView",
+ "discard",
+ "do",
+ "double",
+ "DomainShader",
+ "dword",
+ "else",
+ "export",
+ "extern",
+ "false",
+ "float",
+ "for",
+ "fxgroup",
+ "GeometryShader",
+ "groupshared",
+ "half",
+ "Hullshader",
+ "if",
+ "in",
+ "inline",
+ "inout",
+ "InputPatch",
+ "int",
+ "interface",
+ "line",
+ "lineadj",
+ "linear",
+ "LineStream",
+ "matrix",
+ "min16float",
+ "min10float",
+ "min16int",
+ "min12int",
+ "min16uint",
+ "namespace",
+ "nointerpolation",
+ "noperspective",
+ "NULL",
+ "out",
+ "OutputPatch",
+ "packoffset",
+ "pass",
+ "pixelfragment",
+ "PixelShader",
+ "point",
+ "PointStream",
+ "precise",
+ "RasterizerState",
+ "RenderTargetView",
+ "return",
+ "register",
+ "row_major",
+ "RWBuffer",
+ "RWByteAddressBuffer",
+ "RWStructuredBuffer",
+ "RWTexture1D",
+ "RWTexture1DArray",
+ "RWTexture2D",
+ "RWTexture2DArray",
+ "RWTexture3D",
+ "sample",
+ "sampler",
+ "SamplerState",
+ "SamplerComparisonState",
+ "shared",
+ "snorm",
+ "stateblock",
+ "stateblock_state",
+ "static",
+ "string",
+ "struct",
+ "switch",
+ "StructuredBuffer",
+ "tbuffer",
+ "technique",
+ "technique10",
+ "technique11",
+ "texture",
+ "Texture1D",
+ "Texture1DArray",
+ "Texture2D",
+ "Texture2DArray",
+ "Texture2DMS",
+ "Texture2DMSArray",
+ "Texture3D",
+ "TextureCube",
+ "TextureCubeArray",
+ "true",
+ "typedef",
+ "triangle",
+ "triangleadj",
+ "TriangleStream",
+ "uint",
+ "uniform",
+ "unorm",
+ "unsigned",
+ "vector",
+ "vertexfragment",
+ "VertexShader",
+ "void",
+ "volatile",
+ "while",
+ "auto",
+ "case",
+ "catch",
+ "char",
+ "class",
+ "const_cast",
+ "default",
+ "delete",
+ "dynamic_cast",
+ "enum",
+ "explicit",
+ "friend",
+ "goto",
+ "long",
+ "mutable",
+ "new",
+ "operator",
+ "private",
+ "protected",
+ "public",
+ "reinterpret_cast",
+ "short",
+ "signed",
+ "sizeof",
+ "static_cast",
+ "template",
+ "this",
+ "throw",
+ "try",
+ "typename",
+ "union",
+ "unsigned",
+ "using",
+ "virtual",
+];
diff --git a/third_party/rust/naga/src/back/hlsl/mod.rs b/third_party/rust/naga/src/back/hlsl/mod.rs
new file mode 100644
index 0000000000..333ea2cf1a
--- /dev/null
+++ b/third_party/rust/naga/src/back/hlsl/mod.rs
@@ -0,0 +1,280 @@
+/*!
+Backend for [HLSL][hlsl] (High-Level Shading Language).
+
+# Supported shader model versions:
+- 5.0
+- 5.1
+- 6.0
+
+# Layout of values in `uniform` buffers
+
+WGSL's ["Internal Layout of Values"][ilov] rules specify how each WGSL
+type should be stored in `uniform` and `storage` buffers. The HLSL we
+generate must access values in that form, even when it is not what
+HLSL would use normally.
+
+The rules described here only apply to WGSL `uniform` variables. WGSL
+`storage` buffers are translated as HLSL `ByteAddressBuffers`, for
+which we generate `Load` and `Store` method calls with explicit byte
+offsets. WGSL pipeline inputs must be scalars or vectors; they cannot
+be matrices, which is where the interesting problems arise.
+
+## Row- and column-major ordering for matrices
+
+WGSL specifies that matrices in uniform buffers are stored in
+column-major order. This matches HLSL's default, so one might expect
+things to be straightforward. Unfortunately, WGSL and HLSL disagree on
+what indexing a matrix means: in WGSL, `m[i]` retrieves the `i`'th
+*column* of `m`, whereas in HLSL it retrieves the `i`'th *row*. We
+want to avoid translating `m[i]` into some complicated reassembly of a
+vector from individually fetched components, so this is a problem.
+
+However, with a bit of trickery, it is possible to use HLSL's `m[i]`
+as the translation of WGSL's `m[i]`:
+
+- We declare all matrices in uniform buffers in HLSL with the
+ `row_major` qualifier, and transpose the row and column counts: a
+ WGSL `mat3x4<f32>`, say, becomes an HLSL `row_major float3x4`. (Note
+ that WGSL and HLSL type names put the row and column in reverse
+ order.) Since the HLSL type is the transpose of how WebGPU directs
+ the user to store the data, HLSL will load all matrices transposed.
+
+- Since matrices are transposed, an HLSL indexing expression retrieves
+ the "columns" of the intended WGSL value, as desired.
+
+- For vector-matrix multiplication, since `mul(transpose(m), v)` is
+ equivalent to `mul(v, m)` (note the reversal of the arguments), and
+ `mul(v, transpose(m))` is equivalent to `mul(m, v)`, we can
+ translate WGSL `m * v` and `v * m` to HLSL by simply reversing the
+ arguments to `mul`.
+
+## Padding in two-row matrices
+
+An HLSL `row_major floatKx2` matrix has padding between its rows that
+the WGSL `matKx2<f32>` matrix it represents does not. HLSL stores all
+matrix rows [aligned on 16-byte boundaries][16bb], whereas WGSL says
+that the columns of a `matKx2<f32>` need only be [aligned as required
+for `vec2<f32>`][ilov], which is [eight-byte alignment][8bb].
+
+To compensate for this, any time a `matKx2<f32>` appears in a WGSL
+`uniform` variable, whether directly as the variable's type or as part
+of a struct/array, we actually emit `K` separate `float2` members, and
+assemble/disassemble the matrix from its columns (in WGSL; rows in
+HLSL) upon load and store.
+
+For example, the following WGSL struct type:
+
+```ignore
+struct Baz {
+ m: mat3x2<f32>,
+}
+```
+
+is rendered as the HLSL struct type:
+
+```ignore
+struct Baz {
+ float2 m_0; float2 m_1; float2 m_2;
+};
+```
+
+The `wrapped_struct_matrix` functions in `help.rs` generate HLSL
+helper functions to access such members, converting between the stored
+form and the HLSL matrix types appropriately. For example, for reading
+the member `m` of the `Baz` struct above, we emit:
+
+```ignore
+float3x2 GetMatmOnBaz(Baz obj) {
+ return float3x2(obj.m_0, obj.m_1, obj.m_2);
+}
+```
+
+We also emit an analogous `Set` function, as well as functions for
+accessing individual columns by dynamic index.
+
+[hlsl]: https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl
+[ilov]: https://gpuweb.github.io/gpuweb/wgsl/#internal-value-layout
+[16bb]: https://github.com/microsoft/DirectXShaderCompiler/wiki/Buffer-Packing#constant-buffer-packing
+[8bb]: https://gpuweb.github.io/gpuweb/wgsl/#alignment-and-size
+*/
+
+mod conv;
+mod help;
+mod keywords;
+mod storage;
+mod writer;
+
+use std::fmt::Error as FmtError;
+use thiserror::Error;
+
+use crate::proc;
+
+#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct BindTarget {
+ pub space: u8,
+ pub register: u32,
+ /// If the binding is an unsized binding array, this overrides the size.
+ pub binding_array_size: Option<u32>,
+}
+
+// Using `BTreeMap` instead of `HashMap` so that we can hash itself.
+pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, BindTarget>;
+
+/// A HLSL shader model version.
+#[allow(non_snake_case, non_camel_case_types)]
+#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub enum ShaderModel {
+ V5_0,
+ V5_1,
+ V6_0,
+}
+
+impl ShaderModel {
+ pub const fn to_str(self) -> &'static str {
+ match self {
+ Self::V5_0 => "5_0",
+ Self::V5_1 => "5_1",
+ Self::V6_0 => "6_0",
+ }
+ }
+}
+
+impl crate::ShaderStage {
+ pub const fn to_hlsl_str(self) -> &'static str {
+ match self {
+ Self::Vertex => "vs",
+ Self::Fragment => "ps",
+ Self::Compute => "cs",
+ }
+ }
+}
+
+impl crate::ImageDimension {
+ const fn to_hlsl_str(self) -> &'static str {
+ match self {
+ Self::D1 => "1D",
+ Self::D2 => "2D",
+ Self::D3 => "3D",
+ Self::Cube => "Cube",
+ }
+ }
+}
+
+/// Shorthand result used internally by the backend
+type BackendResult = Result<(), Error>;
+
+#[derive(Clone, Debug, PartialEq, thiserror::Error)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub enum EntryPointError {
+ #[error("mapping of {0:?} is missing")]
+ MissingBinding(crate::ResourceBinding),
+}
+
+/// Configuration used in the [`Writer`].
+#[derive(Clone, Debug, Hash, PartialEq, Eq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct Options {
+ /// The hlsl shader model to be used
+ pub shader_model: ShaderModel,
+ /// Map of resources association to binding locations.
+ pub binding_map: BindingMap,
+ /// Don't panic on missing bindings, instead generate any HLSL.
+ pub fake_missing_bindings: bool,
+ /// Add special constants to `SV_VertexIndex` and `SV_InstanceIndex`,
+ /// to make them work like in Vulkan/Metal, with help of the host.
+ pub special_constants_binding: Option<BindTarget>,
+ /// Bind target of the push constant buffer
+ pub push_constants_target: Option<BindTarget>,
+}
+
+impl Default for Options {
+ fn default() -> Self {
+ Options {
+ shader_model: ShaderModel::V5_1,
+ binding_map: BindingMap::default(),
+ fake_missing_bindings: true,
+ special_constants_binding: None,
+ push_constants_target: None,
+ }
+ }
+}
+
+impl Options {
+ fn resolve_resource_binding(
+ &self,
+ res_binding: &crate::ResourceBinding,
+ ) -> Result<BindTarget, EntryPointError> {
+ match self.binding_map.get(res_binding) {
+ Some(target) => Ok(target.clone()),
+ None if self.fake_missing_bindings => Ok(BindTarget {
+ space: res_binding.group as u8,
+ register: res_binding.binding,
+ binding_array_size: None,
+ }),
+ None => Err(EntryPointError::MissingBinding(res_binding.clone())),
+ }
+ }
+}
+
+/// Reflection info for entry point names.
+#[derive(Default)]
+pub struct ReflectionInfo {
+ /// Mapping of the entry point names.
+ ///
+ /// Each item in the array corresponds to an entry point index. The real entry point name may be different if one of the
+ /// reserved words are used.
+ ///
+ /// Note: Some entry points may fail translation because of missing bindings.
+ pub entry_point_names: Vec<Result<String, EntryPointError>>,
+}
+
+#[derive(Error, Debug)]
+pub enum Error {
+ #[error(transparent)]
+ IoError(#[from] FmtError),
+ #[error("A scalar with an unsupported width was requested: {0:?} {1:?}")]
+ UnsupportedScalar(crate::ScalarKind, crate::Bytes),
+ #[error("{0}")]
+ Unimplemented(String), // TODO: Error used only during development
+ #[error("{0}")]
+ Custom(String),
+}
+
+#[derive(Default)]
+struct Wrapped {
+ array_lengths: crate::FastHashSet<help::WrappedArrayLength>,
+ image_queries: crate::FastHashSet<help::WrappedImageQuery>,
+ constructors: crate::FastHashSet<help::WrappedConstructor>,
+ struct_matrix_access: crate::FastHashSet<help::WrappedStructMatrixAccess>,
+ mat_cx2s: crate::FastHashSet<help::WrappedMatCx2>,
+}
+
+impl Wrapped {
+ fn clear(&mut self) {
+ self.array_lengths.clear();
+ self.image_queries.clear();
+ self.constructors.clear();
+ self.struct_matrix_access.clear();
+ self.mat_cx2s.clear();
+ }
+}
+
+pub struct Writer<'a, W> {
+ out: W,
+ names: crate::FastHashMap<proc::NameKey, String>,
+ namer: proc::Namer,
+ /// HLSL backend options
+ options: &'a Options,
+ /// Information about entry point arguments and result types.
+ entry_point_io: Vec<writer::EntryPointInterface>,
+ /// Set of expressions that have associated temporary variables
+ named_expressions: crate::NamedExpressions,
+ wrapped: Wrapped,
+ temp_access_chain: Vec<storage::SubAccess>,
+}
diff --git a/third_party/rust/naga/src/back/hlsl/storage.rs b/third_party/rust/naga/src/back/hlsl/storage.rs
new file mode 100644
index 0000000000..4397150453
--- /dev/null
+++ b/third_party/rust/naga/src/back/hlsl/storage.rs
@@ -0,0 +1,433 @@
+/*!
+Logic related to `ByteAddressBuffer` operations.
+
+HLSL backend uses byte address buffers for all storage buffers in IR.
+*/
+
+use super::{super::FunctionCtx, BackendResult, Error};
+use crate::{
+ proc::{Alignment, NameKey, TypeResolution},
+ Handle,
+};
+
+use std::{fmt, mem};
+
+const STORE_TEMP_NAME: &str = "_value";
+
+#[derive(Debug)]
+pub(super) enum SubAccess {
+ Offset(u32),
+ Index {
+ value: Handle<crate::Expression>,
+ stride: u32,
+ },
+}
+
+pub(super) enum StoreValue {
+ Expression(Handle<crate::Expression>),
+ TempIndex {
+ depth: usize,
+ index: u32,
+ ty: TypeResolution,
+ },
+ TempAccess {
+ depth: usize,
+ base: Handle<crate::Type>,
+ member_index: u32,
+ },
+}
+
+impl<W: fmt::Write> super::Writer<'_, W> {
+ pub(super) fn write_storage_address(
+ &mut self,
+ module: &crate::Module,
+ chain: &[SubAccess],
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ if chain.is_empty() {
+ write!(self.out, "0")?;
+ }
+ for (i, access) in chain.iter().enumerate() {
+ if i != 0 {
+ write!(self.out, "+")?;
+ }
+ match *access {
+ SubAccess::Offset(offset) => {
+ write!(self.out, "{}", offset)?;
+ }
+ SubAccess::Index { value, stride } => {
+ self.write_expr(module, value, func_ctx)?;
+ write!(self.out, "*{}", stride)?;
+ }
+ }
+ }
+ Ok(())
+ }
+
+ fn write_storage_load_sequence<I: Iterator<Item = (TypeResolution, u32)>>(
+ &mut self,
+ module: &crate::Module,
+ var_handle: Handle<crate::GlobalVariable>,
+ sequence: I,
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ for (i, (ty_resolution, offset)) in sequence.enumerate() {
+ // add the index temporarily
+ self.temp_access_chain.push(SubAccess::Offset(offset));
+ if i != 0 {
+ write!(self.out, ", ")?;
+ };
+ self.write_storage_load(module, var_handle, ty_resolution, func_ctx)?;
+ self.temp_access_chain.pop();
+ }
+ Ok(())
+ }
+
+ /// Helper function to write down the Load operation on a `ByteAddressBuffer`.
+ pub(super) fn write_storage_load(
+ &mut self,
+ module: &crate::Module,
+ var_handle: Handle<crate::GlobalVariable>,
+ result_ty: TypeResolution,
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ match *result_ty.inner_with(&module.types) {
+ crate::TypeInner::Scalar { kind, width: _ } => {
+ // working around the borrow checker in `self.write_expr`
+ let chain = mem::take(&mut self.temp_access_chain);
+ let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
+ let cast = kind.to_hlsl_cast();
+ write!(self.out, "{}({}.Load(", cast, var_name)?;
+ self.write_storage_address(module, &chain, func_ctx)?;
+ write!(self.out, "))")?;
+ self.temp_access_chain = chain;
+ }
+ crate::TypeInner::Vector {
+ size,
+ kind,
+ width: _,
+ } => {
+ // working around the borrow checker in `self.write_expr`
+ let chain = mem::take(&mut self.temp_access_chain);
+ let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
+ let cast = kind.to_hlsl_cast();
+ write!(self.out, "{}({}.Load{}(", cast, var_name, size as u8)?;
+ self.write_storage_address(module, &chain, func_ctx)?;
+ write!(self.out, "))")?;
+ self.temp_access_chain = chain;
+ }
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ write!(
+ self.out,
+ "{}{}x{}(",
+ crate::ScalarKind::Float.to_hlsl_str(width)?,
+ columns as u8,
+ rows as u8,
+ )?;
+
+ // Note: Matrices containing vec3s, due to padding, act like they contain vec4s.
+ let row_stride = Alignment::from(rows) * width as u32;
+ let iter = (0..columns as u32).map(|i| {
+ let ty_inner = crate::TypeInner::Vector {
+ size: rows,
+ kind: crate::ScalarKind::Float,
+ width,
+ };
+ (TypeResolution::Value(ty_inner), i * row_stride)
+ });
+ self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(const_handle),
+ ..
+ } => {
+ write!(self.out, "{{")?;
+ let count = module.constants[const_handle].to_array_length().unwrap();
+ let stride = module.types[base].inner.size(&module.constants);
+ let iter = (0..count).map(|i| (TypeResolution::Handle(base), stride * i));
+ self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?;
+ write!(self.out, "}}")?;
+ }
+ crate::TypeInner::Struct { ref members, .. } => {
+ let constructor = super::help::WrappedConstructor {
+ ty: result_ty.handle().unwrap(),
+ };
+ self.write_wrapped_constructor_function_name(module, constructor)?;
+ write!(self.out, "(")?;
+ let iter = members
+ .iter()
+ .map(|m| (TypeResolution::Handle(m.ty), m.offset));
+ self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ _ => unreachable!(),
+ }
+ Ok(())
+ }
+
+ fn write_store_value(
+ &mut self,
+ module: &crate::Module,
+ value: &StoreValue,
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ match *value {
+ StoreValue::Expression(expr) => self.write_expr(module, expr, func_ctx)?,
+ StoreValue::TempIndex {
+ depth,
+ index,
+ ty: _,
+ } => write!(self.out, "{}{}[{}]", STORE_TEMP_NAME, depth, index)?,
+ StoreValue::TempAccess {
+ depth,
+ base,
+ member_index,
+ } => {
+ let name = &self.names[&NameKey::StructMember(base, member_index)];
+ write!(self.out, "{}{}.{}", STORE_TEMP_NAME, depth, name)?
+ }
+ }
+ Ok(())
+ }
+
+ /// Helper function to write down the Store operation on a `ByteAddressBuffer`.
+ pub(super) fn write_storage_store(
+ &mut self,
+ module: &crate::Module,
+ var_handle: Handle<crate::GlobalVariable>,
+ value: StoreValue,
+ func_ctx: &FunctionCtx,
+ level: crate::back::Level,
+ ) -> BackendResult {
+ let temp_resolution;
+ let ty_resolution = match value {
+ StoreValue::Expression(expr) => &func_ctx.info[expr].ty,
+ StoreValue::TempIndex {
+ depth: _,
+ index: _,
+ ref ty,
+ } => ty,
+ StoreValue::TempAccess {
+ depth: _,
+ base,
+ member_index,
+ } => {
+ let ty_handle = match module.types[base].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ members[member_index as usize].ty
+ }
+ _ => unreachable!(),
+ };
+ temp_resolution = TypeResolution::Handle(ty_handle);
+ &temp_resolution
+ }
+ };
+ match *ty_resolution.inner_with(&module.types) {
+ crate::TypeInner::Scalar { .. } => {
+ // working around the borrow checker in `self.write_expr`
+ let chain = mem::take(&mut self.temp_access_chain);
+ let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
+ write!(self.out, "{}{}.Store(", level, var_name)?;
+ self.write_storage_address(module, &chain, func_ctx)?;
+ write!(self.out, ", asuint(")?;
+ self.write_store_value(module, &value, func_ctx)?;
+ writeln!(self.out, "));")?;
+ self.temp_access_chain = chain;
+ }
+ crate::TypeInner::Vector { size, .. } => {
+ // working around the borrow checker in `self.write_expr`
+ let chain = mem::take(&mut self.temp_access_chain);
+ let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
+ write!(self.out, "{}{}.Store{}(", level, var_name, size as u8)?;
+ self.write_storage_address(module, &chain, func_ctx)?;
+ write!(self.out, ", asuint(")?;
+ self.write_store_value(module, &value, func_ctx)?;
+ writeln!(self.out, "));")?;
+ self.temp_access_chain = chain;
+ }
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ // first, assign the value to a temporary
+ writeln!(self.out, "{}{{", level)?;
+ let depth = level.0 + 1;
+ write!(
+ self.out,
+ "{}{}{}x{} {}{} = ",
+ level.next(),
+ crate::ScalarKind::Float.to_hlsl_str(width)?,
+ columns as u8,
+ rows as u8,
+ STORE_TEMP_NAME,
+ depth,
+ )?;
+ self.write_store_value(module, &value, func_ctx)?;
+ writeln!(self.out, ";")?;
+
+ // Note: Matrices containing vec3s, due to padding, act like they contain vec4s.
+ let row_stride = Alignment::from(rows) * width as u32;
+
+ // then iterate the stores
+ for i in 0..columns as u32 {
+ self.temp_access_chain
+ .push(SubAccess::Offset(i * row_stride));
+ let ty_inner = crate::TypeInner::Vector {
+ size: rows,
+ kind: crate::ScalarKind::Float,
+ width,
+ };
+ let sv = StoreValue::TempIndex {
+ depth,
+ index: i,
+ ty: TypeResolution::Value(ty_inner),
+ };
+ self.write_storage_store(module, var_handle, sv, func_ctx, level.next())?;
+ self.temp_access_chain.pop();
+ }
+ // done
+ writeln!(self.out, "{}}}", level)?;
+ }
+ crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(const_handle),
+ ..
+ } => {
+ // first, assign the value to a temporary
+ writeln!(self.out, "{}{{", level)?;
+ write!(self.out, "{}", level.next())?;
+ self.write_value_type(module, &module.types[base].inner)?;
+ let depth = level.next().0;
+ write!(self.out, " {}{}", STORE_TEMP_NAME, depth)?;
+ self.write_array_size(module, base, crate::ArraySize::Constant(const_handle))?;
+ write!(self.out, " = ")?;
+ self.write_store_value(module, &value, func_ctx)?;
+ writeln!(self.out, ";")?;
+ // then iterate the stores
+ let count = module.constants[const_handle].to_array_length().unwrap();
+ let stride = module.types[base].inner.size(&module.constants);
+ for i in 0..count {
+ self.temp_access_chain.push(SubAccess::Offset(i * stride));
+ let sv = StoreValue::TempIndex {
+ depth,
+ index: i,
+ ty: TypeResolution::Handle(base),
+ };
+ self.write_storage_store(module, var_handle, sv, func_ctx, level.next())?;
+ self.temp_access_chain.pop();
+ }
+ // done
+ writeln!(self.out, "{}}}", level)?;
+ }
+ crate::TypeInner::Struct { ref members, .. } => {
+ // first, assign the value to a temporary
+ writeln!(self.out, "{}{{", level)?;
+ let depth = level.next().0;
+ let struct_ty = ty_resolution.handle().unwrap();
+ let struct_name = &self.names[&NameKey::Type(struct_ty)];
+ write!(
+ self.out,
+ "{}{} {}{} = ",
+ level.next(),
+ struct_name,
+ STORE_TEMP_NAME,
+ depth
+ )?;
+ self.write_store_value(module, &value, func_ctx)?;
+ writeln!(self.out, ";")?;
+ // then iterate the stores
+ for (i, member) in members.iter().enumerate() {
+ self.temp_access_chain
+ .push(SubAccess::Offset(member.offset));
+ let sv = StoreValue::TempAccess {
+ depth,
+ base: struct_ty,
+ member_index: i as u32,
+ };
+ self.write_storage_store(module, var_handle, sv, func_ctx, level.next())?;
+ self.temp_access_chain.pop();
+ }
+ // done
+ writeln!(self.out, "{}}}", level)?;
+ }
+ _ => unreachable!(),
+ }
+ Ok(())
+ }
+
+ pub(super) fn fill_access_chain(
+ &mut self,
+ module: &crate::Module,
+ mut cur_expr: Handle<crate::Expression>,
+ func_ctx: &FunctionCtx,
+ ) -> Result<Handle<crate::GlobalVariable>, Error> {
+ enum AccessIndex {
+ Expression(Handle<crate::Expression>),
+ Constant(u32),
+ }
+ enum Parent<'a> {
+ Array { stride: u32 },
+ Struct(&'a [crate::StructMember]),
+ }
+ self.temp_access_chain.clear();
+
+ loop {
+ let (next_expr, access_index) = match func_ctx.expressions[cur_expr] {
+ crate::Expression::GlobalVariable(handle) => return Ok(handle),
+ crate::Expression::Access { base, index } => (base, AccessIndex::Expression(index)),
+ crate::Expression::AccessIndex { base, index } => {
+ (base, AccessIndex::Constant(index))
+ }
+ ref other => {
+ return Err(Error::Unimplemented(format!(
+ "Pointer access of {:?}",
+ other
+ )))
+ }
+ };
+
+ let parent = match *func_ctx.info[next_expr].ty.inner_with(&module.types) {
+ crate::TypeInner::Pointer { base, .. } => match module.types[base].inner {
+ crate::TypeInner::Struct { ref members, .. } => Parent::Struct(members),
+ crate::TypeInner::Array { stride, .. } => Parent::Array { stride },
+ crate::TypeInner::Vector { width, .. } => Parent::Array {
+ stride: width as u32,
+ },
+ crate::TypeInner::Matrix { rows, width, .. } => Parent::Array {
+ // The stride between matrices is the count of rows as this is how
+ // long each column is.
+ stride: Alignment::from(rows) * width as u32,
+ },
+ _ => unreachable!(),
+ },
+ crate::TypeInner::ValuePointer { width, .. } => Parent::Array {
+ stride: width as u32,
+ },
+ _ => unreachable!(),
+ };
+
+ let sub = match (parent, access_index) {
+ (Parent::Array { stride }, AccessIndex::Expression(value)) => {
+ SubAccess::Index { value, stride }
+ }
+ (Parent::Array { stride }, AccessIndex::Constant(index)) => {
+ SubAccess::Offset(stride * index)
+ }
+ (Parent::Struct(members), AccessIndex::Constant(index)) => {
+ SubAccess::Offset(members[index as usize].offset)
+ }
+ (Parent::Struct(_), AccessIndex::Expression(_)) => unreachable!(),
+ };
+
+ self.temp_access_chain.push(sub);
+ cur_expr = next_expr;
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/back/hlsl/writer.rs b/third_party/rust/naga/src/back/hlsl/writer.rs
new file mode 100644
index 0000000000..e29d2c41db
--- /dev/null
+++ b/third_party/rust/naga/src/back/hlsl/writer.rs
@@ -0,0 +1,2980 @@
+use super::{
+ help::{WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess},
+ storage::StoreValue,
+ BackendResult, Error, Options,
+};
+use crate::{
+ back,
+ proc::{self, NameKey},
+ valid, Handle, Module, ScalarKind, ShaderStage, TypeInner,
+};
+use std::{fmt, mem};
+
+const LOCATION_SEMANTIC: &str = "LOC";
+const SPECIAL_CBUF_TYPE: &str = "NagaConstants";
+const SPECIAL_CBUF_VAR: &str = "_NagaConstants";
+const SPECIAL_BASE_VERTEX: &str = "base_vertex";
+const SPECIAL_BASE_INSTANCE: &str = "base_instance";
+const SPECIAL_OTHER: &str = "other";
+
+struct EpStructMember {
+ name: String,
+ ty: Handle<crate::Type>,
+ // technically, this should always be `Some`
+ binding: Option<crate::Binding>,
+ index: u32,
+}
+
+/// Structure contains information required for generating
+/// wrapped structure of all entry points arguments
+struct EntryPointBinding {
+ /// Name of the fake EP argument that contains the struct
+ /// with all the flattened input data.
+ arg_name: String,
+ /// Generated structure name
+ ty_name: String,
+ /// Members of generated structure
+ members: Vec<EpStructMember>,
+}
+
+pub(super) struct EntryPointInterface {
+ /// If `Some`, the input of an entry point is gathered in a special
+ /// struct with members sorted by binding.
+ /// The `EntryPointBinding::members` array is sorted by index,
+ /// so that we can walk it in `write_ep_arguments_initialization`.
+ input: Option<EntryPointBinding>,
+ /// If `Some`, the output of an entry point is flattened.
+ /// The `EntryPointBinding::members` array is sorted by binding,
+ /// So that we can walk it in `Statement::Return` handler.
+ output: Option<EntryPointBinding>,
+}
+
+#[derive(Clone, Eq, PartialEq, PartialOrd, Ord)]
+enum InterfaceKey {
+ Location(u32),
+ BuiltIn(crate::BuiltIn),
+ Other,
+}
+
+impl InterfaceKey {
+ const fn new(binding: Option<&crate::Binding>) -> Self {
+ match binding {
+ Some(&crate::Binding::Location { location, .. }) => Self::Location(location),
+ Some(&crate::Binding::BuiltIn(built_in)) => Self::BuiltIn(built_in),
+ None => Self::Other,
+ }
+ }
+}
+
+#[derive(Copy, Clone, PartialEq)]
+enum Io {
+ Input,
+ Output,
+}
+
+impl<'a, W: fmt::Write> super::Writer<'a, W> {
+ pub fn new(out: W, options: &'a Options) -> Self {
+ Self {
+ out,
+ names: crate::FastHashMap::default(),
+ namer: proc::Namer::default(),
+ options,
+ entry_point_io: Vec::new(),
+ named_expressions: crate::NamedExpressions::default(),
+ wrapped: super::Wrapped::default(),
+ temp_access_chain: Vec::new(),
+ }
+ }
+
+ fn reset(&mut self, module: &Module) {
+ self.names.clear();
+ self.namer
+ .reset(module, super::keywords::RESERVED, &[], &mut self.names);
+ self.entry_point_io.clear();
+ self.named_expressions.clear();
+ self.wrapped.clear();
+ }
+
+ pub fn write(
+ &mut self,
+ module: &Module,
+ module_info: &valid::ModuleInfo,
+ ) -> Result<super::ReflectionInfo, Error> {
+ self.reset(module);
+
+ // Write special constants, if needed
+ if let Some(ref bt) = self.options.special_constants_binding {
+ writeln!(self.out, "struct {} {{", SPECIAL_CBUF_TYPE)?;
+ writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_BASE_VERTEX)?;
+ writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_BASE_INSTANCE)?;
+ writeln!(self.out, "{}uint {};", back::INDENT, SPECIAL_OTHER)?;
+ writeln!(self.out, "}};")?;
+ write!(
+ self.out,
+ "ConstantBuffer<{}> {}: register(b{}",
+ SPECIAL_CBUF_TYPE, SPECIAL_CBUF_VAR, bt.register
+ )?;
+ if bt.space != 0 {
+ write!(self.out, ", space{}", bt.space)?;
+ }
+ writeln!(self.out, ");")?;
+ }
+
+ // Write all constants
+ // For example, input wgsl shader:
+ // ```wgsl
+ // let c_scale: f32 = 1.2;
+ // return VertexOutput(uv, vec4<f32>(c_scale * pos, 0.0, 1.0));
+ // ```
+ //
+ // Output shader:
+ // ```hlsl
+ // static const float c_scale = 1.2;
+ // const VertexOutput vertexoutput1 = { vertexinput.uv3, float4((c_scale * vertexinput.pos1), 0.0, 1.0) };
+ // ```
+ //
+ // If we remove `write_global_constant` `c_scale` will be inlined.
+ for (handle, constant) in module.constants.iter() {
+ if constant.name.is_some() {
+ self.write_global_constant(module, &constant.inner, handle)?;
+ }
+ }
+
+ // Extra newline for readability
+ writeln!(self.out)?;
+
+ // Save all entry point output types
+ let ep_results = module
+ .entry_points
+ .iter()
+ .map(|ep| (ep.stage, ep.function.result.clone()))
+ .collect::<Vec<(ShaderStage, Option<crate::FunctionResult>)>>();
+
+ self.write_all_mat_cx2_typedefs_and_functions(module)?;
+
+ // Write all structs
+ for (handle, ty) in module.types.iter() {
+ if let TypeInner::Struct { ref members, span } = ty.inner {
+ if module.types[members.last().unwrap().ty]
+ .inner
+ .is_dynamically_sized(&module.types)
+ {
+ // unsized arrays can only be in storage buffers,
+ // for which we use `ByteAddressBuffer` anyway.
+ continue;
+ }
+
+ let ep_result = ep_results.iter().find(|e| {
+ if let Some(ref result) = e.1 {
+ result.ty == handle
+ } else {
+ false
+ }
+ });
+
+ self.write_struct(
+ module,
+ handle,
+ members,
+ span,
+ ep_result.map(|r| (r.0, Io::Output)),
+ )?;
+ writeln!(self.out)?;
+ }
+ }
+
+ // Write wrapped constructor functions used in constants
+ for (_, constant) in module.constants.iter() {
+ self.write_wrapped_constructor_function_for_constant(module, constant)?;
+ }
+
+ // Write all globals
+ for (ty, _) in module.global_variables.iter() {
+ self.write_global(module, ty)?;
+ }
+
+ if !module.global_variables.is_empty() {
+ // Add extra newline for readability
+ writeln!(self.out)?;
+ }
+
+ // Write all entry points wrapped structs
+ for (index, ep) in module.entry_points.iter().enumerate() {
+ let ep_name = self.names[&NameKey::EntryPoint(index as u16)].clone();
+ let ep_io = self.write_ep_interface(module, &ep.function, ep.stage, &ep_name)?;
+ self.entry_point_io.push(ep_io);
+ }
+
+ // Write all regular functions
+ for (handle, function) in module.functions.iter() {
+ let info = &module_info[handle];
+
+ // Check if all of the globals are accessible
+ if !self.options.fake_missing_bindings {
+ if let Some((var_handle, _)) =
+ module
+ .global_variables
+ .iter()
+ .find(|&(var_handle, var)| match var.binding {
+ Some(ref binding) if !info[var_handle].is_empty() => {
+ self.options.resolve_resource_binding(binding).is_err()
+ }
+ _ => false,
+ })
+ {
+ log::info!(
+ "Skipping function {:?} (name {:?}) because global {:?} is inaccessible",
+ handle,
+ function.name,
+ var_handle
+ );
+ continue;
+ }
+ }
+
+ let ctx = back::FunctionCtx {
+ ty: back::FunctionType::Function(handle),
+ info,
+ expressions: &function.expressions,
+ named_expressions: &function.named_expressions,
+ };
+ let name = self.names[&NameKey::Function(handle)].clone();
+
+ // Write wrapped function for `Expression::ImageQuery` and `Expressions::ArrayLength`
+ // before writing all statements and expressions.
+ self.write_wrapped_functions(module, &ctx)?;
+
+ self.write_function(module, name.as_str(), function, &ctx)?;
+
+ writeln!(self.out)?;
+ }
+
+ let mut entry_point_names = Vec::with_capacity(module.entry_points.len());
+
+ // Write all entry points
+ for (index, ep) in module.entry_points.iter().enumerate() {
+ let info = module_info.get_entry_point(index);
+
+ if !self.options.fake_missing_bindings {
+ let mut ep_error = None;
+ for (var_handle, var) in module.global_variables.iter() {
+ match var.binding {
+ Some(ref binding) if !info[var_handle].is_empty() => {
+ if let Err(err) = self.options.resolve_resource_binding(binding) {
+ ep_error = Some(err);
+ break;
+ }
+ }
+ _ => {}
+ }
+ }
+ if let Some(err) = ep_error {
+ entry_point_names.push(Err(err));
+ continue;
+ }
+ }
+
+ let ctx = back::FunctionCtx {
+ ty: back::FunctionType::EntryPoint(index as u16),
+ info,
+ expressions: &ep.function.expressions,
+ named_expressions: &ep.function.named_expressions,
+ };
+
+ // Write wrapped function for `Expression::ImageQuery` and `Expressions::ArrayLength`
+ // before writing all statements and expressions.
+ self.write_wrapped_functions(module, &ctx)?;
+
+ if ep.stage == ShaderStage::Compute {
+ // HLSL is calling workgroup size "num threads"
+ let num_threads = ep.workgroup_size;
+ writeln!(
+ self.out,
+ "[numthreads({}, {}, {})]",
+ num_threads[0], num_threads[1], num_threads[2]
+ )?;
+ }
+
+ let name = self.names[&NameKey::EntryPoint(index as u16)].clone();
+ self.write_function(module, &name, &ep.function, &ctx)?;
+
+ if index < module.entry_points.len() - 1 {
+ writeln!(self.out)?;
+ }
+
+ entry_point_names.push(Ok(name));
+ }
+
+ Ok(super::ReflectionInfo { entry_point_names })
+ }
+
+ fn write_modifier(&mut self, binding: &crate::Binding) -> BackendResult {
+ match *binding {
+ crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }) => {
+ write!(self.out, "precise ")?;
+ }
+ crate::Binding::Location {
+ interpolation,
+ sampling,
+ ..
+ } => {
+ if let Some(interpolation) = interpolation {
+ if let Some(string) = interpolation.to_hlsl_str() {
+ write!(self.out, "{} ", string)?
+ }
+ }
+
+ if let Some(sampling) = sampling {
+ if let Some(string) = sampling.to_hlsl_str() {
+ write!(self.out, "{} ", string)?
+ }
+ }
+ }
+ _ => {}
+ }
+
+ Ok(())
+ }
+
+ //TODO: we could force fragment outputs to always go through `entry_point_io.output` path
+ // if they are struct, so that the `stage` argument here could be omitted.
+ fn write_semantic(
+ &mut self,
+ binding: &crate::Binding,
+ stage: Option<(ShaderStage, Io)>,
+ ) -> BackendResult {
+ match *binding {
+ crate::Binding::BuiltIn(builtin) => {
+ let builtin_str = builtin.to_hlsl_str()?;
+ write!(self.out, " : {}", builtin_str)?;
+ }
+ crate::Binding::Location { location, .. } => {
+ if stage == Some((crate::ShaderStage::Fragment, Io::Output)) {
+ write!(self.out, " : SV_Target{}", location)?;
+ } else {
+ write!(self.out, " : {}{}", LOCATION_SEMANTIC, location)?;
+ }
+ }
+ }
+
+ Ok(())
+ }
+
+ fn write_interface_struct(
+ &mut self,
+ module: &Module,
+ shader_stage: (ShaderStage, Io),
+ struct_name: String,
+ mut members: Vec<EpStructMember>,
+ ) -> Result<EntryPointBinding, Error> {
+ // Sort the members so that first come the user-defined varyings
+ // in ascending locations, and then built-ins. This allows VS and FS
+ // interfaces to match with regards to order.
+ members.sort_by_key(|m| InterfaceKey::new(m.binding.as_ref()));
+
+ write!(self.out, "struct {}", struct_name)?;
+ writeln!(self.out, " {{")?;
+ for m in members.iter() {
+ write!(self.out, "{}", back::INDENT)?;
+ if let Some(ref binding) = m.binding {
+ self.write_modifier(binding)?;
+ }
+ self.write_type(module, m.ty)?;
+ write!(self.out, " {}", &m.name)?;
+ if let Some(ref binding) = m.binding {
+ self.write_semantic(binding, Some(shader_stage))?;
+ }
+ writeln!(self.out, ";")?;
+ }
+ writeln!(self.out, "}};")?;
+ writeln!(self.out)?;
+
+ match shader_stage.1 {
+ Io::Input => {
+ // bring back the original order
+ members.sort_by_key(|m| m.index);
+ }
+ Io::Output => {
+ // keep it sorted by binding
+ }
+ }
+
+ Ok(EntryPointBinding {
+ arg_name: self.namer.call(struct_name.to_lowercase().as_str()),
+ ty_name: struct_name,
+ members,
+ })
+ }
+
+ /// Flatten all entry point arguments into a single struct.
+ /// This is needed since we need to re-order them: first placing user locations,
+ /// then built-ins.
+ fn write_ep_input_struct(
+ &mut self,
+ module: &Module,
+ func: &crate::Function,
+ stage: ShaderStage,
+ entry_point_name: &str,
+ ) -> Result<EntryPointBinding, Error> {
+ let struct_name = format!("{:?}Input_{}", stage, entry_point_name);
+
+ let mut fake_members = Vec::new();
+ for arg in func.arguments.iter() {
+ match module.types[arg.ty].inner {
+ TypeInner::Struct { ref members, .. } => {
+ for member in members.iter() {
+ let name = self.namer.call_or(&member.name, "member");
+ let index = fake_members.len() as u32;
+ fake_members.push(EpStructMember {
+ name,
+ ty: member.ty,
+ binding: member.binding.clone(),
+ index,
+ });
+ }
+ }
+ _ => {
+ let member_name = self.namer.call_or(&arg.name, "member");
+ let index = fake_members.len() as u32;
+ fake_members.push(EpStructMember {
+ name: member_name,
+ ty: arg.ty,
+ binding: arg.binding.clone(),
+ index,
+ });
+ }
+ }
+ }
+
+ self.write_interface_struct(module, (stage, Io::Input), struct_name, fake_members)
+ }
+
+ /// Flatten all entry point results into a single struct.
+ /// This is needed since we need to re-order them: first placing user locations,
+ /// then built-ins.
+ fn write_ep_output_struct(
+ &mut self,
+ module: &Module,
+ result: &crate::FunctionResult,
+ stage: ShaderStage,
+ entry_point_name: &str,
+ ) -> Result<EntryPointBinding, Error> {
+ let struct_name = format!("{:?}Output_{}", stage, entry_point_name);
+
+ let mut fake_members = Vec::new();
+ let empty = [];
+ let members = match module.types[result.ty].inner {
+ TypeInner::Struct { ref members, .. } => members,
+ ref other => {
+ log::error!("Unexpected {:?} output type without a binding", other);
+ &empty[..]
+ }
+ };
+
+ for member in members.iter() {
+ let member_name = self.namer.call_or(&member.name, "member");
+ let index = fake_members.len() as u32;
+ fake_members.push(EpStructMember {
+ name: member_name,
+ ty: member.ty,
+ binding: member.binding.clone(),
+ index,
+ });
+ }
+
+ self.write_interface_struct(module, (stage, Io::Output), struct_name, fake_members)
+ }
+
+ /// Writes special interface structures for an entry point. The special structures have
+ /// all the fields flattened into them and sorted by binding. They are only needed for
+ /// VS outputs and FS inputs, so that these interfaces match.
+ fn write_ep_interface(
+ &mut self,
+ module: &Module,
+ func: &crate::Function,
+ stage: ShaderStage,
+ ep_name: &str,
+ ) -> Result<EntryPointInterface, Error> {
+ Ok(EntryPointInterface {
+ input: if !func.arguments.is_empty() && stage == ShaderStage::Fragment {
+ Some(self.write_ep_input_struct(module, func, stage, ep_name)?)
+ } else {
+ None
+ },
+ output: match func.result {
+ Some(ref fr) if fr.binding.is_none() && stage == ShaderStage::Vertex => {
+ Some(self.write_ep_output_struct(module, fr, stage, ep_name)?)
+ }
+ _ => None,
+ },
+ })
+ }
+
+ /// Write an entry point preface that initializes the arguments as specified in IR.
+ fn write_ep_arguments_initialization(
+ &mut self,
+ module: &Module,
+ func: &crate::Function,
+ ep_index: u16,
+ ) -> BackendResult {
+ let ep_input = match self.entry_point_io[ep_index as usize].input.take() {
+ Some(ep_input) => ep_input,
+ None => return Ok(()),
+ };
+ let mut fake_iter = ep_input.members.iter();
+ for (arg_index, arg) in func.arguments.iter().enumerate() {
+ write!(self.out, "{}", back::INDENT)?;
+ self.write_type(module, arg.ty)?;
+ let arg_name = &self.names[&NameKey::EntryPointArgument(ep_index, arg_index as u32)];
+ write!(self.out, " {}", arg_name)?;
+ match module.types[arg.ty].inner {
+ TypeInner::Array { base, size, .. } => {
+ self.write_array_size(module, base, size)?;
+ let fake_member = fake_iter.next().unwrap();
+ writeln!(self.out, " = {}.{};", ep_input.arg_name, fake_member.name)?;
+ }
+ TypeInner::Struct { ref members, .. } => {
+ write!(self.out, " = {{ ")?;
+ for index in 0..members.len() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ let fake_member = fake_iter.next().unwrap();
+ write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
+ }
+ writeln!(self.out, " }};")?;
+ }
+ _ => {
+ let fake_member = fake_iter.next().unwrap();
+ writeln!(self.out, " = {}.{};", ep_input.arg_name, fake_member.name)?;
+ }
+ }
+ }
+ assert!(fake_iter.next().is_none());
+ Ok(())
+ }
+
+ /// Helper method used to write global variables
+ /// # Notes
+ /// Always adds a newline
+ fn write_global(
+ &mut self,
+ module: &Module,
+ handle: Handle<crate::GlobalVariable>,
+ ) -> BackendResult {
+ let global = &module.global_variables[handle];
+ let inner = &module.types[global.ty].inner;
+
+ if let Some(ref binding) = global.binding {
+ if let Err(err) = self.options.resolve_resource_binding(binding) {
+ log::info!(
+ "Skipping global {:?} (name {:?}) for being inaccessible: {}",
+ handle,
+ global.name,
+ err,
+ );
+ return Ok(());
+ }
+ }
+
+ // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-variable-register
+ let register_ty = match global.space {
+ crate::AddressSpace::Function => unreachable!("Function address space"),
+ crate::AddressSpace::Private => {
+ write!(self.out, "static ")?;
+ self.write_type(module, global.ty)?;
+ ""
+ }
+ crate::AddressSpace::WorkGroup => {
+ write!(self.out, "groupshared ")?;
+ self.write_type(module, global.ty)?;
+ ""
+ }
+ crate::AddressSpace::Uniform => {
+ // constant buffer declarations are expected to be inlined, e.g.
+ // `cbuffer foo: register(b0) { field1: type1; }`
+ write!(self.out, "cbuffer")?;
+ "b"
+ }
+ crate::AddressSpace::Storage { access } => {
+ let (prefix, register) = if access.contains(crate::StorageAccess::STORE) {
+ ("RW", "u")
+ } else {
+ ("", "t")
+ };
+ write!(self.out, "{}ByteAddressBuffer", prefix)?;
+ register
+ }
+ crate::AddressSpace::Handle => {
+ let handle_ty = match *inner {
+ TypeInner::BindingArray { ref base, .. } => &module.types[*base].inner,
+ _ => inner,
+ };
+
+ let register = match *handle_ty {
+ TypeInner::Sampler { .. } => "s",
+ // all storage textures are UAV, unconditionally
+ TypeInner::Image {
+ class: crate::ImageClass::Storage { .. },
+ ..
+ } => "u",
+ _ => "t",
+ };
+ self.write_type(module, global.ty)?;
+ register
+ }
+ crate::AddressSpace::PushConstant => {
+ // The type of the push constants will be wrapped in `ConstantBuffer`
+ write!(self.out, "ConstantBuffer<")?;
+ "b"
+ }
+ };
+
+ // If the global is a push constant write the type now because it will be a
+ // generic argument to `ConstantBuffer`
+ if global.space == crate::AddressSpace::PushConstant {
+ self.write_global_type(module, global.ty)?;
+
+ // need to write the array size if the type was emitted with `write_type`
+ if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+
+ // Close the angled brackets for the generic argument
+ write!(self.out, ">")?;
+ }
+
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ write!(self.out, " {}", name)?;
+
+ // Push constants need to be assigned a binding explicitly by the consumer
+ // since naga has no way to know the binding from the shader alone
+ if global.space == crate::AddressSpace::PushConstant {
+ let target = self
+ .options
+ .push_constants_target
+ .as_ref()
+ .expect("No bind target was defined for the push constants block");
+ write!(self.out, ": register(b{}", target.register)?;
+ if target.space != 0 {
+ write!(self.out, ", space{}", target.space)?;
+ }
+ write!(self.out, ")")?;
+ }
+
+ if let Some(ref binding) = global.binding {
+ // this was already resolved earlier when we started evaluating an entry point.
+ let bt = self.options.resolve_resource_binding(binding).unwrap();
+
+ // need to write the binding array size if the type was emitted with `write_type`
+ if let TypeInner::BindingArray { base, size, .. } = module.types[global.ty].inner {
+ if let Some(overridden_size) = bt.binding_array_size {
+ write!(self.out, "[{}]", overridden_size)?;
+ } else {
+ self.write_array_size(module, base, size)?;
+ }
+ }
+
+ write!(self.out, " : register({}{}", register_ty, bt.register)?;
+ if bt.space != 0 {
+ write!(self.out, ", space{}", bt.space)?;
+ }
+ write!(self.out, ")")?;
+ } else {
+ // need to write the array size if the type was emitted with `write_type`
+ if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+ if global.space == crate::AddressSpace::Private {
+ write!(self.out, " = ")?;
+ if let Some(init) = global.init {
+ self.write_constant(module, init)?;
+ } else {
+ self.write_default_init(module, global.ty)?;
+ }
+ }
+ }
+
+ if global.space == crate::AddressSpace::Uniform {
+ write!(self.out, " {{ ")?;
+
+ self.write_global_type(module, global.ty)?;
+
+ write!(
+ self.out,
+ " {}",
+ &self.names[&NameKey::GlobalVariable(handle)]
+ )?;
+
+ // need to write the array size if the type was emitted with `write_type`
+ if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+
+ writeln!(self.out, "; }}")?;
+ } else {
+ writeln!(self.out, ";")?;
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write global constants
+ ///
+ /// # Notes
+ /// Ends in a newline
+ fn write_global_constant(
+ &mut self,
+ module: &Module,
+ inner: &crate::ConstantInner,
+ handle: Handle<crate::Constant>,
+ ) -> BackendResult {
+ write!(self.out, "static const ")?;
+ match *inner {
+ crate::ConstantInner::Scalar {
+ width: _,
+ ref value,
+ } => {
+ // Write type
+ let ty_str = match *value {
+ crate::ScalarValue::Sint(_) => "int",
+ crate::ScalarValue::Uint(_) => "uint",
+ crate::ScalarValue::Float(_) => "float",
+ crate::ScalarValue::Bool(_) => "bool",
+ };
+ let name = &self.names[&NameKey::Constant(handle)];
+ write!(self.out, "{} {} = ", ty_str, name)?;
+
+ // Second match required to avoid heap allocation by `format!()`
+ match *value {
+ crate::ScalarValue::Sint(value) => write!(self.out, "{}", value)?,
+ crate::ScalarValue::Uint(value) => write!(self.out, "{}", value)?,
+ crate::ScalarValue::Float(value) => {
+ // Floats are written using `Debug` instead of `Display` because it always appends the
+ // decimal part even it's zero
+ write!(self.out, "{:?}", value)?
+ }
+ crate::ScalarValue::Bool(value) => write!(self.out, "{}", value)?,
+ };
+ }
+ crate::ConstantInner::Composite { ty, ref components } => {
+ self.write_type(module, ty)?;
+ let name = &self.names[&NameKey::Constant(handle)];
+ write!(self.out, " {} = ", name)?;
+ self.write_composite_constant(module, ty, components)?;
+ }
+ }
+ writeln!(self.out, ";")?;
+ Ok(())
+ }
+
+ pub(super) fn write_array_size(
+ &mut self,
+ module: &Module,
+ base: Handle<crate::Type>,
+ size: crate::ArraySize,
+ ) -> BackendResult {
+ write!(self.out, "[")?;
+
+ // Write the array size
+ // Writes nothing if `ArraySize::Dynamic`
+ // Panics if `ArraySize::Constant` has a constant that isn't an sint or uint
+ match size {
+ crate::ArraySize::Constant(const_handle) => {
+ let size = module.constants[const_handle].to_array_length().unwrap();
+ write!(self.out, "{}", size)?;
+ }
+ crate::ArraySize::Dynamic => {}
+ }
+
+ write!(self.out, "]")?;
+
+ if let TypeInner::Array {
+ base: next_base,
+ size: next_size,
+ ..
+ } = module.types[base].inner
+ {
+ self.write_array_size(module, next_base, next_size)?;
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write structs
+ ///
+ /// # Notes
+ /// Ends in a newline
+ fn write_struct(
+ &mut self,
+ module: &Module,
+ handle: Handle<crate::Type>,
+ members: &[crate::StructMember],
+ span: u32,
+ shader_stage: Option<(ShaderStage, Io)>,
+ ) -> BackendResult {
+ // Write struct name
+ let struct_name = &self.names[&NameKey::Type(handle)];
+ writeln!(self.out, "struct {} {{", struct_name)?;
+
+ let mut last_offset = 0;
+ for (index, member) in members.iter().enumerate() {
+ if member.binding.is_none() && member.offset > last_offset {
+ // using int as padding should work as long as the backend
+ // doesn't support a type that's less than 4 bytes in size
+ // (Error::UnsupportedScalar catches this)
+ let padding = (member.offset - last_offset) / 4;
+ for i in 0..padding {
+ writeln!(self.out, "{}int _pad{}_{};", back::INDENT, index, i)?;
+ }
+ }
+ let ty_inner = &module.types[member.ty].inner;
+ last_offset = member.offset
+ + ty_inner
+ .try_size_hlsl(&module.types, &module.constants)
+ .unwrap();
+
+ // The indentation is only for readability
+ write!(self.out, "{}", back::INDENT)?;
+
+ match module.types[member.ty].inner {
+ TypeInner::Array { base, size, .. } => {
+ // HLSL arrays are written as `type name[size]`
+
+ self.write_global_type(module, member.ty)?;
+
+ // Write `name`
+ write!(
+ self.out,
+ " {}",
+ &self.names[&NameKey::StructMember(handle, index as u32)]
+ )?;
+ // Write [size]
+ self.write_array_size(module, base, size)?;
+ }
+ // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
+ // See the module-level block comment in mod.rs for details.
+ TypeInner::Matrix {
+ rows,
+ columns,
+ width,
+ } if member.binding.is_none() && rows == crate::VectorSize::Bi => {
+ let vec_ty = crate::TypeInner::Vector {
+ size: rows,
+ kind: crate::ScalarKind::Float,
+ width,
+ };
+ let field_name_key = NameKey::StructMember(handle, index as u32);
+
+ for i in 0..columns as u8 {
+ if i != 0 {
+ write!(self.out, "; ")?;
+ }
+ self.write_value_type(module, &vec_ty)?;
+ write!(self.out, " {}_{}", &self.names[&field_name_key], i)?;
+ }
+ }
+ _ => {
+ // Write modifier before type
+ if let Some(ref binding) = member.binding {
+ self.write_modifier(binding)?;
+ }
+
+ // Even though Naga IR matrices are column-major, we must describe
+ // matrices passed from the CPU as being in row-major order.
+ // See the module-level block comment in mod.rs for details.
+ if let TypeInner::Matrix { .. } = module.types[member.ty].inner {
+ write!(self.out, "row_major ")?;
+ }
+
+ // Write the member type and name
+ self.write_type(module, member.ty)?;
+ write!(
+ self.out,
+ " {}",
+ &self.names[&NameKey::StructMember(handle, index as u32)]
+ )?;
+ }
+ }
+
+ if let Some(ref binding) = member.binding {
+ self.write_semantic(binding, shader_stage)?;
+ };
+ writeln!(self.out, ";")?;
+ }
+
+ // add padding at the end since sizes of types don't get rounded up to their alignment in HLSL
+ if members.last().unwrap().binding.is_none() && span > last_offset {
+ let padding = (span - last_offset) / 4;
+ for i in 0..padding {
+ writeln!(self.out, "{}int _end_pad_{};", back::INDENT, i)?;
+ }
+ }
+
+ writeln!(self.out, "}};")?;
+ Ok(())
+ }
+
+ /// Helper method used to write global/structs non image/sampler types
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ pub(super) fn write_global_type(
+ &mut self,
+ module: &Module,
+ ty: Handle<crate::Type>,
+ ) -> BackendResult {
+ let matrix_data = get_inner_matrix_data(module, ty);
+
+ // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
+ // See the module-level block comment in mod.rs for details.
+ if let Some(MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = matrix_data
+ {
+ write!(self.out, "__mat{}x2", columns as u8)?;
+ } else {
+ // Even though Naga IR matrices are column-major, we must describe
+ // matrices passed from the CPU as being in row-major order.
+ // See the module-level block comment in mod.rs for details.
+ if matrix_data.is_some() {
+ write!(self.out, "row_major ")?;
+ }
+
+ self.write_type(module, ty)?;
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write non image/sampler types
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ pub(super) fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
+ let inner = &module.types[ty].inner;
+ match *inner {
+ TypeInner::Struct { .. } => write!(self.out, "{}", self.names[&NameKey::Type(ty)])?,
+ // hlsl array has the size separated from the base type
+ TypeInner::Array { base, .. } | TypeInner::BindingArray { base, .. } => {
+ self.write_type(module, base)?
+ }
+ ref other => self.write_value_type(module, other)?,
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write value types
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ pub(super) fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult {
+ match *inner {
+ TypeInner::Scalar { kind, width } | TypeInner::Atomic { kind, width } => {
+ write!(self.out, "{}", kind.to_hlsl_str(width)?)?;
+ }
+ TypeInner::Vector { size, kind, width } => {
+ write!(
+ self.out,
+ "{}{}",
+ kind.to_hlsl_str(width)?,
+ back::vector_size_str(size)
+ )?;
+ }
+ TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ // The IR supports only float matrix
+ // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-matrix
+
+ // Because of the implicit transpose all matrices have in HLSL, we need to transpose the size as well.
+ write!(
+ self.out,
+ "{}{}x{}",
+ crate::ScalarKind::Float.to_hlsl_str(width)?,
+ back::vector_size_str(columns),
+ back::vector_size_str(rows),
+ )?;
+ }
+ TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ self.write_image_type(dim, arrayed, class)?;
+ }
+ TypeInner::Sampler { comparison } => {
+ let sampler = if comparison {
+ "SamplerComparisonState"
+ } else {
+ "SamplerState"
+ };
+ write!(self.out, "{}", sampler)?;
+ }
+ // HLSL arrays are written as `type name[size]`
+ // Current code is written arrays only as `[size]`
+ // Base `type` and `name` should be written outside
+ TypeInner::Array { base, size, .. } | TypeInner::BindingArray { base, size } => {
+ self.write_array_size(module, base, size)?;
+ }
+ _ => {
+ return Err(Error::Unimplemented(format!(
+ "write_value_type {:?}",
+ inner
+ )))
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write functions
+ /// # Notes
+ /// Ends in a newline
+ fn write_function(
+ &mut self,
+ module: &Module,
+ name: &str,
+ func: &crate::Function,
+ func_ctx: &back::FunctionCtx<'_>,
+ ) -> BackendResult {
+ // Function Declaration Syntax - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-function-syntax
+
+ // Write modifier
+ if let Some(crate::FunctionResult {
+ binding:
+ Some(
+ ref binding @ crate::Binding::BuiltIn(crate::BuiltIn::Position {
+ invariant: true,
+ }),
+ ),
+ ..
+ }) = func.result
+ {
+ self.write_modifier(binding)?;
+ }
+
+ // Write return type
+ if let Some(ref result) = func.result {
+ match func_ctx.ty {
+ back::FunctionType::Function(_) => {
+ self.write_type(module, result.ty)?;
+ }
+ back::FunctionType::EntryPoint(index) => {
+ if let Some(ref ep_output) = self.entry_point_io[index as usize].output {
+ write!(self.out, "{}", ep_output.ty_name)?;
+ } else {
+ self.write_type(module, result.ty)?;
+ }
+ }
+ }
+ } else {
+ write!(self.out, "void")?;
+ }
+
+ // Write function name
+ write!(self.out, " {}(", name)?;
+
+ // Write function arguments for non entry point functions
+ match func_ctx.ty {
+ back::FunctionType::Function(handle) => {
+ for (index, arg) in func.arguments.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ // Write argument type
+ let arg_ty = match module.types[arg.ty].inner {
+ // pointers in function arguments are expected and resolve to `inout`
+ TypeInner::Pointer { base, .. } => {
+ //TODO: can we narrow this down to just `in` when possible?
+ write!(self.out, "inout ")?;
+ base
+ }
+ _ => arg.ty,
+ };
+ self.write_type(module, arg_ty)?;
+
+ let argument_name =
+ &self.names[&NameKey::FunctionArgument(handle, index as u32)];
+
+ // Write argument name. Space is important.
+ write!(self.out, " {}", argument_name)?;
+ if let TypeInner::Array { base, size, .. } = module.types[arg.ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+ }
+ }
+ back::FunctionType::EntryPoint(ep_index) => {
+ if let Some(ref ep_input) = self.entry_point_io[ep_index as usize].input {
+ write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name,)?;
+ } else {
+ let stage = module.entry_points[ep_index as usize].stage;
+ for (index, arg) in func.arguments.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ self.write_type(module, arg.ty)?;
+
+ let argument_name =
+ &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)];
+
+ write!(self.out, " {}", argument_name)?;
+ if let TypeInner::Array { base, size, .. } = module.types[arg.ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+
+ if let Some(ref binding) = arg.binding {
+ self.write_semantic(binding, Some((stage, Io::Input)))?;
+ }
+ }
+ }
+ }
+ }
+ // Ends of arguments
+ write!(self.out, ")")?;
+
+ // Write semantic if it present
+ if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
+ let stage = module.entry_points[index as usize].stage;
+ if let Some(crate::FunctionResult {
+ binding: Some(ref binding),
+ ..
+ }) = func.result
+ {
+ self.write_semantic(binding, Some((stage, Io::Output)))?;
+ }
+ }
+
+ // Function body start
+ writeln!(self.out)?;
+ writeln!(self.out, "{{")?;
+
+ if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
+ self.write_ep_arguments_initialization(module, func, index)?;
+ }
+
+ // Write function local variables
+ for (handle, local) in func.local_variables.iter() {
+ // Write indentation (only for readability)
+ write!(self.out, "{}", back::INDENT)?;
+
+ // Write the local name
+ // The leading space is important
+ self.write_type(module, local.ty)?;
+ write!(self.out, " {}", self.names[&func_ctx.name_key(handle)])?;
+ // Write size for array type
+ if let TypeInner::Array { base, size, .. } = module.types[local.ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+
+ write!(self.out, " = ")?;
+ // Write the local initializer if needed
+ if let Some(init) = local.init {
+ // Put the equal signal only if there's a initializer
+ // The leading and trailing spaces aren't needed but help with readability
+
+ // Write the constant
+ // `write_constant` adds no trailing or leading space/newline
+ self.write_constant(module, init)?;
+ } else {
+ // Zero initialize local variables
+ self.write_default_init(module, local.ty)?;
+ }
+
+ // Finish the local with `;` and add a newline (only for readability)
+ writeln!(self.out, ";")?
+ }
+
+ if !func.local_variables.is_empty() {
+ writeln!(self.out)?;
+ }
+
+ // Write the function body (statement list)
+ for sta in func.body.iter() {
+ // The indentation should always be 1 when writing the function body
+ self.write_stmt(module, sta, func_ctx, back::Level(1))?;
+ }
+
+ writeln!(self.out, "}}")?;
+
+ self.named_expressions.clear();
+
+ Ok(())
+ }
+
+ /// Helper method used to write statements
+ ///
+ /// # Notes
+ /// Always adds a newline
+ fn write_stmt(
+ &mut self,
+ module: &Module,
+ stmt: &crate::Statement,
+ func_ctx: &back::FunctionCtx<'_>,
+ level: back::Level,
+ ) -> BackendResult {
+ use crate::Statement;
+
+ match *stmt {
+ Statement::Emit(ref range) => {
+ for handle in range.clone() {
+ let info = &func_ctx.info[handle];
+ let ptr_class = info.ty.inner_with(&module.types).pointer_space();
+ let expr_name = if ptr_class.is_some() {
+ // HLSL can't save a pointer-valued expression in a variable,
+ // but we shouldn't ever need to: they should never be named expressions,
+ // and none of the expression types flagged by bake_ref_count can be pointer-valued.
+ None
+ } else if let Some(name) = func_ctx.named_expressions.get(&handle) {
+ // Front end provides names for all variables at the start of writing.
+ // But we write them to step by step. We need to recache them
+ // Otherwise, we could accidentally write variable name instead of full expression.
+ // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords.
+ Some(self.namer.call(name))
+ } else if info.ref_count == 0 {
+ Some(self.namer.call(""))
+ } else {
+ let min_ref_count = func_ctx.expressions[handle].bake_ref_count();
+ if min_ref_count <= info.ref_count {
+ Some(format!("_expr{}", handle.index()))
+ } else {
+ None
+ }
+ };
+
+ if let Some(name) = expr_name {
+ write!(self.out, "{}", level)?;
+ self.write_named_expr(module, handle, name, func_ctx)?;
+ }
+ }
+ }
+ // TODO: copy-paste from glsl-out
+ Statement::Block(ref block) => {
+ write!(self.out, "{}", level)?;
+ writeln!(self.out, "{{")?;
+ for sta in block.iter() {
+ // Increase the indentation to help with readability
+ self.write_stmt(module, sta, func_ctx, level.next())?
+ }
+ writeln!(self.out, "{}}}", level)?
+ }
+ // TODO: copy-paste from glsl-out
+ Statement::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ write!(self.out, "{}", level)?;
+ write!(self.out, "if (")?;
+ self.write_expr(module, condition, func_ctx)?;
+ writeln!(self.out, ") {{")?;
+
+ let l2 = level.next();
+ for sta in accept {
+ // Increase indentation to help with readability
+ self.write_stmt(module, sta, func_ctx, l2)?;
+ }
+
+ // If there are no statements in the reject block we skip writing it
+ // This is only for readability
+ if !reject.is_empty() {
+ writeln!(self.out, "{}}} else {{", level)?;
+
+ for sta in reject {
+ // Increase indentation to help with readability
+ self.write_stmt(module, sta, func_ctx, l2)?;
+ }
+ }
+
+ writeln!(self.out, "{}}}", level)?
+ }
+ // TODO: copy-paste from glsl-out
+ Statement::Kill => writeln!(self.out, "{}discard;", level)?,
+ Statement::Return { value: None } => {
+ writeln!(self.out, "{}return;", level)?;
+ }
+ Statement::Return { value: Some(expr) } => {
+ let base_ty_res = &func_ctx.info[expr].ty;
+ let mut resolved = base_ty_res.inner_with(&module.types);
+ if let TypeInner::Pointer { base, space: _ } = *resolved {
+ resolved = &module.types[base].inner;
+ }
+
+ if let TypeInner::Struct { .. } = *resolved {
+ // We can safely unwrap here, since we now we working with struct
+ let ty = base_ty_res.handle().unwrap();
+ let struct_name = &self.names[&NameKey::Type(ty)];
+ let variable_name = self.namer.call(&struct_name.to_lowercase());
+ write!(
+ self.out,
+ "{}const {} {} = ",
+ level, struct_name, variable_name,
+ )?;
+ self.write_expr(module, expr, func_ctx)?;
+ writeln!(self.out, ";")?;
+
+ // for entry point returns, we may need to reshuffle the outputs into a different struct
+ let ep_output = match func_ctx.ty {
+ back::FunctionType::Function(_) => None,
+ back::FunctionType::EntryPoint(index) => {
+ self.entry_point_io[index as usize].output.as_ref()
+ }
+ };
+ let final_name = match ep_output {
+ Some(ep_output) => {
+ let final_name = self.namer.call(&variable_name);
+ write!(
+ self.out,
+ "{}const {} {} = {{ ",
+ level, ep_output.ty_name, final_name,
+ )?;
+ for (index, m) in ep_output.members.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ let member_name = &self.names[&NameKey::StructMember(ty, m.index)];
+ write!(self.out, "{}.{}", variable_name, member_name)?;
+ }
+ writeln!(self.out, " }};")?;
+ final_name
+ }
+ None => variable_name,
+ };
+ writeln!(self.out, "{}return {};", level, final_name)?;
+ } else {
+ write!(self.out, "{}return ", level)?;
+ self.write_expr(module, expr, func_ctx)?;
+ writeln!(self.out, ";")?
+ }
+ }
+ Statement::Store { pointer, value } => {
+ let ty_inner = func_ctx.info[pointer].ty.inner_with(&module.types);
+ if let Some(crate::AddressSpace::Storage { .. }) = ty_inner.pointer_space() {
+ let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
+ self.write_storage_store(
+ module,
+ var_handle,
+ StoreValue::Expression(value),
+ func_ctx,
+ level,
+ )?;
+ } else {
+ // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
+ // See the module-level block comment in mod.rs for details.
+ //
+ // We handle matrix Stores here directly (including sub accesses for Vectors and Scalars).
+ // Loads are handled by `Expression::AccessIndex` (since sub accesses work fine for Loads).
+ struct MatrixAccess {
+ base: Handle<crate::Expression>,
+ index: u32,
+ }
+ enum Index {
+ Expression(Handle<crate::Expression>),
+ Static(u32),
+ }
+
+ let get_members = |expr: Handle<crate::Expression>| {
+ let base_ty_res = &func_ctx.info[expr].ty;
+ let resolved = base_ty_res.inner_with(&module.types);
+ match *resolved {
+ TypeInner::Pointer { base, .. } => match module.types[base].inner {
+ TypeInner::Struct { ref members, .. } => Some(members),
+ _ => None,
+ },
+ _ => None,
+ }
+ };
+
+ let mut matrix = None;
+ let mut vector = None;
+ let mut scalar = None;
+
+ let mut current_expr = pointer;
+ for _ in 0..3 {
+ let resolved = func_ctx.info[current_expr].ty.inner_with(&module.types);
+
+ match (resolved, &func_ctx.expressions[current_expr]) {
+ (
+ &TypeInner::Pointer { base: ty, .. },
+ &crate::Expression::AccessIndex { base, index },
+ ) if matches!(
+ module.types[ty].inner,
+ TypeInner::Matrix {
+ rows: crate::VectorSize::Bi,
+ ..
+ }
+ ) && get_members(base)
+ .map(|members| members[index as usize].binding.is_none())
+ == Some(true) =>
+ {
+ matrix = Some(MatrixAccess { base, index });
+ break;
+ }
+ (
+ &TypeInner::ValuePointer {
+ size: Some(crate::VectorSize::Bi),
+ ..
+ },
+ &crate::Expression::Access { base, index },
+ ) => {
+ vector = Some(Index::Expression(index));
+ current_expr = base;
+ }
+ (
+ &TypeInner::ValuePointer {
+ size: Some(crate::VectorSize::Bi),
+ ..
+ },
+ &crate::Expression::AccessIndex { base, index },
+ ) => {
+ vector = Some(Index::Static(index));
+ current_expr = base;
+ }
+ (
+ &TypeInner::ValuePointer { size: None, .. },
+ &crate::Expression::Access { base, index },
+ ) => {
+ scalar = Some(Index::Expression(index));
+ current_expr = base;
+ }
+ (
+ &TypeInner::ValuePointer { size: None, .. },
+ &crate::Expression::AccessIndex { base, index },
+ ) => {
+ scalar = Some(Index::Static(index));
+ current_expr = base;
+ }
+ _ => break,
+ }
+ }
+
+ write!(self.out, "{}", level)?;
+
+ if let Some(MatrixAccess { index, base }) = matrix {
+ let base_ty_res = &func_ctx.info[base].ty;
+ let resolved = base_ty_res.inner_with(&module.types);
+ let ty = match *resolved {
+ TypeInner::Pointer { base, .. } => base,
+ _ => base_ty_res.handle().unwrap(),
+ };
+
+ if let Some(Index::Static(vec_index)) = vector {
+ self.write_expr(module, base, func_ctx)?;
+ write!(
+ self.out,
+ ".{}_{}",
+ &self.names[&NameKey::StructMember(ty, index)],
+ vec_index
+ )?;
+
+ if let Some(scalar_index) = scalar {
+ write!(self.out, "[")?;
+ match scalar_index {
+ Index::Static(index) => {
+ write!(self.out, "{}", index)?;
+ }
+ Index::Expression(index) => {
+ self.write_expr(module, index, func_ctx)?;
+ }
+ }
+ write!(self.out, "]")?;
+ }
+
+ write!(self.out, " = ")?;
+ self.write_expr(module, value, func_ctx)?;
+ writeln!(self.out, ";")?;
+ } else {
+ let access = WrappedStructMatrixAccess { ty, index };
+ match (&vector, &scalar) {
+ (&Some(_), &Some(_)) => {
+ self.write_wrapped_struct_matrix_set_scalar_function_name(
+ access,
+ )?;
+ }
+ (&Some(_), &None) => {
+ self.write_wrapped_struct_matrix_set_vec_function_name(access)?;
+ }
+ (&None, _) => {
+ self.write_wrapped_struct_matrix_set_function_name(access)?;
+ }
+ }
+
+ write!(self.out, "(")?;
+ self.write_expr(module, base, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, value, func_ctx)?;
+
+ if let Some(Index::Expression(vec_index)) = vector {
+ write!(self.out, ", ")?;
+ self.write_expr(module, vec_index, func_ctx)?;
+
+ if let Some(scalar_index) = scalar {
+ write!(self.out, ", ")?;
+ match scalar_index {
+ Index::Static(index) => {
+ write!(self.out, "{}", index)?;
+ }
+ Index::Expression(index) => {
+ self.write_expr(module, index, func_ctx)?;
+ }
+ }
+ }
+ }
+ writeln!(self.out, ");")?;
+ }
+ } else {
+ // We handle `Store`s to __matCx2 column vectors and scalar elements via
+ // the previously injected functions __set_col_of_matCx2 / __set_el_of_matCx2.
+ struct MatrixData {
+ columns: crate::VectorSize,
+ base: Handle<crate::Expression>,
+ }
+
+ enum Index {
+ Expression(Handle<crate::Expression>),
+ Static(u32),
+ }
+
+ let mut matrix = None;
+ let mut vector = None;
+ let mut scalar = None;
+
+ let mut current_expr = pointer;
+ for _ in 0..3 {
+ let resolved = func_ctx.info[current_expr].ty.inner_with(&module.types);
+ match (resolved, &func_ctx.expressions[current_expr]) {
+ (
+ &TypeInner::ValuePointer {
+ size: Some(crate::VectorSize::Bi),
+ ..
+ },
+ &crate::Expression::Access { base, index },
+ ) => {
+ vector = Some(index);
+ current_expr = base;
+ }
+ (
+ &TypeInner::ValuePointer { size: None, .. },
+ &crate::Expression::Access { base, index },
+ ) => {
+ scalar = Some(Index::Expression(index));
+ current_expr = base;
+ }
+ (
+ &TypeInner::ValuePointer { size: None, .. },
+ &crate::Expression::AccessIndex { base, index },
+ ) => {
+ scalar = Some(Index::Static(index));
+ current_expr = base;
+ }
+ _ => {
+ if let Some(MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = get_inner_matrix_of_struct_array_member(
+ module,
+ current_expr,
+ func_ctx,
+ true,
+ ) {
+ matrix = Some(MatrixData {
+ columns,
+ base: current_expr,
+ });
+ }
+
+ break;
+ }
+ }
+ }
+
+ if let (Some(MatrixData { columns, base }), Some(vec_index)) =
+ (matrix, vector)
+ {
+ if scalar.is_some() {
+ write!(self.out, "__set_el_of_mat{}x2", columns as u8)?;
+ } else {
+ write!(self.out, "__set_col_of_mat{}x2", columns as u8)?;
+ }
+ write!(self.out, "(")?;
+ self.write_expr(module, base, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, vec_index, func_ctx)?;
+
+ if let Some(scalar_index) = scalar {
+ write!(self.out, ", ")?;
+ match scalar_index {
+ Index::Static(index) => {
+ write!(self.out, "{}", index)?;
+ }
+ Index::Expression(index) => {
+ self.write_expr(module, index, func_ctx)?;
+ }
+ }
+ }
+
+ write!(self.out, ", ")?;
+ self.write_expr(module, value, func_ctx)?;
+
+ writeln!(self.out, ");")?;
+ } else {
+ self.write_expr(module, pointer, func_ctx)?;
+ write!(self.out, " = ")?;
+
+ // We cast the RHS of this store in cases where the LHS
+ // is a struct member with type:
+ // - matCx2 or
+ // - a (possibly nested) array of matCx2's
+ if let Some(MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = get_inner_matrix_of_struct_array_member(
+ module, pointer, func_ctx, false,
+ ) {
+ let mut resolved =
+ func_ctx.info[pointer].ty.inner_with(&module.types);
+ if let TypeInner::Pointer { base, .. } = *resolved {
+ resolved = &module.types[base].inner;
+ }
+
+ write!(self.out, "(__mat{}x2", columns as u8)?;
+ if let TypeInner::Array { base, size, .. } = *resolved {
+ self.write_array_size(module, base, size)?;
+ }
+ write!(self.out, ")")?;
+ }
+
+ self.write_expr(module, value, func_ctx)?;
+ writeln!(self.out, ";")?
+ }
+ }
+ }
+ }
+ Statement::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ let l2 = level.next();
+ if !continuing.is_empty() || break_if.is_some() {
+ let gate_name = self.namer.call("loop_init");
+ writeln!(self.out, "{}bool {} = true;", level, gate_name)?;
+ writeln!(self.out, "{}while(true) {{", level)?;
+ writeln!(self.out, "{}if (!{}) {{", l2, gate_name)?;
+ let l3 = l2.next();
+ for sta in continuing.iter() {
+ self.write_stmt(module, sta, func_ctx, l3)?;
+ }
+ if let Some(condition) = break_if {
+ write!(self.out, "{}if (", l3)?;
+ self.write_expr(module, condition, func_ctx)?;
+ writeln!(self.out, ") {{")?;
+ writeln!(self.out, "{}break;", l3.next())?;
+ writeln!(self.out, "{}}}", l3)?;
+ }
+ writeln!(self.out, "{}}}", l2)?;
+ writeln!(self.out, "{}{} = false;", l2, gate_name)?;
+ } else {
+ writeln!(self.out, "{}while(true) {{", level)?;
+ }
+
+ for sta in body.iter() {
+ self.write_stmt(module, sta, func_ctx, l2)?;
+ }
+ writeln!(self.out, "{}}}", level)?
+ }
+ Statement::Break => writeln!(self.out, "{}break;", level)?,
+ Statement::Continue => writeln!(self.out, "{}continue;", level)?,
+ Statement::Barrier(barrier) => {
+ if barrier.contains(crate::Barrier::STORAGE) {
+ writeln!(self.out, "{}DeviceMemoryBarrierWithGroupSync();", level)?;
+ }
+
+ if barrier.contains(crate::Barrier::WORK_GROUP) {
+ writeln!(self.out, "{}GroupMemoryBarrierWithGroupSync();", level)?;
+ }
+ }
+ Statement::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => {
+ write!(self.out, "{}", level)?;
+ self.write_expr(module, image, func_ctx)?;
+
+ write!(self.out, "[")?;
+ if let Some(index) = array_index {
+ // Array index accepted only for texture_storage_2d_array, so we can safety use int3(coordinate, array_index) here
+ write!(self.out, "int3(")?;
+ self.write_expr(module, coordinate, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, index, func_ctx)?;
+ write!(self.out, ")")?;
+ } else {
+ self.write_expr(module, coordinate, func_ctx)?;
+ }
+ write!(self.out, "]")?;
+
+ write!(self.out, " = ")?;
+ self.write_expr(module, value, func_ctx)?;
+ writeln!(self.out, ";")?;
+ }
+ Statement::Call {
+ function,
+ ref arguments,
+ result,
+ } => {
+ write!(self.out, "{}", level)?;
+ if let Some(expr) = result {
+ write!(self.out, "const ")?;
+ let name = format!("{}{}", back::BAKE_PREFIX, expr.index());
+ let expr_ty = &func_ctx.info[expr].ty;
+ match *expr_ty {
+ proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
+ proc::TypeResolution::Value(ref value) => {
+ self.write_value_type(module, value)?
+ }
+ };
+ write!(self.out, " {} = ", name)?;
+ self.named_expressions.insert(expr, name);
+ }
+ let func_name = &self.names[&NameKey::Function(function)];
+ write!(self.out, "{}(", func_name)?;
+ for (index, argument) in arguments.iter().enumerate() {
+ self.write_expr(module, *argument, func_ctx)?;
+ // Only write a comma if isn't the last element
+ if index != arguments.len().saturating_sub(1) {
+ // The leading space is for readability only
+ write!(self.out, ", ")?;
+ }
+ }
+ writeln!(self.out, ");")?
+ }
+ Statement::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result,
+ } => {
+ write!(self.out, "{}", level)?;
+ let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ match func_ctx.info[result].ty {
+ proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
+ proc::TypeResolution::Value(ref value) => {
+ self.write_value_type(module, value)?
+ }
+ };
+
+ let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
+ // working around the borrow checker in `self.write_expr`
+ let chain = mem::take(&mut self.temp_access_chain);
+ let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
+
+ let fun_str = fun.to_hlsl_suffix();
+ write!(
+ self.out,
+ " {}; {}.Interlocked{}(",
+ res_name, var_name, fun_str
+ )?;
+ self.write_storage_address(module, &chain, func_ctx)?;
+ write!(self.out, ", ")?;
+ // handle the special cases
+ match *fun {
+ crate::AtomicFunction::Subtract => {
+ // we just wrote `InterlockedAdd`, so negate the argument
+ write!(self.out, "-")?;
+ }
+ crate::AtomicFunction::Exchange { compare: Some(_) } => {
+ return Err(Error::Unimplemented("atomic CompareExchange".to_string()));
+ }
+ _ => {}
+ }
+ self.write_expr(module, value, func_ctx)?;
+ writeln!(self.out, ", {});", res_name)?;
+ self.temp_access_chain = chain;
+ self.named_expressions.insert(result, res_name);
+ }
+ Statement::Switch {
+ selector,
+ ref cases,
+ } => {
+ // Start the switch
+ write!(self.out, "{}", level)?;
+ write!(self.out, "switch(")?;
+ self.write_expr(module, selector, func_ctx)?;
+ writeln!(self.out, ") {{")?;
+ let type_postfix = match *func_ctx.info[selector].ty.inner_with(&module.types) {
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Uint,
+ ..
+ } => "u",
+ _ => "",
+ };
+
+ // Write all cases
+ let indent_level_1 = level.next();
+ let indent_level_2 = indent_level_1.next();
+
+ for (i, case) in cases.iter().enumerate() {
+ match case.value {
+ crate::SwitchValue::Integer(value) => writeln!(
+ self.out,
+ "{}case {}{}: {{",
+ indent_level_1, value, type_postfix
+ )?,
+ crate::SwitchValue::Default => {
+ writeln!(self.out, "{}default: {{", indent_level_1)?
+ }
+ }
+
+ // FXC doesn't support fallthrough so we duplicate the body of the following case blocks
+ if case.fall_through {
+ let curr_len = i + 1;
+ let end_case_idx = curr_len
+ + cases
+ .iter()
+ .skip(curr_len)
+ .position(|case| !case.fall_through)
+ .unwrap();
+ let indent_level_3 = indent_level_2.next();
+ for case in &cases[i..=end_case_idx] {
+ writeln!(self.out, "{}{{", indent_level_2)?;
+ for sta in case.body.iter() {
+ self.write_stmt(module, sta, func_ctx, indent_level_3)?;
+ }
+ writeln!(self.out, "{}}}", indent_level_2)?;
+ }
+
+ let last_case = &cases[end_case_idx];
+ if last_case.body.last().map_or(true, |s| !s.is_terminator()) {
+ writeln!(self.out, "{}break;", indent_level_2)?;
+ }
+ } else {
+ for sta in case.body.iter() {
+ self.write_stmt(module, sta, func_ctx, indent_level_2)?;
+ }
+ if case.body.last().map_or(true, |s| !s.is_terminator()) {
+ writeln!(self.out, "{}break;", indent_level_2)?;
+ }
+ }
+
+ writeln!(self.out, "{}}}", indent_level_1)?;
+ }
+
+ writeln!(self.out, "{}}}", level)?
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Helper method to write expressions
+ ///
+ /// # Notes
+ /// Doesn't add any newlines or leading/trailing spaces
+ pub(super) fn write_expr(
+ &mut self,
+ module: &Module,
+ expr: Handle<crate::Expression>,
+ func_ctx: &back::FunctionCtx<'_>,
+ ) -> BackendResult {
+ use crate::Expression;
+
+ // Handle the special semantics for base vertex/instance
+ let ff_input = if self.options.special_constants_binding.is_some() {
+ func_ctx.is_fixed_function_input(expr, module)
+ } else {
+ None
+ };
+ let closing_bracket = match ff_input {
+ Some(crate::BuiltIn::VertexIndex) => {
+ write!(self.out, "({}.{} + ", SPECIAL_CBUF_VAR, SPECIAL_BASE_VERTEX)?;
+ ")"
+ }
+ Some(crate::BuiltIn::InstanceIndex) => {
+ write!(
+ self.out,
+ "({}.{} + ",
+ SPECIAL_CBUF_VAR, SPECIAL_BASE_INSTANCE,
+ )?;
+ ")"
+ }
+ Some(crate::BuiltIn::NumWorkGroups) => {
+ //Note: despite their names (`BASE_VERTEX` and `BASE_INSTANCE`),
+ // in compute shaders the special constants contain the number
+ // of workgroups, which we are using here.
+ write!(
+ self.out,
+ "uint3({}.{}, {}.{}, {}.{})",
+ SPECIAL_CBUF_VAR,
+ SPECIAL_BASE_VERTEX,
+ SPECIAL_CBUF_VAR,
+ SPECIAL_BASE_INSTANCE,
+ SPECIAL_CBUF_VAR,
+ SPECIAL_OTHER,
+ )?;
+ return Ok(());
+ }
+ _ => "",
+ };
+
+ if let Some(name) = self.named_expressions.get(&expr) {
+ write!(self.out, "{}{}", name, closing_bracket)?;
+ return Ok(());
+ }
+
+ let expression = &func_ctx.expressions[expr];
+
+ match *expression {
+ Expression::Constant(constant) => self.write_constant(module, constant)?,
+ Expression::Compose { ty, ref components } => {
+ match module.types[ty].inner {
+ TypeInner::Struct { .. } | TypeInner::Array { .. } => {
+ self.write_wrapped_constructor_function_name(
+ module,
+ WrappedConstructor { ty },
+ )?;
+ }
+ _ => {
+ self.write_type(module, ty)?;
+ }
+ };
+
+ write!(self.out, "(")?;
+
+ for (index, &component) in components.iter().enumerate() {
+ if index != 0 {
+ // The leading space is for readability only
+ write!(self.out, ", ")?;
+ }
+ self.write_expr(module, component, func_ctx)?;
+ }
+
+ write!(self.out, ")")?;
+ }
+ // All of the multiplication can be expressed as `mul`,
+ // except vector * vector, which needs to use the "*" operator.
+ Expression::Binary {
+ op: crate::BinaryOperator::Multiply,
+ left,
+ right,
+ } if func_ctx.info[left].ty.inner_with(&module.types).is_matrix()
+ || func_ctx.info[right]
+ .ty
+ .inner_with(&module.types)
+ .is_matrix() =>
+ {
+ // We intentionally flip the order of multiplication as our matrices are implicitly transposed.
+ write!(self.out, "mul(")?;
+ self.write_expr(module, right, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, left, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+
+ // TODO: handle undefined behavior of BinaryOperator::Modulo
+ //
+ // sint:
+ // if right == 0 return 0
+ // if left == min(type_of(left)) && right == -1 return 0
+ // if sign(left) != sign(right) return result as defined by WGSL
+ //
+ // uint:
+ // if right == 0 return 0
+ //
+ // float:
+ // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798
+
+ // While HLSL supports float operands with the % operator it is only
+ // defined in cases where both sides are either positive or negative.
+ Expression::Binary {
+ op: crate::BinaryOperator::Modulo,
+ left,
+ right,
+ } if func_ctx.info[left]
+ .ty
+ .inner_with(&module.types)
+ .scalar_kind()
+ == Some(crate::ScalarKind::Float) =>
+ {
+ write!(self.out, "fmod(")?;
+ self.write_expr(module, left, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, right, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ Expression::Binary { op, left, right } => {
+ write!(self.out, "(")?;
+ self.write_expr(module, left, func_ctx)?;
+ write!(self.out, " {} ", crate::back::binary_operation_str(op))?;
+ self.write_expr(module, right, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ Expression::Access { base, index } => {
+ if let Some(crate::AddressSpace::Storage { .. }) = func_ctx.info[expr]
+ .ty
+ .inner_with(&module.types)
+ .pointer_space()
+ {
+ // do nothing, the chain is written on `Load`/`Store`
+ } else {
+ // We use the function __get_col_of_matCx2 here in cases
+ // where `base`s type resolves to a matCx2 and is part of a
+ // struct member with type of (possibly nested) array of matCx2's.
+ //
+ // Note that this only works for `Load`s and we handle
+ // `Store`s differently in `Statement::Store`.
+ if let Some(MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
+ {
+ write!(self.out, "__get_col_of_mat{}x2(", columns as u8)?;
+ self.write_expr(module, base, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, index, func_ctx)?;
+ write!(self.out, ")")?;
+ return Ok(());
+ }
+
+ let base_ty_res = &func_ctx.info[base].ty;
+ let resolved = base_ty_res.inner_with(&module.types);
+
+ let non_uniform_qualifier = match *resolved {
+ TypeInner::BindingArray { .. } => {
+ let uniformity = &func_ctx.info[index].uniformity;
+
+ uniformity.non_uniform_result.is_some()
+ }
+ _ => false,
+ };
+
+ self.write_expr(module, base, func_ctx)?;
+ write!(self.out, "[")?;
+ if non_uniform_qualifier {
+ write!(self.out, "NonUniformResourceIndex(")?;
+ }
+ self.write_expr(module, index, func_ctx)?;
+ if non_uniform_qualifier {
+ write!(self.out, ")")?;
+ }
+ write!(self.out, "]")?;
+ }
+ }
+ Expression::AccessIndex { base, index } => {
+ if let Some(crate::AddressSpace::Storage { .. }) = func_ctx.info[expr]
+ .ty
+ .inner_with(&module.types)
+ .pointer_space()
+ {
+ // do nothing, the chain is written on `Load`/`Store`
+ } else {
+ fn write_access<W: fmt::Write>(
+ writer: &mut super::Writer<'_, W>,
+ resolved: &TypeInner,
+ base_ty_handle: Option<Handle<crate::Type>>,
+ index: u32,
+ ) -> BackendResult {
+ match *resolved {
+ // We specifcally lift the ValuePointer to this case. While `[0]` is valid
+ // HLSL for any vector behind a value pointer, FXC completely miscompiles
+ // it and generates completely nonsensical DXBC.
+ //
+ // See https://github.com/gfx-rs/naga/issues/2095 for more details.
+ TypeInner::Vector { .. } | TypeInner::ValuePointer { .. } => {
+ // Write vector access as a swizzle
+ write!(writer.out, ".{}", back::COMPONENTS[index as usize])?
+ }
+ TypeInner::Matrix { .. }
+ | TypeInner::Array { .. }
+ | TypeInner::BindingArray { .. } => write!(writer.out, "[{}]", index)?,
+ TypeInner::Struct { .. } => {
+ // This will never panic in case the type is a `Struct`, this is not true
+ // for other types so we can only check while inside this match arm
+ let ty = base_ty_handle.unwrap();
+
+ write!(
+ writer.out,
+ ".{}",
+ &writer.names[&NameKey::StructMember(ty, index)]
+ )?
+ }
+ ref other => {
+ return Err(Error::Custom(format!("Cannot index {:?}", other)))
+ }
+ }
+ Ok(())
+ }
+
+ // We write the matrix column access in a special way since
+ // the type of `base` is our special __matCx2 struct.
+ if let Some(MatrixType {
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ ..
+ }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
+ {
+ self.write_expr(module, base, func_ctx)?;
+ write!(self.out, "._{}", index)?;
+ return Ok(());
+ }
+
+ let base_ty_res = &func_ctx.info[base].ty;
+ let mut resolved = base_ty_res.inner_with(&module.types);
+ let base_ty_handle = match *resolved {
+ TypeInner::Pointer { base, .. } => {
+ resolved = &module.types[base].inner;
+ Some(base)
+ }
+ _ => base_ty_res.handle(),
+ };
+
+ // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
+ // See the module-level block comment in mod.rs for details.
+ //
+ // We handle matrix reconstruction here for Loads.
+ // Stores are handled directly by `Statement::Store`.
+ if let TypeInner::Struct { ref members, .. } = *resolved {
+ let member = &members[index as usize];
+
+ match module.types[member.ty].inner {
+ TypeInner::Matrix {
+ rows: crate::VectorSize::Bi,
+ ..
+ } if member.binding.is_none() => {
+ let ty = base_ty_handle.unwrap();
+ self.write_wrapped_struct_matrix_get_function_name(
+ WrappedStructMatrixAccess { ty, index },
+ )?;
+ write!(self.out, "(")?;
+ self.write_expr(module, base, func_ctx)?;
+ write!(self.out, ")")?;
+ return Ok(());
+ }
+ _ => {}
+ }
+ }
+
+ self.write_expr(module, base, func_ctx)?;
+ write_access(self, resolved, base_ty_handle, index)?;
+ }
+ }
+ Expression::FunctionArgument(pos) => {
+ let key = match func_ctx.ty {
+ back::FunctionType::Function(handle) => NameKey::FunctionArgument(handle, pos),
+ back::FunctionType::EntryPoint(index) => {
+ NameKey::EntryPointArgument(index, pos)
+ }
+ };
+ let name = &self.names[&key];
+ write!(self.out, "{}", name)?;
+ }
+ Expression::ImageSample {
+ image,
+ sampler,
+ gather,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ } => {
+ use crate::SampleLevel as Sl;
+ const COMPONENTS: [&str; 4] = ["", "Green", "Blue", "Alpha"];
+
+ let (base_str, component_str) = match gather {
+ Some(component) => ("Gather", COMPONENTS[component as usize]),
+ None => ("Sample", ""),
+ };
+ let cmp_str = match depth_ref {
+ Some(_) => "Cmp",
+ None => "",
+ };
+ let level_str = match level {
+ Sl::Zero if gather.is_none() => "LevelZero",
+ Sl::Auto | Sl::Zero => "",
+ Sl::Exact(_) => "Level",
+ Sl::Bias(_) => "Bias",
+ Sl::Gradient { .. } => "Grad",
+ };
+
+ self.write_expr(module, image, func_ctx)?;
+ write!(
+ self.out,
+ ".{}{}{}{}(",
+ base_str, cmp_str, component_str, level_str
+ )?;
+ self.write_expr(module, sampler, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_texture_coordinates(
+ "float",
+ coordinate,
+ array_index,
+ None,
+ module,
+ func_ctx,
+ )?;
+
+ if let Some(depth_ref) = depth_ref {
+ write!(self.out, ", ")?;
+ self.write_expr(module, depth_ref, func_ctx)?;
+ }
+
+ match level {
+ Sl::Auto | Sl::Zero => {}
+ Sl::Exact(expr) => {
+ write!(self.out, ", ")?;
+ self.write_expr(module, expr, func_ctx)?;
+ }
+ Sl::Bias(expr) => {
+ write!(self.out, ", ")?;
+ self.write_expr(module, expr, func_ctx)?;
+ }
+ Sl::Gradient { x, y } => {
+ write!(self.out, ", ")?;
+ self.write_expr(module, x, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, y, func_ctx)?;
+ }
+ }
+
+ if let Some(offset) = offset {
+ write!(self.out, ", ")?;
+ self.write_constant(module, offset)?;
+ }
+
+ write!(self.out, ")")?;
+ }
+ Expression::ImageQuery { image, query } => {
+ // use wrapped image query function
+ if let TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } = *func_ctx.info[image].ty.inner_with(&module.types)
+ {
+ let wrapped_image_query = WrappedImageQuery {
+ dim,
+ arrayed,
+ class,
+ query: query.into(),
+ };
+
+ self.write_wrapped_image_query_function_name(wrapped_image_query)?;
+ write!(self.out, "(")?;
+ // Image always first param
+ self.write_expr(module, image, func_ctx)?;
+ if let crate::ImageQuery::Size { level: Some(level) } = query {
+ write!(self.out, ", ")?;
+ self.write_expr(module, level, func_ctx)?;
+ }
+ write!(self.out, ")")?;
+ }
+ }
+ Expression::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => {
+ // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-load
+ self.write_expr(module, image, func_ctx)?;
+ write!(self.out, ".Load(")?;
+
+ self.write_texture_coordinates(
+ "int",
+ coordinate,
+ array_index,
+ level,
+ module,
+ func_ctx,
+ )?;
+
+ if let Some(sample) = sample {
+ write!(self.out, ", ")?;
+ self.write_expr(module, sample, func_ctx)?;
+ }
+
+ // close bracket for Load function
+ write!(self.out, ")")?;
+
+ // return x component if return type is scalar
+ if let TypeInner::Scalar { .. } = *func_ctx.info[expr].ty.inner_with(&module.types)
+ {
+ write!(self.out, ".x")?;
+ }
+ }
+ Expression::GlobalVariable(handle) => match module.global_variables[handle].space {
+ crate::AddressSpace::Storage { .. } => {}
+ _ => {
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ write!(self.out, "{}", name)?;
+ }
+ },
+ Expression::LocalVariable(handle) => {
+ write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
+ }
+ Expression::Load { pointer } => {
+ match func_ctx.info[pointer]
+ .ty
+ .inner_with(&module.types)
+ .pointer_space()
+ {
+ Some(crate::AddressSpace::Storage { .. }) => {
+ let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
+ let result_ty = func_ctx.info[expr].ty.clone();
+ self.write_storage_load(module, var_handle, result_ty, func_ctx)?;
+ }
+ _ => {
+ let mut close_paren = false;
+
+ // We cast the value loaded to a native HLSL floatCx2
+ // in cases where it is of type:
+ // - __matCx2 or
+ // - a (possibly nested) array of __matCx2's
+ if let Some(MatrixType {
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ ..
+ }) = get_inner_matrix_of_struct_array_member(
+ module, pointer, func_ctx, false,
+ )
+ .or_else(|| get_inner_matrix_of_global_uniform(module, pointer, func_ctx))
+ {
+ let mut resolved = func_ctx.info[pointer].ty.inner_with(&module.types);
+ if let TypeInner::Pointer { base, .. } = *resolved {
+ resolved = &module.types[base].inner;
+ }
+
+ write!(self.out, "((")?;
+ if let TypeInner::Array { base, size, .. } = *resolved {
+ self.write_type(module, base)?;
+ self.write_array_size(module, base, size)?;
+ } else {
+ self.write_value_type(module, resolved)?;
+ }
+ write!(self.out, ")")?;
+ close_paren = true;
+ }
+
+ self.write_expr(module, pointer, func_ctx)?;
+
+ if close_paren {
+ write!(self.out, ")")?;
+ }
+ }
+ }
+ }
+ Expression::Unary { op, expr } => {
+ use crate::{ScalarKind as Sk, UnaryOperator as Uo};
+ // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-operators#unary-operators
+ let op_str = match op {
+ Uo::Negate => "-",
+ Uo::Not => match func_ctx.info[expr]
+ .ty
+ .inner_with(&module.types)
+ .scalar_kind()
+ {
+ Some(Sk::Sint) | Some(Sk::Uint) => "~",
+ Some(Sk::Bool) => "!",
+ ref other => {
+ return Err(Error::Custom(format!(
+ "Cannot apply not to type {:?}",
+ other
+ )))
+ }
+ },
+ };
+ write!(self.out, "{}", op_str)?;
+ self.write_expr(module, expr, func_ctx)?;
+ }
+ Expression::As {
+ expr,
+ kind,
+ convert,
+ } => {
+ let inner = func_ctx.info[expr].ty.inner_with(&module.types);
+ match convert {
+ Some(dst_width) => {
+ match *inner {
+ TypeInner::Vector { size, .. } => {
+ write!(
+ self.out,
+ "{}{}(",
+ kind.to_hlsl_str(dst_width)?,
+ back::vector_size_str(size)
+ )?;
+ }
+ TypeInner::Scalar { .. } => {
+ write!(self.out, "{}(", kind.to_hlsl_str(dst_width)?,)?;
+ }
+ TypeInner::Matrix { columns, rows, .. } => {
+ write!(
+ self.out,
+ "{}{}x{}(",
+ kind.to_hlsl_str(dst_width)?,
+ back::vector_size_str(columns),
+ back::vector_size_str(rows)
+ )?;
+ }
+ _ => {
+ return Err(Error::Unimplemented(format!(
+ "write_expr expression::as {:?}",
+ inner
+ )));
+ }
+ };
+ }
+ None => {
+ write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
+ }
+ }
+ self.write_expr(module, expr, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ use crate::MathFunction as Mf;
+
+ enum Function {
+ Asincosh { is_sin: bool },
+ Atanh,
+ Unpack2x16float,
+ Regular(&'static str),
+ MissingIntOverload(&'static str),
+ }
+
+ let fun = match fun {
+ // comparison
+ Mf::Abs => Function::Regular("abs"),
+ Mf::Min => Function::Regular("min"),
+ Mf::Max => Function::Regular("max"),
+ Mf::Clamp => Function::Regular("clamp"),
+ Mf::Saturate => Function::Regular("saturate"),
+ // trigonometry
+ Mf::Cos => Function::Regular("cos"),
+ Mf::Cosh => Function::Regular("cosh"),
+ Mf::Sin => Function::Regular("sin"),
+ Mf::Sinh => Function::Regular("sinh"),
+ Mf::Tan => Function::Regular("tan"),
+ Mf::Tanh => Function::Regular("tanh"),
+ Mf::Acos => Function::Regular("acos"),
+ Mf::Asin => Function::Regular("asin"),
+ Mf::Atan => Function::Regular("atan"),
+ Mf::Atan2 => Function::Regular("atan2"),
+ Mf::Asinh => Function::Asincosh { is_sin: true },
+ Mf::Acosh => Function::Asincosh { is_sin: false },
+ Mf::Atanh => Function::Atanh,
+ Mf::Radians => Function::Regular("radians"),
+ Mf::Degrees => Function::Regular("degrees"),
+ // decomposition
+ Mf::Ceil => Function::Regular("ceil"),
+ Mf::Floor => Function::Regular("floor"),
+ Mf::Round => Function::Regular("round"),
+ Mf::Fract => Function::Regular("frac"),
+ Mf::Trunc => Function::Regular("trunc"),
+ Mf::Modf => Function::Regular("modf"),
+ Mf::Frexp => Function::Regular("frexp"),
+ Mf::Ldexp => Function::Regular("ldexp"),
+ // exponent
+ Mf::Exp => Function::Regular("exp"),
+ Mf::Exp2 => Function::Regular("exp2"),
+ Mf::Log => Function::Regular("log"),
+ Mf::Log2 => Function::Regular("log2"),
+ Mf::Pow => Function::Regular("pow"),
+ // geometry
+ Mf::Dot => Function::Regular("dot"),
+ //Mf::Outer => ,
+ Mf::Cross => Function::Regular("cross"),
+ Mf::Distance => Function::Regular("distance"),
+ Mf::Length => Function::Regular("length"),
+ Mf::Normalize => Function::Regular("normalize"),
+ Mf::FaceForward => Function::Regular("faceforward"),
+ Mf::Reflect => Function::Regular("reflect"),
+ Mf::Refract => Function::Regular("refract"),
+ // computational
+ Mf::Sign => Function::Regular("sign"),
+ Mf::Fma => Function::Regular("mad"),
+ Mf::Mix => Function::Regular("lerp"),
+ Mf::Step => Function::Regular("step"),
+ Mf::SmoothStep => Function::Regular("smoothstep"),
+ Mf::Sqrt => Function::Regular("sqrt"),
+ Mf::InverseSqrt => Function::Regular("rsqrt"),
+ //Mf::Inverse =>,
+ Mf::Transpose => Function::Regular("transpose"),
+ Mf::Determinant => Function::Regular("determinant"),
+ // bits
+ Mf::CountOneBits => Function::MissingIntOverload("countbits"),
+ Mf::ReverseBits => Function::MissingIntOverload("reversebits"),
+ Mf::FindLsb => Function::Regular("firstbitlow"),
+ Mf::FindMsb => Function::Regular("firstbithigh"),
+ Mf::Unpack2x16float => Function::Unpack2x16float,
+ _ => return Err(Error::Unimplemented(format!("write_expr_math {:?}", fun))),
+ };
+
+ match fun {
+ Function::Asincosh { is_sin } => {
+ write!(self.out, "log(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " + sqrt(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " * ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ match is_sin {
+ true => write!(self.out, " + 1.0))")?,
+ false => write!(self.out, " - 1.0))")?,
+ }
+ }
+ Function::Atanh => {
+ write!(self.out, "0.5 * log((1.0 + ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ") / (1.0 - ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "))")?;
+ }
+ Function::Unpack2x16float => {
+ write!(self.out, "float2(f16tof32(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "), f16tof32((")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ") >> 16))")?;
+ }
+ Function::Regular(fun_name) => {
+ write!(self.out, "{}(", fun_name)?;
+ self.write_expr(module, arg, func_ctx)?;
+ if let Some(arg) = arg1 {
+ write!(self.out, ", ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ }
+ if let Some(arg) = arg2 {
+ write!(self.out, ", ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ }
+ if let Some(arg) = arg3 {
+ write!(self.out, ", ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ }
+ write!(self.out, ")")?
+ }
+ Function::MissingIntOverload(fun_name) => {
+ let scalar_kind = &func_ctx.info[arg]
+ .ty
+ .inner_with(&module.types)
+ .scalar_kind();
+ if let Some(ScalarKind::Sint) = *scalar_kind {
+ write!(self.out, "asint({}(asuint(", fun_name)?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ")))")?;
+ } else {
+ write!(self.out, "{}(", fun_name)?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ }
+ }
+ }
+ Expression::Swizzle {
+ size,
+ vector,
+ pattern,
+ } => {
+ self.write_expr(module, vector, func_ctx)?;
+ write!(self.out, ".")?;
+ for &sc in pattern[..size as usize].iter() {
+ self.out.write_char(back::COMPONENTS[sc as usize])?;
+ }
+ }
+ Expression::ArrayLength(expr) => {
+ let var_handle = match func_ctx.expressions[expr] {
+ Expression::AccessIndex { base, index: _ } => {
+ match func_ctx.expressions[base] {
+ Expression::GlobalVariable(handle) => handle,
+ _ => unreachable!(),
+ }
+ }
+ Expression::GlobalVariable(handle) => handle,
+ _ => unreachable!(),
+ };
+
+ let var = &module.global_variables[var_handle];
+ let (offset, stride) = match module.types[var.ty].inner {
+ TypeInner::Array { stride, .. } => (0, stride),
+ TypeInner::Struct { ref members, .. } => {
+ let last = members.last().unwrap();
+ let stride = match module.types[last.ty].inner {
+ TypeInner::Array { stride, .. } => stride,
+ _ => unreachable!(),
+ };
+ (last.offset, stride)
+ }
+ _ => unreachable!(),
+ };
+
+ let storage_access = match var.space {
+ crate::AddressSpace::Storage { access } => access,
+ _ => crate::StorageAccess::default(),
+ };
+ let wrapped_array_length = WrappedArrayLength {
+ writable: storage_access.contains(crate::StorageAccess::STORE),
+ };
+
+ write!(self.out, "((")?;
+ self.write_wrapped_array_length_function_name(wrapped_array_length)?;
+ let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
+ write!(self.out, "({}) - {}) / {})", var_name, offset, stride)?
+ }
+ Expression::Derivative { axis, expr } => {
+ use crate::DerivativeAxis as Da;
+
+ let fun_str = match axis {
+ Da::X => "ddx",
+ Da::Y => "ddy",
+ Da::Width => "fwidth",
+ };
+ write!(self.out, "{}(", fun_str)?;
+ self.write_expr(module, expr, func_ctx)?;
+ write!(self.out, ")")?
+ }
+ Expression::Relational { fun, argument } => {
+ use crate::RelationalFunction as Rf;
+
+ let fun_str = match fun {
+ Rf::All => "all",
+ Rf::Any => "any",
+ Rf::IsNan => "isnan",
+ Rf::IsInf => "isinf",
+ Rf::IsFinite => "isfinite",
+ Rf::IsNormal => "isnormal",
+ };
+ write!(self.out, "{}(", fun_str)?;
+ self.write_expr(module, argument, func_ctx)?;
+ write!(self.out, ")")?
+ }
+ Expression::Splat { size, value } => {
+ // hlsl is not supported one value constructor
+ // if we write, for example, int4(0), dxc returns error:
+ // error: too few elements in vector initialization (expected 4 elements, have 1)
+ let number_of_components = match size {
+ crate::VectorSize::Bi => "xx",
+ crate::VectorSize::Tri => "xxx",
+ crate::VectorSize::Quad => "xxxx",
+ };
+ write!(self.out, "(")?;
+ self.write_expr(module, value, func_ctx)?;
+ write!(self.out, ").{}", number_of_components)?
+ }
+ Expression::Select {
+ condition,
+ accept,
+ reject,
+ } => {
+ write!(self.out, "(")?;
+ self.write_expr(module, condition, func_ctx)?;
+ write!(self.out, " ? ")?;
+ self.write_expr(module, accept, func_ctx)?;
+ write!(self.out, " : ")?;
+ self.write_expr(module, reject, func_ctx)?;
+ write!(self.out, ")")?
+ }
+ // Nothing to do here, since call expression already cached
+ Expression::CallResult(_) | Expression::AtomicResult { .. } => {}
+ }
+
+ if !closing_bracket.is_empty() {
+ write!(self.out, "{}", closing_bracket)?;
+ }
+ Ok(())
+ }
+
+ /// Helper method used to write constants
+ ///
+ /// # Notes
+ /// Doesn't add any newlines or leading/trailing spaces
+ fn write_constant(
+ &mut self,
+ module: &Module,
+ handle: Handle<crate::Constant>,
+ ) -> BackendResult {
+ let constant = &module.constants[handle];
+ match constant.inner {
+ crate::ConstantInner::Scalar {
+ width: _,
+ ref value,
+ } => {
+ if constant.name.is_some() {
+ write!(self.out, "{}", &self.names[&NameKey::Constant(handle)])?;
+ } else {
+ self.write_scalar_value(*value)?;
+ }
+ }
+ crate::ConstantInner::Composite { ty, ref components } => {
+ self.write_composite_constant(module, ty, components)?;
+ }
+ }
+
+ Ok(())
+ }
+
+ fn write_composite_constant(
+ &mut self,
+ module: &Module,
+ ty: Handle<crate::Type>,
+ components: &[Handle<crate::Constant>],
+ ) -> BackendResult {
+ match module.types[ty].inner {
+ TypeInner::Struct { .. } | TypeInner::Array { .. } => {
+ self.write_wrapped_constructor_function_name(module, WrappedConstructor { ty })?;
+ }
+ _ => {
+ self.write_type(module, ty)?;
+ }
+ };
+ write!(self.out, "(")?;
+ for (index, constant) in components.iter().enumerate() {
+ self.write_constant(module, *constant)?;
+ // Only write a comma if isn't the last element
+ if index != components.len().saturating_sub(1) {
+ // The leading space is for readability only
+ write!(self.out, ", ")?;
+ }
+ }
+ write!(self.out, ")")?;
+
+ Ok(())
+ }
+
+ /// Helper method used to write [`ScalarValue`](crate::ScalarValue)
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ fn write_scalar_value(&mut self, value: crate::ScalarValue) -> BackendResult {
+ use crate::ScalarValue as Sv;
+
+ match value {
+ Sv::Sint(value) => write!(self.out, "{}", value)?,
+ Sv::Uint(value) => write!(self.out, "{}u", value)?,
+ // Floats are written using `Debug` instead of `Display` because it always appends the
+ // decimal part even it's zero
+ Sv::Float(value) => write!(self.out, "{:?}", value)?,
+ Sv::Bool(value) => write!(self.out, "{}", value)?,
+ }
+
+ Ok(())
+ }
+
+ fn write_named_expr(
+ &mut self,
+ module: &Module,
+ handle: Handle<crate::Expression>,
+ name: String,
+ ctx: &back::FunctionCtx,
+ ) -> BackendResult {
+ match ctx.info[handle].ty {
+ proc::TypeResolution::Handle(ty_handle) => match module.types[ty_handle].inner {
+ TypeInner::Struct { .. } => {
+ let ty_name = &self.names[&NameKey::Type(ty_handle)];
+ write!(self.out, "{}", ty_name)?;
+ }
+ _ => {
+ self.write_type(module, ty_handle)?;
+ }
+ },
+ proc::TypeResolution::Value(ref inner) => {
+ self.write_value_type(module, inner)?;
+ }
+ }
+
+ let base_ty_res = &ctx.info[handle].ty;
+ let resolved = base_ty_res.inner_with(&module.types);
+
+ write!(self.out, " {}", name)?;
+ // If rhs is a array type, we should write array size
+ if let TypeInner::Array { base, size, .. } = *resolved {
+ self.write_array_size(module, base, size)?;
+ }
+ write!(self.out, " = ")?;
+ self.write_expr(module, handle, ctx)?;
+ writeln!(self.out, ";")?;
+ self.named_expressions.insert(handle, name);
+
+ Ok(())
+ }
+
+ /// Helper function that write default zero initialization
+ fn write_default_init(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
+ match module.types[ty].inner {
+ TypeInner::Array {
+ size: crate::ArraySize::Constant(const_handle),
+ base,
+ ..
+ } => {
+ write!(self.out, "{{")?;
+ let count = module.constants[const_handle].to_array_length().unwrap();
+ for i in 0..count {
+ if i != 0 {
+ write!(self.out, ",")?;
+ }
+ self.write_default_init(module, base)?;
+ }
+ write!(self.out, "}}")?;
+ }
+ _ => {
+ write!(self.out, "(")?;
+ self.write_type(module, ty)?;
+ write!(self.out, ")0")?;
+ }
+ }
+ Ok(())
+ }
+}
+
+pub(super) struct MatrixType {
+ pub(super) columns: crate::VectorSize,
+ pub(super) rows: crate::VectorSize,
+ pub(super) width: crate::Bytes,
+}
+
+pub(super) fn get_inner_matrix_data(
+ module: &Module,
+ handle: Handle<crate::Type>,
+) -> Option<MatrixType> {
+ match module.types[handle].inner {
+ TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => Some(MatrixType {
+ columns,
+ rows,
+ width,
+ }),
+ TypeInner::Array { base, .. } => get_inner_matrix_data(module, base),
+ _ => None,
+ }
+}
+
+/// Returns the matrix data if the access chain starting at `base`:
+/// - starts with an expression with resolved type of [`TypeInner::Matrix`] if `direct = true`
+/// - contains one or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`]
+/// - ends at an expression with resolved type of [`TypeInner::Struct`]
+pub(super) fn get_inner_matrix_of_struct_array_member(
+ module: &Module,
+ base: Handle<crate::Expression>,
+ func_ctx: &back::FunctionCtx<'_>,
+ direct: bool,
+) -> Option<MatrixType> {
+ let mut mat_data = None;
+ let mut array_base = None;
+
+ let mut current_base = base;
+ loop {
+ let mut resolved = func_ctx.info[current_base].ty.inner_with(&module.types);
+ if let TypeInner::Pointer { base, .. } = *resolved {
+ resolved = &module.types[base].inner;
+ };
+
+ match *resolved {
+ TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ mat_data = Some(MatrixType {
+ columns,
+ rows,
+ width,
+ })
+ }
+ TypeInner::Array { base, .. } => {
+ array_base = Some(base);
+ }
+ TypeInner::Struct { .. } => {
+ if let Some(array_base) = array_base {
+ if direct {
+ return mat_data;
+ } else {
+ return get_inner_matrix_data(module, array_base);
+ }
+ }
+
+ break;
+ }
+ _ => break,
+ }
+
+ current_base = match func_ctx.expressions[current_base] {
+ crate::Expression::Access { base, .. } => base,
+ crate::Expression::AccessIndex { base, .. } => base,
+ _ => break,
+ };
+ }
+ None
+}
+
+/// Returns the matrix data if the access chain starting at `base`:
+/// - starts with an expression with resolved type of [`TypeInner::Matrix`]
+/// - contains zero or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`]
+/// - ends with an [`Expression::GlobalVariable`](crate::Expression::GlobalVariable) in [`AddressSpace::Uniform`](crate::AddressSpace::Uniform)
+fn get_inner_matrix_of_global_uniform(
+ module: &Module,
+ base: Handle<crate::Expression>,
+ func_ctx: &back::FunctionCtx<'_>,
+) -> Option<MatrixType> {
+ let mut mat_data = None;
+ let mut array_base = None;
+
+ let mut current_base = base;
+ loop {
+ let mut resolved = func_ctx.info[current_base].ty.inner_with(&module.types);
+ if let TypeInner::Pointer { base, .. } = *resolved {
+ resolved = &module.types[base].inner;
+ };
+
+ match *resolved {
+ TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ mat_data = Some(MatrixType {
+ columns,
+ rows,
+ width,
+ })
+ }
+ TypeInner::Array { base, .. } => {
+ array_base = Some(base);
+ }
+ _ => break,
+ }
+
+ current_base = match func_ctx.expressions[current_base] {
+ crate::Expression::Access { base, .. } => base,
+ crate::Expression::AccessIndex { base, .. } => base,
+ crate::Expression::GlobalVariable(handle)
+ if module.global_variables[handle].space == crate::AddressSpace::Uniform =>
+ {
+ return mat_data.or_else(|| {
+ array_base.and_then(|array_base| get_inner_matrix_data(module, array_base))
+ })
+ }
+ _ => break,
+ };
+ }
+ None
+}
diff --git a/third_party/rust/naga/src/back/mod.rs b/third_party/rust/naga/src/back/mod.rs
new file mode 100644
index 0000000000..d8e016c008
--- /dev/null
+++ b/third_party/rust/naga/src/back/mod.rs
@@ -0,0 +1,209 @@
+/*!
+Backend functions that export shader [`Module`](super::Module)s into binary and text formats.
+*/
+#![allow(dead_code)] // can be dead if none of the enabled backends need it
+
+#[cfg(feature = "dot-out")]
+pub mod dot;
+#[cfg(feature = "glsl-out")]
+pub mod glsl;
+#[cfg(feature = "hlsl-out")]
+pub mod hlsl;
+#[cfg(feature = "msl-out")]
+pub mod msl;
+#[cfg(feature = "spv-out")]
+pub mod spv;
+#[cfg(feature = "wgsl-out")]
+pub mod wgsl;
+
+const COMPONENTS: &[char] = &['x', 'y', 'z', 'w'];
+const INDENT: &str = " ";
+const BAKE_PREFIX: &str = "_e";
+
+type NeedBakeExpressions = crate::FastHashSet<crate::Handle<crate::Expression>>;
+
+#[derive(Clone, Copy)]
+struct Level(usize);
+
+impl Level {
+ const fn next(&self) -> Self {
+ Level(self.0 + 1)
+ }
+}
+
+impl std::fmt::Display for Level {
+ fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
+ (0..self.0).try_for_each(|_| formatter.write_str(INDENT))
+ }
+}
+
+/// Stores the current function type (either a regular function or an entry point)
+///
+/// Also stores data needed to identify it (handle for a regular function or index for an entry point)
+enum FunctionType {
+ /// A regular function and it's handle
+ Function(crate::Handle<crate::Function>),
+ /// A entry point and it's index
+ EntryPoint(crate::proc::EntryPointIndex),
+}
+
+/// Helper structure that stores data needed when writing the function
+struct FunctionCtx<'a> {
+ /// The current function being written
+ ty: FunctionType,
+ /// Analysis about the function
+ info: &'a crate::valid::FunctionInfo,
+ /// The expression arena of the current function being written
+ expressions: &'a crate::Arena<crate::Expression>,
+ /// Map of expressions that have associated variable names
+ named_expressions: &'a crate::NamedExpressions,
+}
+
+impl FunctionCtx<'_> {
+ /// Helper method that generates a [`NameKey`](crate::proc::NameKey) for a local in the current function
+ const fn name_key(&self, local: crate::Handle<crate::LocalVariable>) -> crate::proc::NameKey {
+ match self.ty {
+ FunctionType::Function(handle) => crate::proc::NameKey::FunctionLocal(handle, local),
+ FunctionType::EntryPoint(idx) => crate::proc::NameKey::EntryPointLocal(idx, local),
+ }
+ }
+
+ /// Helper method that generates a [`NameKey`](crate::proc::NameKey) for a function argument.
+ ///
+ /// # Panics
+ /// - If the function arguments are less or equal to `arg`
+ const fn argument_key(&self, arg: u32) -> crate::proc::NameKey {
+ match self.ty {
+ FunctionType::Function(handle) => crate::proc::NameKey::FunctionArgument(handle, arg),
+ FunctionType::EntryPoint(ep_index) => {
+ crate::proc::NameKey::EntryPointArgument(ep_index, arg)
+ }
+ }
+ }
+
+ // Returns true if the given expression points to a fixed-function pipeline input.
+ fn is_fixed_function_input(
+ &self,
+ mut expression: crate::Handle<crate::Expression>,
+ module: &crate::Module,
+ ) -> Option<crate::BuiltIn> {
+ let ep_function = match self.ty {
+ FunctionType::Function(_) => return None,
+ FunctionType::EntryPoint(ep_index) => &module.entry_points[ep_index as usize].function,
+ };
+ let mut built_in = None;
+ loop {
+ match self.expressions[expression] {
+ crate::Expression::FunctionArgument(arg_index) => {
+ return match ep_function.arguments[arg_index as usize].binding {
+ Some(crate::Binding::BuiltIn(bi)) => Some(bi),
+ _ => built_in,
+ };
+ }
+ crate::Expression::AccessIndex { base, index } => {
+ match *self.info[base].ty.inner_with(&module.types) {
+ crate::TypeInner::Struct { ref members, .. } => {
+ if let Some(crate::Binding::BuiltIn(bi)) =
+ members[index as usize].binding
+ {
+ built_in = Some(bi);
+ }
+ }
+ _ => return None,
+ }
+ expression = base;
+ }
+ _ => return None,
+ }
+ }
+ }
+}
+
+impl crate::Expression {
+ /// Returns the ref count, upon reaching which this expression
+ /// should be considered for baking.
+ ///
+ /// Note: we have to cache any expressions that depend on the control flow,
+ /// or otherwise they may be moved into a non-uniform control flow, accidentally.
+ /// See the [module-level documentation][emit] for details.
+ ///
+ /// [emit]: index.html#expression-evaluation-time
+ const fn bake_ref_count(&self) -> usize {
+ match *self {
+ // accesses are never cached, only loads are
+ crate::Expression::Access { .. } | crate::Expression::AccessIndex { .. } => usize::MAX,
+ // sampling may use the control flow, and image ops look better by themselves
+ crate::Expression::ImageSample { .. } | crate::Expression::ImageLoad { .. } => 1,
+ // derivatives use the control flow
+ crate::Expression::Derivative { .. } => 1,
+ // TODO: We need a better fix for named `Load` expressions
+ // More info - https://github.com/gfx-rs/naga/pull/914
+ // And https://github.com/gfx-rs/naga/issues/910
+ crate::Expression::Load { .. } => 1,
+ // cache expressions that are referenced multiple times
+ _ => 2,
+ }
+ }
+}
+
+/// Helper function that returns the string corresponding to the [`BinaryOperator`](crate::BinaryOperator)
+/// # Notes
+/// Used by `glsl-out`, `msl-out`, `wgsl-out`, `hlsl-out`.
+const fn binary_operation_str(op: crate::BinaryOperator) -> &'static str {
+ use crate::BinaryOperator as Bo;
+ match op {
+ Bo::Add => "+",
+ Bo::Subtract => "-",
+ Bo::Multiply => "*",
+ Bo::Divide => "/",
+ Bo::Modulo => "%",
+ Bo::Equal => "==",
+ Bo::NotEqual => "!=",
+ Bo::Less => "<",
+ Bo::LessEqual => "<=",
+ Bo::Greater => ">",
+ Bo::GreaterEqual => ">=",
+ Bo::And => "&",
+ Bo::ExclusiveOr => "^",
+ Bo::InclusiveOr => "|",
+ Bo::LogicalAnd => "&&",
+ Bo::LogicalOr => "||",
+ Bo::ShiftLeft => "<<",
+ Bo::ShiftRight => ">>",
+ }
+}
+
+/// Helper function that returns the string corresponding to the [`VectorSize`](crate::VectorSize)
+/// # Notes
+/// Used by `msl-out`, `wgsl-out`, `hlsl-out`.
+const fn vector_size_str(size: crate::VectorSize) -> &'static str {
+ match size {
+ crate::VectorSize::Bi => "2",
+ crate::VectorSize::Tri => "3",
+ crate::VectorSize::Quad => "4",
+ }
+}
+
+impl crate::TypeInner {
+ const fn is_handle(&self) -> bool {
+ match *self {
+ crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } => true,
+ _ => false,
+ }
+ }
+}
+
+impl crate::Statement {
+ /// Returns true if the statement directly terminates the current block.
+ ///
+ /// Used to decide whether case blocks require a explicit `break`.
+ pub const fn is_terminator(&self) -> bool {
+ match *self {
+ crate::Statement::Break
+ | crate::Statement::Continue
+ | crate::Statement::Return { .. }
+ | crate::Statement::Kill => true,
+ _ => false,
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/back/msl/keywords.rs b/third_party/rust/naga/src/back/msl/keywords.rs
new file mode 100644
index 0000000000..a3a9c52dcc
--- /dev/null
+++ b/third_party/rust/naga/src/back/msl/keywords.rs
@@ -0,0 +1,217 @@
+//TODO: find a complete list
+pub const RESERVED: &[&str] = &[
+ // control flow
+ "break",
+ "if",
+ "else",
+ "continue",
+ "goto",
+ "do",
+ "while",
+ "for",
+ "switch",
+ "case",
+ // types and values
+ "void",
+ "unsigned",
+ "signed",
+ "bool",
+ "char",
+ "int",
+ "uint",
+ "long",
+ "float",
+ "double",
+ "char8_t",
+ "wchar_t",
+ "true",
+ "false",
+ "nullptr",
+ "union",
+ "class",
+ "struct",
+ "enum",
+ // other
+ "main",
+ "using",
+ "decltype",
+ "sizeof",
+ "typeof",
+ "typedef",
+ "explicit",
+ "export",
+ "friend",
+ "namespace",
+ "operator",
+ "public",
+ "template",
+ "typename",
+ "typeid",
+ "co_await",
+ "co_return",
+ "co_yield",
+ "module",
+ "import",
+ "ray_data",
+ "vec_step",
+ "visible",
+ "as_type",
+ "this",
+ // qualifiers
+ "mutable",
+ "static",
+ "volatile",
+ "restrict",
+ "const",
+ "non-temporal",
+ "dereferenceable",
+ "invariant",
+ // exceptions
+ "throw",
+ "try",
+ "catch",
+ // operators
+ "const_cast",
+ "dynamic_cast",
+ "reinterpret_cast",
+ "static_cast",
+ "new",
+ "delete",
+ "and",
+ "and_eq",
+ "bitand",
+ "bitor",
+ "compl",
+ "not",
+ "not_eq",
+ "or",
+ "or_eq",
+ "xor",
+ "xor_eq",
+ "compl",
+ // Metal-specific
+ "constant",
+ "device",
+ "threadgroup",
+ "threadgroup_imageblock",
+ "kernel",
+ "compute",
+ "vertex",
+ "fragment",
+ "read_only",
+ "write_only",
+ "read_write",
+ "auto",
+ // Metal reserved types
+ "llong",
+ "ullong",
+ "quad",
+ "complex",
+ "imaginary",
+ // Metal constants
+ "CHAR_BIT",
+ "SCHAR_MAX",
+ "SCHAR_MIN",
+ "UCHAR_MAX",
+ "CHAR_MAX",
+ "CHAR_MIN",
+ "USHRT_MAX",
+ "SHRT_MAX",
+ "SHRT_MIN",
+ "UINT_MAX",
+ "INT_MAX",
+ "INT_MIN",
+ "ULONG_MAX",
+ "LONG_MAX",
+ "LONG_MIN",
+ "ULLONG_MAX",
+ "LLONG_MAX",
+ "LLONG_MIN",
+ "FLT_DIG",
+ "FLT_MANT_DIG",
+ "FLT_MAX_10_EXP",
+ "FLT_MAX_EXP",
+ "FLT_MIN_10_EXP",
+ "FLT_MIN_EXP",
+ "FLT_RADIX",
+ "FLT_MAX",
+ "FLT_MIN",
+ "FLT_EPSILON",
+ "FLT_DECIMAL_DIG",
+ "FP_ILOGB0",
+ "FP_ILOGB0",
+ "FP_ILOGBNAN",
+ "FP_ILOGBNAN",
+ "MAXFLOAT",
+ "HUGE_VALF",
+ "INFINITY",
+ "NAN",
+ "M_E_F",
+ "M_LOG2E_F",
+ "M_LOG10E_F",
+ "M_LN2_F",
+ "M_LN10_F",
+ "M_PI_F",
+ "M_PI_2_F",
+ "M_PI_4_F",
+ "M_1_PI_F",
+ "M_2_PI_F",
+ "M_2_SQRTPI_F",
+ "M_SQRT2_F",
+ "M_SQRT1_2_F",
+ "HALF_DIG",
+ "HALF_MANT_DIG",
+ "HALF_MAX_10_EXP",
+ "HALF_MAX_EXP",
+ "HALF_MIN_10_EXP",
+ "HALF_MIN_EXP",
+ "HALF_RADIX",
+ "HALF_MAX",
+ "HALF_MIN",
+ "HALF_EPSILON",
+ "HALF_DECIMAL_DIG",
+ "MAXHALF",
+ "HUGE_VALH",
+ "M_E_H",
+ "M_LOG2E_H",
+ "M_LOG10E_H",
+ "M_LN2_H",
+ "M_LN10_H",
+ "M_PI_H",
+ "M_PI_2_H",
+ "M_PI_4_H",
+ "M_1_PI_H",
+ "M_2_PI_H",
+ "M_2_SQRTPI_H",
+ "M_SQRT2_H",
+ "M_SQRT1_2_H",
+ "DBL_DIG",
+ "DBL_MANT_DIG",
+ "DBL_MAX_10_EXP",
+ "DBL_MAX_EXP",
+ "DBL_MIN_10_EXP",
+ "DBL_MIN_EXP",
+ "DBL_RADIX",
+ "DBL_MAX",
+ "DBL_MIN",
+ "DBL_EPSILON",
+ "DBL_DECIMAL_DIG",
+ "MAXDOUBLE",
+ "HUGE_VAL",
+ "M_E",
+ "M_LOG2E",
+ "M_LOG10E",
+ "M_LN2",
+ "M_LN10",
+ "M_PI",
+ "M_PI_2",
+ "M_PI_4",
+ "M_1_PI",
+ "M_2_PI",
+ "M_2_SQRTPI",
+ "M_SQRT2",
+ "M_SQRT1_2",
+ // Naga utilities
+ "DefaultConstructible",
+ "clamped_lod_e",
+];
diff --git a/third_party/rust/naga/src/back/msl/mod.rs b/third_party/rust/naga/src/back/msl/mod.rs
new file mode 100644
index 0000000000..a8ed2dd0d5
--- /dev/null
+++ b/third_party/rust/naga/src/back/msl/mod.rs
@@ -0,0 +1,497 @@
+/*!
+Backend for [MSL][msl] (Metal Shading Language).
+
+## Binding model
+
+Metal's bindings are flat per resource. Since there isn't an obvious mapping
+from SPIR-V's descriptor sets, we require a separate mapping provided in the options.
+This mapping may have one or more resource end points for each descriptor set + index
+pair.
+
+## Entry points
+
+Even though MSL and our IR appear to be similar in that the entry points in both can
+accept arguments and return values, the restrictions are different.
+MSL allows the varyings to be either in separate arguments, or inside a single
+`[[stage_in]]` struct. We gather input varyings and form this artificial structure.
+We also add all the (non-Private) globals into the arguments.
+
+At the beginning of the entry point, we assign the local constants and re-compose
+the arguments as they are declared on IR side, so that the rest of the logic can
+pretend that MSL doesn't have all the restrictions it has.
+
+For the result type, if it's a structure, we re-compose it with a temporary value
+holding the result.
+
+[msl]: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
+*/
+
+use crate::{arena::Handle, proc::index, valid::ModuleInfo};
+use std::{
+ fmt::{Error as FmtError, Write},
+ ops,
+};
+
+mod keywords;
+pub mod sampler;
+mod writer;
+
+pub use writer::Writer;
+
+pub type Slot = u8;
+pub type InlineSamplerIndex = u8;
+
+#[derive(Clone, Debug, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub enum BindSamplerTarget {
+ Resource(Slot),
+ Inline(InlineSamplerIndex),
+}
+
+#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))]
+pub struct BindTarget {
+ pub buffer: Option<Slot>,
+ pub texture: Option<Slot>,
+ pub sampler: Option<BindSamplerTarget>,
+ /// If the binding is an unsized binding array, this overrides the size.
+ pub binding_array_size: Option<u32>,
+ pub mutable: bool,
+}
+
+// Using `BTreeMap` instead of `HashMap` so that we can hash itself.
+pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, BindTarget>;
+
+#[derive(Clone, Debug, Default, Hash, Eq, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))]
+pub struct PerStageResources {
+ pub resources: BindingMap,
+
+ pub push_constant_buffer: Option<Slot>,
+
+ /// The slot of a buffer that contains an array of `u32`,
+ /// one for the size of each bound buffer that contains a runtime array,
+ /// in order of [`crate::GlobalVariable`] declarations.
+ pub sizes_buffer: Option<Slot>,
+}
+
+#[derive(Clone, Debug, Default, Hash, Eq, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))]
+pub struct PerStageMap {
+ pub vs: PerStageResources,
+ pub fs: PerStageResources,
+ pub cs: PerStageResources,
+}
+
+impl ops::Index<crate::ShaderStage> for PerStageMap {
+ type Output = PerStageResources;
+ fn index(&self, stage: crate::ShaderStage) -> &PerStageResources {
+ match stage {
+ crate::ShaderStage::Vertex => &self.vs,
+ crate::ShaderStage::Fragment => &self.fs,
+ crate::ShaderStage::Compute => &self.cs,
+ }
+ }
+}
+
+enum ResolvedBinding {
+ BuiltIn(crate::BuiltIn),
+ Attribute(u32),
+ Color(u32),
+ User {
+ prefix: &'static str,
+ index: u32,
+ interpolation: Option<ResolvedInterpolation>,
+ },
+ Resource(BindTarget),
+}
+
+#[derive(Copy, Clone)]
+enum ResolvedInterpolation {
+ CenterPerspective,
+ CenterNoPerspective,
+ CentroidPerspective,
+ CentroidNoPerspective,
+ SamplePerspective,
+ SampleNoPerspective,
+ Flat,
+}
+
+// Note: some of these should be removed in favor of proper IR validation.
+
+#[derive(Debug, thiserror::Error)]
+pub enum Error {
+ #[error(transparent)]
+ Format(#[from] FmtError),
+ #[error("bind target {0:?} is empty")]
+ UnimplementedBindTarget(BindTarget),
+ #[error("composing of {0:?} is not implemented yet")]
+ UnsupportedCompose(Handle<crate::Type>),
+ #[error("operation {0:?} is not implemented yet")]
+ UnsupportedBinaryOp(crate::BinaryOperator),
+ #[error("standard function '{0}' is not implemented yet")]
+ UnsupportedCall(String),
+ #[error("feature '{0}' is not implemented yet")]
+ FeatureNotImplemented(String),
+ #[error("module is not valid")]
+ Validation,
+ #[error("BuiltIn {0:?} is not supported")]
+ UnsupportedBuiltIn(crate::BuiltIn),
+ #[error("capability {0:?} is not supported")]
+ CapabilityNotSupported(crate::valid::Capabilities),
+ #[error("address space {0:?} is not supported for target MSL version")]
+ UnsupportedAddressSpace(crate::AddressSpace),
+ #[error("attribute '{0}' is not supported for target MSL version")]
+ UnsupportedAttribute(String),
+}
+
+#[derive(Clone, Debug, PartialEq, thiserror::Error)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub enum EntryPointError {
+ #[error("mapping of {0:?} is missing")]
+ MissingBinding(crate::ResourceBinding),
+ #[error("mapping for push constants is missing")]
+ MissingPushConstants,
+ #[error("mapping for sizes buffer is missing")]
+ MissingSizesBuffer,
+}
+
+/// Points in the MSL code where we might emit a pipeline input or output.
+///
+/// Note that, even though vertex shaders' outputs are always fragment
+/// shaders' inputs, we still need to distinguish `VertexOutput` and
+/// `FragmentInput`, since there are certain differences in the way
+/// [`ResolvedBinding`s] are represented on either side.
+///
+/// [`ResolvedBinding`s]: ResolvedBinding
+#[derive(Clone, Copy, Debug)]
+enum LocationMode {
+ /// Input to the vertex shader.
+ VertexInput,
+
+ /// Output from the vertex shader.
+ VertexOutput,
+
+ /// Input to the fragment shader.
+ FragmentInput,
+
+ /// Output from the fragment shader.
+ FragmentOutput,
+
+ /// Compute shader input or output.
+ Uniform,
+}
+
+#[derive(Clone, Debug, Hash, PartialEq, Eq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct Options {
+ /// (Major, Minor) target version of the Metal Shading Language.
+ pub lang_version: (u8, u8),
+ /// Map of per-stage resources to slots.
+ pub per_stage_map: PerStageMap,
+ /// Samplers to be inlined into the code.
+ pub inline_samplers: Vec<sampler::InlineSampler>,
+ /// Make it possible to link different stages via SPIRV-Cross.
+ pub spirv_cross_compatibility: bool,
+ /// Don't panic on missing bindings, instead generate invalid MSL.
+ pub fake_missing_bindings: bool,
+ /// Bounds checking policies.
+ #[cfg_attr(feature = "deserialize", serde(default))]
+ pub bounds_check_policies: index::BoundsCheckPolicies,
+}
+
+impl Default for Options {
+ fn default() -> Self {
+ Options {
+ lang_version: (2, 0),
+ per_stage_map: PerStageMap::default(),
+ inline_samplers: Vec::new(),
+ spirv_cross_compatibility: false,
+ fake_missing_bindings: true,
+ bounds_check_policies: index::BoundsCheckPolicies::default(),
+ }
+ }
+}
+
+/// A subset of options that are meant to be changed per pipeline.
+#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct PipelineOptions {
+ /// Allow `BuiltIn::PointSize` in the vertex shader.
+ ///
+ /// Metal doesn't like this for non-point primitive topologies.
+ pub allow_point_size: bool,
+}
+
+impl Options {
+ fn resolve_local_binding(
+ &self,
+ binding: &crate::Binding,
+ mode: LocationMode,
+ ) -> Result<ResolvedBinding, Error> {
+ match *binding {
+ crate::Binding::BuiltIn(mut built_in) => {
+ if let crate::BuiltIn::Position { ref mut invariant } = built_in {
+ if *invariant && self.lang_version < (2, 1) {
+ return Err(Error::UnsupportedAttribute("invariant".to_string()));
+ }
+
+ // The 'invariant' attribute may only appear on vertex
+ // shader outputs, not fragment shader inputs.
+ if !matches!(mode, LocationMode::VertexOutput) {
+ *invariant = false;
+ }
+ }
+
+ Ok(ResolvedBinding::BuiltIn(built_in))
+ }
+ crate::Binding::Location {
+ location,
+ interpolation,
+ sampling,
+ } => match mode {
+ LocationMode::VertexInput => Ok(ResolvedBinding::Attribute(location)),
+ LocationMode::FragmentOutput => Ok(ResolvedBinding::Color(location)),
+ LocationMode::VertexOutput | LocationMode::FragmentInput => {
+ Ok(ResolvedBinding::User {
+ prefix: if self.spirv_cross_compatibility {
+ "locn"
+ } else {
+ "loc"
+ },
+ index: location,
+ interpolation: {
+ // unwrap: The verifier ensures that vertex shader outputs and fragment
+ // shader inputs always have fully specified interpolation, and that
+ // sampling is `None` only for Flat interpolation.
+ let interpolation = interpolation.unwrap();
+ let sampling = sampling.unwrap_or(crate::Sampling::Center);
+ Some(ResolvedInterpolation::from_binding(interpolation, sampling))
+ },
+ })
+ }
+ LocationMode::Uniform => {
+ log::error!(
+ "Unexpected Binding::Location({}) for the Uniform mode",
+ location
+ );
+ Err(Error::Validation)
+ }
+ },
+ }
+ }
+
+ fn resolve_resource_binding(
+ &self,
+ stage: crate::ShaderStage,
+ res_binding: &crate::ResourceBinding,
+ ) -> Result<ResolvedBinding, EntryPointError> {
+ match self.per_stage_map[stage].resources.get(res_binding) {
+ Some(target) => Ok(ResolvedBinding::Resource(target.clone())),
+ None if self.fake_missing_bindings => Ok(ResolvedBinding::User {
+ prefix: "fake",
+ index: 0,
+ interpolation: None,
+ }),
+ None => Err(EntryPointError::MissingBinding(res_binding.clone())),
+ }
+ }
+
+ const fn resolve_push_constants(
+ &self,
+ stage: crate::ShaderStage,
+ ) -> Result<ResolvedBinding, EntryPointError> {
+ let slot = match stage {
+ crate::ShaderStage::Vertex => self.per_stage_map.vs.push_constant_buffer,
+ crate::ShaderStage::Fragment => self.per_stage_map.fs.push_constant_buffer,
+ crate::ShaderStage::Compute => self.per_stage_map.cs.push_constant_buffer,
+ };
+ match slot {
+ Some(slot) => Ok(ResolvedBinding::Resource(BindTarget {
+ buffer: Some(slot),
+ texture: None,
+ sampler: None,
+ binding_array_size: None,
+ mutable: false,
+ })),
+ None if self.fake_missing_bindings => Ok(ResolvedBinding::User {
+ prefix: "fake",
+ index: 0,
+ interpolation: None,
+ }),
+ None => Err(EntryPointError::MissingPushConstants),
+ }
+ }
+
+ fn resolve_sizes_buffer(
+ &self,
+ stage: crate::ShaderStage,
+ ) -> Result<ResolvedBinding, EntryPointError> {
+ let slot = self.per_stage_map[stage].sizes_buffer;
+ match slot {
+ Some(slot) => Ok(ResolvedBinding::Resource(BindTarget {
+ buffer: Some(slot),
+ texture: None,
+ sampler: None,
+ binding_array_size: None,
+ mutable: false,
+ })),
+ None if self.fake_missing_bindings => Ok(ResolvedBinding::User {
+ prefix: "fake",
+ index: 0,
+ interpolation: None,
+ }),
+ None => Err(EntryPointError::MissingSizesBuffer),
+ }
+ }
+}
+
+impl ResolvedBinding {
+ fn as_inline_sampler<'a>(&self, options: &'a Options) -> Option<&'a sampler::InlineSampler> {
+ match *self {
+ Self::Resource(BindTarget {
+ sampler: Some(BindSamplerTarget::Inline(index)),
+ ..
+ }) => Some(&options.inline_samplers[index as usize]),
+ _ => None,
+ }
+ }
+
+ const fn as_bind_target(&self) -> Option<&BindTarget> {
+ match *self {
+ Self::Resource(ref target) => Some(target),
+ _ => None,
+ }
+ }
+
+ fn try_fmt<W: Write>(&self, out: &mut W) -> Result<(), Error> {
+ write!(out, " [[")?;
+ match *self {
+ Self::BuiltIn(built_in) => {
+ use crate::BuiltIn as Bi;
+ let name = match built_in {
+ Bi::Position { invariant: false } => "position",
+ Bi::Position { invariant: true } => "position, invariant",
+ // vertex
+ Bi::BaseInstance => "base_instance",
+ Bi::BaseVertex => "base_vertex",
+ Bi::ClipDistance => "clip_distance",
+ Bi::InstanceIndex => "instance_id",
+ Bi::PointSize => "point_size",
+ Bi::VertexIndex => "vertex_id",
+ // fragment
+ Bi::FragDepth => "depth(any)",
+ Bi::FrontFacing => "front_facing",
+ Bi::PrimitiveIndex => "primitive_id",
+ Bi::SampleIndex => "sample_id",
+ Bi::SampleMask => "sample_mask",
+ // compute
+ Bi::GlobalInvocationId => "thread_position_in_grid",
+ Bi::LocalInvocationId => "thread_position_in_threadgroup",
+ Bi::LocalInvocationIndex => "thread_index_in_threadgroup",
+ Bi::WorkGroupId => "threadgroup_position_in_grid",
+ Bi::WorkGroupSize => "dispatch_threads_per_threadgroup",
+ Bi::NumWorkGroups => "threadgroups_per_grid",
+ Bi::CullDistance | Bi::ViewIndex => {
+ return Err(Error::UnsupportedBuiltIn(built_in))
+ }
+ };
+ write!(out, "{}", name)?;
+ }
+ Self::Attribute(index) => write!(out, "attribute({})", index)?,
+ Self::Color(index) => write!(out, "color({})", index)?,
+ Self::User {
+ prefix,
+ index,
+ interpolation,
+ } => {
+ write!(out, "user({}{})", prefix, index)?;
+ if let Some(interpolation) = interpolation {
+ write!(out, ", ")?;
+ interpolation.try_fmt(out)?;
+ }
+ }
+ Self::Resource(ref target) => {
+ if let Some(id) = target.buffer {
+ write!(out, "buffer({})", id)?;
+ } else if let Some(id) = target.texture {
+ write!(out, "texture({})", id)?;
+ } else if let Some(BindSamplerTarget::Resource(id)) = target.sampler {
+ write!(out, "sampler({})", id)?;
+ } else {
+ return Err(Error::UnimplementedBindTarget(target.clone()));
+ }
+ }
+ }
+ write!(out, "]]")?;
+ Ok(())
+ }
+}
+
+impl ResolvedInterpolation {
+ const fn from_binding(interpolation: crate::Interpolation, sampling: crate::Sampling) -> Self {
+ use crate::Interpolation as I;
+ use crate::Sampling as S;
+
+ match (interpolation, sampling) {
+ (I::Perspective, S::Center) => Self::CenterPerspective,
+ (I::Perspective, S::Centroid) => Self::CentroidPerspective,
+ (I::Perspective, S::Sample) => Self::SamplePerspective,
+ (I::Linear, S::Center) => Self::CenterNoPerspective,
+ (I::Linear, S::Centroid) => Self::CentroidNoPerspective,
+ (I::Linear, S::Sample) => Self::SampleNoPerspective,
+ (I::Flat, _) => Self::Flat,
+ }
+ }
+
+ fn try_fmt<W: Write>(self, out: &mut W) -> Result<(), Error> {
+ let identifier = match self {
+ Self::CenterPerspective => "center_perspective",
+ Self::CenterNoPerspective => "center_no_perspective",
+ Self::CentroidPerspective => "centroid_perspective",
+ Self::CentroidNoPerspective => "centroid_no_perspective",
+ Self::SamplePerspective => "sample_perspective",
+ Self::SampleNoPerspective => "sample_no_perspective",
+ Self::Flat => "flat",
+ };
+ out.write_str(identifier)?;
+ Ok(())
+ }
+}
+
+/// Information about a translated module that is required
+/// for the use of the result.
+pub struct TranslationInfo {
+ /// Mapping of the entry point names. Each item in the array
+ /// corresponds to an entry point index.
+ ///
+ ///Note: Some entry points may fail translation because of missing bindings.
+ pub entry_point_names: Vec<Result<String, EntryPointError>>,
+}
+
+pub fn write_string(
+ module: &crate::Module,
+ info: &ModuleInfo,
+ options: &Options,
+ pipeline_options: &PipelineOptions,
+) -> Result<(String, TranslationInfo), Error> {
+ let mut w = writer::Writer::new(String::new());
+ let info = w.write(module, info, options, pipeline_options)?;
+ Ok((w.finish(), info))
+}
+
+#[test]
+fn test_error_size() {
+ use std::mem::size_of;
+ assert_eq!(size_of::<Error>(), 32);
+}
diff --git a/third_party/rust/naga/src/back/msl/sampler.rs b/third_party/rust/naga/src/back/msl/sampler.rs
new file mode 100644
index 0000000000..3b95fa3781
--- /dev/null
+++ b/third_party/rust/naga/src/back/msl/sampler.rs
@@ -0,0 +1,175 @@
+#[cfg(feature = "deserialize")]
+use serde::Deserialize;
+#[cfg(feature = "serialize")]
+use serde::Serialize;
+use std::{num::NonZeroU32, ops::Range};
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum Coord {
+ Normalized,
+ Pixel,
+}
+
+impl Default for Coord {
+ fn default() -> Self {
+ Self::Normalized
+ }
+}
+
+impl Coord {
+ pub const fn as_str(&self) -> &'static str {
+ match *self {
+ Self::Normalized => "normalized",
+ Self::Pixel => "pixel",
+ }
+ }
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum Address {
+ Repeat,
+ MirroredRepeat,
+ ClampToEdge,
+ ClampToZero,
+ ClampToBorder,
+}
+
+impl Default for Address {
+ fn default() -> Self {
+ Self::ClampToEdge
+ }
+}
+
+impl Address {
+ pub const fn as_str(&self) -> &'static str {
+ match *self {
+ Self::Repeat => "repeat",
+ Self::MirroredRepeat => "mirrored_repeat",
+ Self::ClampToEdge => "clamp_to_edge",
+ Self::ClampToZero => "clamp_to_zero",
+ Self::ClampToBorder => "clamp_to_border",
+ }
+ }
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum BorderColor {
+ TransparentBlack,
+ OpaqueBlack,
+ OpaqueWhite,
+}
+
+impl Default for BorderColor {
+ fn default() -> Self {
+ Self::TransparentBlack
+ }
+}
+
+impl BorderColor {
+ pub const fn as_str(&self) -> &'static str {
+ match *self {
+ Self::TransparentBlack => "transparent_black",
+ Self::OpaqueBlack => "opaque_black",
+ Self::OpaqueWhite => "opaque_white",
+ }
+ }
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum Filter {
+ Nearest,
+ Linear,
+}
+
+impl Filter {
+ pub const fn as_str(&self) -> &'static str {
+ match *self {
+ Self::Nearest => "nearest",
+ Self::Linear => "linear",
+ }
+ }
+}
+
+impl Default for Filter {
+ fn default() -> Self {
+ Self::Nearest
+ }
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum CompareFunc {
+ Never,
+ Less,
+ LessEqual,
+ Greater,
+ GreaterEqual,
+ Equal,
+ NotEqual,
+ Always,
+}
+
+impl Default for CompareFunc {
+ fn default() -> Self {
+ Self::Never
+ }
+}
+
+impl CompareFunc {
+ pub const fn as_str(&self) -> &'static str {
+ match *self {
+ Self::Never => "never",
+ Self::Less => "less",
+ Self::LessEqual => "less_equal",
+ Self::Greater => "greater",
+ Self::GreaterEqual => "greater_equal",
+ Self::Equal => "equal",
+ Self::NotEqual => "not_equal",
+ Self::Always => "always",
+ }
+ }
+}
+
+#[derive(Clone, Debug, Default, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub struct InlineSampler {
+ pub coord: Coord,
+ pub address: [Address; 3],
+ pub border_color: BorderColor,
+ pub mag_filter: Filter,
+ pub min_filter: Filter,
+ pub mip_filter: Option<Filter>,
+ pub lod_clamp: Option<Range<f32>>,
+ pub max_anisotropy: Option<NonZeroU32>,
+ pub compare_func: CompareFunc,
+}
+
+impl Eq for InlineSampler {}
+
+#[allow(clippy::derive_hash_xor_eq)]
+impl std::hash::Hash for InlineSampler {
+ fn hash<H: std::hash::Hasher>(&self, hasher: &mut H) {
+ self.coord.hash(hasher);
+ self.address.hash(hasher);
+ self.border_color.hash(hasher);
+ self.mag_filter.hash(hasher);
+ self.min_filter.hash(hasher);
+ self.mip_filter.hash(hasher);
+ self.lod_clamp
+ .as_ref()
+ .map(|range| (range.start.to_bits(), range.end.to_bits()))
+ .hash(hasher);
+ self.max_anisotropy.hash(hasher);
+ self.compare_func.hash(hasher);
+ }
+}
diff --git a/third_party/rust/naga/src/back/msl/writer.rs b/third_party/rust/naga/src/back/msl/writer.rs
new file mode 100644
index 0000000000..9147fbe398
--- /dev/null
+++ b/third_party/rust/naga/src/back/msl/writer.rs
@@ -0,0 +1,3985 @@
+use super::{sampler as sm, Error, LocationMode, Options, PipelineOptions, TranslationInfo};
+use crate::{
+ arena::Handle,
+ back,
+ proc::index,
+ proc::{self, NameKey, TypeResolution},
+ valid, FastHashMap, FastHashSet,
+};
+use bit_set::BitSet;
+use std::{
+ fmt::{Display, Error as FmtError, Formatter, Write},
+ iter,
+};
+
+/// Shorthand result used internally by the backend
+type BackendResult = Result<(), Error>;
+
+const NAMESPACE: &str = "metal";
+// The name of the array member of the Metal struct types we generate to
+// represent Naga `Array` types. See the comments in `Writer::write_type_defs`
+// for details.
+const WRAPPED_ARRAY_FIELD: &str = "inner";
+// This is a hack: we need to pass a pointer to an atomic,
+// but generally the backend isn't putting "&" in front of every pointer.
+// Some more general handling of pointers is needed to be implemented here.
+const ATOMIC_REFERENCE: &str = "&";
+
+/// Write the Metal name for a Naga numeric type: scalar, vector, or matrix.
+///
+/// The `sizes` slice determines whether this function writes a
+/// scalar, vector, or matrix type:
+///
+/// - An empty slice produces a scalar type.
+/// - A one-element slice produces a vector type.
+/// - A two element slice `[ROWS COLUMNS]` produces a matrix of the given size.
+fn put_numeric_type(
+ out: &mut impl Write,
+ kind: crate::ScalarKind,
+ sizes: &[crate::VectorSize],
+) -> Result<(), FmtError> {
+ match (kind, sizes) {
+ (kind, &[]) => {
+ write!(out, "{}", kind.to_msl_name())
+ }
+ (kind, &[rows]) => {
+ write!(
+ out,
+ "{}::{}{}",
+ NAMESPACE,
+ kind.to_msl_name(),
+ back::vector_size_str(rows)
+ )
+ }
+ (kind, &[rows, columns]) => {
+ write!(
+ out,
+ "{}::{}{}x{}",
+ NAMESPACE,
+ kind.to_msl_name(),
+ back::vector_size_str(columns),
+ back::vector_size_str(rows)
+ )
+ }
+ (_, _) => Ok(()), // not meaningful
+ }
+}
+
+/// Prefix for cached clamped level-of-detail values for `ImageLoad` expressions.
+const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e";
+
+struct TypeContext<'a> {
+ handle: Handle<crate::Type>,
+ module: &'a crate::Module,
+ names: &'a FastHashMap<NameKey, String>,
+ access: crate::StorageAccess,
+ binding: Option<&'a super::ResolvedBinding>,
+ first_time: bool,
+}
+
+impl<'a> Display for TypeContext<'a> {
+ fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> {
+ let ty = &self.module.types[self.handle];
+ if ty.needs_alias() && !self.first_time {
+ let name = &self.names[&NameKey::Type(self.handle)];
+ return write!(out, "{}", name);
+ }
+
+ match ty.inner {
+ crate::TypeInner::Scalar { kind, .. } => put_numeric_type(out, kind, &[]),
+ crate::TypeInner::Atomic { kind, .. } => {
+ write!(out, "{}::atomic_{}", NAMESPACE, kind.to_msl_name())
+ }
+ crate::TypeInner::Vector { size, kind, .. } => put_numeric_type(out, kind, &[size]),
+ crate::TypeInner::Matrix { columns, rows, .. } => {
+ put_numeric_type(out, crate::ScalarKind::Float, &[rows, columns])
+ }
+ crate::TypeInner::Pointer { base, space } => {
+ let sub = Self {
+ handle: base,
+ first_time: false,
+ ..*self
+ };
+ let space_name = match space.to_msl_name() {
+ Some(name) => name,
+ None => return Ok(()),
+ };
+ write!(out, "{} {}&", space_name, sub)
+ }
+ crate::TypeInner::ValuePointer {
+ size,
+ kind,
+ width: _,
+ space,
+ } => {
+ match space.to_msl_name() {
+ Some(name) => write!(out, "{} ", name)?,
+ None => return Ok(()),
+ };
+ match size {
+ Some(rows) => put_numeric_type(out, kind, &[rows])?,
+ None => put_numeric_type(out, kind, &[])?,
+ };
+
+ write!(out, "&")
+ }
+ crate::TypeInner::Array { base, .. } => {
+ let sub = Self {
+ handle: base,
+ first_time: false,
+ ..*self
+ };
+ // Array lengths go at the end of the type definition,
+ // so just print the element type here.
+ write!(out, "{}", sub)
+ }
+ crate::TypeInner::Struct { .. } => unreachable!(),
+ crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ let dim_str = match dim {
+ crate::ImageDimension::D1 => "1d",
+ crate::ImageDimension::D2 => "2d",
+ crate::ImageDimension::D3 => "3d",
+ crate::ImageDimension::Cube => "cube",
+ };
+ let (texture_str, msaa_str, kind, access) = match class {
+ crate::ImageClass::Sampled { kind, multi } => {
+ let (msaa_str, access) = if multi {
+ ("_ms", "read")
+ } else {
+ ("", "sample")
+ };
+ ("texture", msaa_str, kind, access)
+ }
+ crate::ImageClass::Depth { multi } => {
+ let (msaa_str, access) = if multi {
+ ("_ms", "read")
+ } else {
+ ("", "sample")
+ };
+ ("depth", msaa_str, crate::ScalarKind::Float, access)
+ }
+ crate::ImageClass::Storage { format, .. } => {
+ let access = if self
+ .access
+ .contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE)
+ {
+ "read_write"
+ } else if self.access.contains(crate::StorageAccess::STORE) {
+ "write"
+ } else if self.access.contains(crate::StorageAccess::LOAD) {
+ "read"
+ } else {
+ log::warn!(
+ "Storage access for {:?} (name '{}'): {:?}",
+ self.handle,
+ ty.name.as_deref().unwrap_or_default(),
+ self.access
+ );
+ unreachable!("module is not valid");
+ };
+ ("texture", "", format.into(), access)
+ }
+ };
+ let base_name = kind.to_msl_name();
+ let array_str = if arrayed { "_array" } else { "" };
+ write!(
+ out,
+ "{}::{}{}{}{}<{}, {}::access::{}>",
+ NAMESPACE,
+ texture_str,
+ dim_str,
+ msaa_str,
+ array_str,
+ base_name,
+ NAMESPACE,
+ access,
+ )
+ }
+ crate::TypeInner::Sampler { comparison: _ } => {
+ write!(out, "{}::sampler", NAMESPACE)
+ }
+ crate::TypeInner::BindingArray { base, size } => {
+ let base_tyname = Self {
+ handle: base,
+ first_time: false,
+ ..*self
+ };
+
+ if let Some(&super::ResolvedBinding::Resource(super::BindTarget {
+ binding_array_size: Some(override_size),
+ ..
+ })) = self.binding
+ {
+ write!(
+ out,
+ "{}::array<{}, {}>",
+ NAMESPACE, base_tyname, override_size
+ )
+ } else if let crate::ArraySize::Constant(size) = size {
+ let constant_ctx = ConstantContext {
+ handle: size,
+ arena: &self.module.constants,
+ names: self.names,
+ first_time: false,
+ };
+ write!(
+ out,
+ "{}::array<{}, {}>",
+ NAMESPACE, base_tyname, constant_ctx
+ )
+ } else {
+ unreachable!("metal requires all arrays be constant sized");
+ }
+ }
+ }
+ }
+}
+
+struct TypedGlobalVariable<'a> {
+ module: &'a crate::Module,
+ names: &'a FastHashMap<NameKey, String>,
+ handle: Handle<crate::GlobalVariable>,
+ usage: valid::GlobalUse,
+ binding: Option<&'a super::ResolvedBinding>,
+ reference: bool,
+}
+
+impl<'a> TypedGlobalVariable<'a> {
+ fn try_fmt<W: Write>(&self, out: &mut W) -> BackendResult {
+ let var = &self.module.global_variables[self.handle];
+ let name = &self.names[&NameKey::GlobalVariable(self.handle)];
+
+ let storage_access = match var.space {
+ crate::AddressSpace::Storage { access } => access,
+ _ => match self.module.types[var.ty].inner {
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Storage { access, .. },
+ ..
+ } => access,
+ crate::TypeInner::BindingArray { base, .. } => {
+ match self.module.types[base].inner {
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Storage { access, .. },
+ ..
+ } => access,
+ _ => crate::StorageAccess::default(),
+ }
+ }
+ _ => crate::StorageAccess::default(),
+ },
+ };
+ let ty_name = TypeContext {
+ handle: var.ty,
+ module: self.module,
+ names: self.names,
+ access: storage_access,
+ binding: self.binding,
+ first_time: false,
+ };
+
+ let (space, access, reference) = match var.space.to_msl_name() {
+ Some(space) if self.reference => {
+ let access = if var.space.needs_access_qualifier()
+ && !self.usage.contains(valid::GlobalUse::WRITE)
+ {
+ "const"
+ } else {
+ ""
+ };
+ (space, access, "&")
+ }
+ _ => ("", "", ""),
+ };
+
+ Ok(write!(
+ out,
+ "{}{}{}{}{}{} {}",
+ space,
+ if space.is_empty() { "" } else { " " },
+ ty_name,
+ if access.is_empty() { "" } else { " " },
+ access,
+ reference,
+ name,
+ )?)
+ }
+}
+
+struct ConstantContext<'a> {
+ handle: Handle<crate::Constant>,
+ arena: &'a crate::Arena<crate::Constant>,
+ names: &'a FastHashMap<NameKey, String>,
+ first_time: bool,
+}
+
+impl<'a> Display for ConstantContext<'a> {
+ fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> {
+ let con = &self.arena[self.handle];
+ if con.needs_alias() && !self.first_time {
+ let name = &self.names[&NameKey::Constant(self.handle)];
+ return write!(out, "{}", name);
+ }
+
+ match con.inner {
+ crate::ConstantInner::Scalar { value, width: _ } => match value {
+ crate::ScalarValue::Sint(value) => {
+ write!(out, "{}", value)
+ }
+ crate::ScalarValue::Uint(value) => {
+ write!(out, "{}u", value)
+ }
+ crate::ScalarValue::Float(value) => {
+ if value.is_infinite() {
+ let sign = if value.is_sign_negative() { "-" } else { "" };
+ write!(out, "{}INFINITY", sign)
+ } else if value.is_nan() {
+ write!(out, "NAN")
+ } else {
+ let suffix = if value.fract() == 0.0 { ".0" } else { "" };
+
+ write!(out, "{}{}", value, suffix)
+ }
+ }
+ crate::ScalarValue::Bool(value) => {
+ write!(out, "{}", value)
+ }
+ },
+ crate::ConstantInner::Composite { .. } => unreachable!("should be aliased"),
+ }
+ }
+}
+
+pub struct Writer<W> {
+ out: W,
+ names: FastHashMap<NameKey, String>,
+ named_expressions: crate::NamedExpressions,
+ /// Set of expressions that need to be baked to avoid unnecessary repetition in output
+ need_bake_expressions: back::NeedBakeExpressions,
+ namer: proc::Namer,
+ #[cfg(test)]
+ put_expression_stack_pointers: FastHashSet<*const ()>,
+ #[cfg(test)]
+ put_block_stack_pointers: FastHashSet<*const ()>,
+ /// Set of (struct type, struct field index) denoting which fields require
+ /// padding inserted **before** them (i.e. between fields at index - 1 and index)
+ struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,
+}
+
+impl crate::ScalarKind {
+ const fn to_msl_name(self) -> &'static str {
+ match self {
+ Self::Float => "float",
+ Self::Sint => "int",
+ Self::Uint => "uint",
+ Self::Bool => "bool",
+ }
+ }
+}
+
+const fn separate(need_separator: bool) -> &'static str {
+ if need_separator {
+ ","
+ } else {
+ ""
+ }
+}
+
+fn should_pack_struct_member(
+ members: &[crate::StructMember],
+ span: u32,
+ index: usize,
+ module: &crate::Module,
+) -> Option<crate::ScalarKind> {
+ let member = &members[index];
+ //Note: this is imperfect - the same structure can be used for host-shared
+ // things, where packed float would matter.
+ if member.binding.is_some() {
+ return None;
+ }
+
+ let ty_inner = &module.types[member.ty].inner;
+ let last_offset = member.offset + ty_inner.size(&module.constants);
+ let next_offset = match members.get(index + 1) {
+ Some(next) => next.offset,
+ None => span,
+ };
+ let is_tight = next_offset == last_offset;
+
+ match *ty_inner {
+ crate::TypeInner::Vector {
+ size: crate::VectorSize::Tri,
+ width: 4,
+ kind,
+ } if member.offset & 0xF != 0 || is_tight => Some(kind),
+ _ => None,
+ }
+}
+
+fn needs_array_length(ty: Handle<crate::Type>, arena: &crate::UniqueArena<crate::Type>) -> bool {
+ match arena[ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ if let Some(member) = members.last() {
+ if let crate::TypeInner::Array {
+ size: crate::ArraySize::Dynamic,
+ ..
+ } = arena[member.ty].inner
+ {
+ return true;
+ }
+ }
+ false
+ }
+ crate::TypeInner::Array {
+ size: crate::ArraySize::Dynamic,
+ ..
+ } => true,
+ _ => false,
+ }
+}
+
+impl crate::AddressSpace {
+ /// Returns true if global variables in this address space are
+ /// passed in function arguments. These arguments need to be
+ /// passed through any functions called from the entry point.
+ const fn needs_pass_through(&self) -> bool {
+ match *self {
+ Self::Uniform
+ | Self::Storage { .. }
+ | Self::Private
+ | Self::WorkGroup
+ | Self::PushConstant
+ | Self::Handle => true,
+ Self::Function => false,
+ }
+ }
+
+ /// Returns true if the address space may need a "const" qualifier.
+ const fn needs_access_qualifier(&self) -> bool {
+ match *self {
+ //Note: we are ignoring the storage access here, and instead
+ // rely on the actual use of a global by functions. This means we
+ // may end up with "const" even if the binding is read-write,
+ // and that should be OK.
+ Self::Storage { .. } => true,
+ // These should always be read-write.
+ Self::Private | Self::WorkGroup => false,
+ // These translate to `constant` address space, no need for qualifiers.
+ Self::Uniform | Self::PushConstant => false,
+ // Not applicable.
+ Self::Handle | Self::Function => false,
+ }
+ }
+
+ const fn to_msl_name(self) -> Option<&'static str> {
+ match self {
+ Self::Handle => None,
+ Self::Uniform | Self::PushConstant => Some("constant"),
+ Self::Storage { .. } => Some("device"),
+ Self::Private | Self::Function => Some("thread"),
+ Self::WorkGroup => Some("threadgroup"),
+ }
+ }
+}
+
+impl crate::Type {
+ // Returns `true` if we need to emit an alias for this type.
+ const fn needs_alias(&self) -> bool {
+ use crate::TypeInner as Ti;
+
+ match self.inner {
+ // value types are concise enough, we only alias them if they are named
+ Ti::Scalar { .. }
+ | Ti::Vector { .. }
+ | Ti::Matrix { .. }
+ | Ti::Atomic { .. }
+ | Ti::Pointer { .. }
+ | Ti::ValuePointer { .. } => self.name.is_some(),
+ // composite types are better to be aliased, regardless of the name
+ Ti::Struct { .. } | Ti::Array { .. } => true,
+ // handle types may be different, depending on the global var access, so we always inline them
+ Ti::Image { .. } | Ti::Sampler { .. } | Ti::BindingArray { .. } => false,
+ }
+ }
+}
+
+impl crate::Constant {
+ // Returns `true` if we need to emit an alias for this constant.
+ const fn needs_alias(&self) -> bool {
+ match self.inner {
+ crate::ConstantInner::Scalar { .. } => self.name.is_some(),
+ crate::ConstantInner::Composite { .. } => true,
+ }
+ }
+}
+
+enum FunctionOrigin {
+ Handle(Handle<crate::Function>),
+ EntryPoint(proc::EntryPointIndex),
+}
+
+/// A level of detail argument.
+///
+/// When [`BoundsCheckPolicy::Restrict`] applies to an [`ImageLoad`] access, we
+/// save the clamped level of detail in a temporary variable whose name is based
+/// on the handle of the `ImageLoad` expression. But for other policies, we just
+/// use the expression directly.
+///
+/// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict
+/// [`ImageLoad`]: crate::Expression::ImageLoad
+#[derive(Clone, Copy)]
+enum LevelOfDetail {
+ Direct(Handle<crate::Expression>),
+ Restricted(Handle<crate::Expression>),
+}
+
+/// Values needed to select a particular texel for [`ImageLoad`] and [`ImageStore`].
+///
+/// When this is used in code paths unconcerned with the `Restrict` bounds check
+/// policy, the `LevelOfDetail` enum introduces an unneeded match, since `level`
+/// will always be either `None` or `Some(Direct(_))`. But this turns out not to
+/// be too awkward. If that changes, we can revisit.
+///
+/// [`ImageLoad`]: crate::Expression::ImageLoad
+/// [`ImageStore`]: crate::Statement::ImageStore
+struct TexelAddress {
+ coordinate: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ sample: Option<Handle<crate::Expression>>,
+ level: Option<LevelOfDetail>,
+}
+
+struct ExpressionContext<'a> {
+ function: &'a crate::Function,
+ origin: FunctionOrigin,
+ info: &'a valid::FunctionInfo,
+ module: &'a crate::Module,
+ pipeline_options: &'a PipelineOptions,
+ policies: index::BoundsCheckPolicies,
+
+ /// A bitset containing the `Expression` handle indexes of expressions used
+ /// as indices in `ReadZeroSkipWrite`-policy accesses. These may need to be
+ /// cached in temporary variables. See `index::find_checked_indexes` for
+ /// details.
+ guarded_indices: BitSet,
+}
+
+impl<'a> ExpressionContext<'a> {
+ fn resolve_type(&self, handle: Handle<crate::Expression>) -> &'a crate::TypeInner {
+ self.info[handle].ty.inner_with(&self.module.types)
+ }
+
+ /// Return true if calls to `image`'s `read` and `write` methods should supply a level of detail.
+ ///
+ /// Only mipmapped images need to specify a level of detail. Since 1D
+ /// textures cannot have mipmaps, MSL requires that the level argument to
+ /// texture1d queries and accesses must be a constexpr 0. It's easiest
+ /// just to omit the level entirely for 1D textures.
+ fn image_needs_lod(&self, image: Handle<crate::Expression>) -> bool {
+ let image_ty = self.resolve_type(image);
+ if let crate::TypeInner::Image { dim, class, .. } = *image_ty {
+ class.is_mipmapped() && dim != crate::ImageDimension::D1
+ } else {
+ false
+ }
+ }
+
+ fn choose_bounds_check_policy(
+ &self,
+ pointer: Handle<crate::Expression>,
+ ) -> index::BoundsCheckPolicy {
+ self.policies
+ .choose_policy(pointer, &self.module.types, self.info)
+ }
+
+ fn access_needs_check(
+ &self,
+ base: Handle<crate::Expression>,
+ index: index::GuardedIndex,
+ ) -> Option<index::IndexableLength> {
+ index::access_needs_check(base, index, self.module, self.function, self.info)
+ }
+
+ fn get_packed_vec_kind(
+ &self,
+ expr_handle: Handle<crate::Expression>,
+ ) -> Option<crate::ScalarKind> {
+ match self.function.expressions[expr_handle] {
+ crate::Expression::AccessIndex { base, index } => {
+ let ty = match *self.resolve_type(base) {
+ crate::TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
+ ref ty => ty,
+ };
+ match *ty {
+ crate::TypeInner::Struct {
+ ref members, span, ..
+ } => should_pack_struct_member(members, span, index as usize, self.module),
+ _ => None,
+ }
+ }
+ _ => None,
+ }
+ }
+}
+
+struct StatementContext<'a> {
+ expression: ExpressionContext<'a>,
+ mod_info: &'a valid::ModuleInfo,
+ result_struct: Option<&'a str>,
+}
+
+impl<W: Write> Writer<W> {
+ /// Creates a new `Writer` instance.
+ pub fn new(out: W) -> Self {
+ Writer {
+ out,
+ names: FastHashMap::default(),
+ named_expressions: Default::default(),
+ need_bake_expressions: Default::default(),
+ namer: proc::Namer::default(),
+ #[cfg(test)]
+ put_expression_stack_pointers: Default::default(),
+ #[cfg(test)]
+ put_block_stack_pointers: Default::default(),
+ struct_member_pads: FastHashSet::default(),
+ }
+ }
+
+ /// Finishes writing and returns the output.
+ // See https://github.com/rust-lang/rust-clippy/issues/4979.
+ #[allow(clippy::missing_const_for_fn)]
+ pub fn finish(self) -> W {
+ self.out
+ }
+
+ fn put_call_parameters(
+ &mut self,
+ parameters: impl Iterator<Item = Handle<crate::Expression>>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ write!(self.out, "(")?;
+ for (i, handle) in parameters.enumerate() {
+ if i != 0 {
+ write!(self.out, ", ")?;
+ }
+ self.put_expression(handle, context, true)?;
+ }
+ write!(self.out, ")")?;
+ Ok(())
+ }
+
+ fn put_level_of_detail(
+ &mut self,
+ level: LevelOfDetail,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ match level {
+ LevelOfDetail::Direct(expr) => self.put_expression(expr, context, true)?,
+ LevelOfDetail::Restricted(load) => {
+ write!(self.out, "{}{}", CLAMPED_LOD_LOAD_PREFIX, load.index())?
+ }
+ }
+ Ok(())
+ }
+
+ fn put_image_query(
+ &mut self,
+ image: Handle<crate::Expression>,
+ query: &str,
+ level: Option<LevelOfDetail>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".get_{}(", query)?;
+ if let Some(level) = level {
+ self.put_level_of_detail(level, context)?;
+ }
+ write!(self.out, ")")?;
+ Ok(())
+ }
+
+ fn put_image_size_query(
+ &mut self,
+ image: Handle<crate::Expression>,
+ level: Option<LevelOfDetail>,
+ kind: crate::ScalarKind,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ //Note: MSL only has separate width/height/depth queries,
+ // so compose the result of them.
+ let dim = match *context.resolve_type(image) {
+ crate::TypeInner::Image { dim, .. } => dim,
+ ref other => unreachable!("Unexpected type {:?}", other),
+ };
+ let coordinate_type = kind.to_msl_name();
+ match dim {
+ crate::ImageDimension::D1 => {
+ // Since 1D textures never have mipmaps, MSL requires that the
+ // `level` argument be a constexpr 0. It's simplest for us just
+ // to pass `None` and omit the level entirely.
+ if kind == crate::ScalarKind::Uint {
+ // No need to construct a vector. No cast needed.
+ self.put_image_query(image, "width", None, context)?;
+ } else {
+ // There's no definition for `int` in the `metal` namespace.
+ write!(self.out, "int(")?;
+ self.put_image_query(image, "width", None, context)?;
+ write!(self.out, ")")?;
+ }
+ }
+ crate::ImageDimension::D2 => {
+ write!(self.out, "{}::{}2(", NAMESPACE, coordinate_type)?;
+ self.put_image_query(image, "width", level, context)?;
+ write!(self.out, ", ")?;
+ self.put_image_query(image, "height", level, context)?;
+ write!(self.out, ")")?;
+ }
+ crate::ImageDimension::D3 => {
+ write!(self.out, "{}::{}3(", NAMESPACE, coordinate_type)?;
+ self.put_image_query(image, "width", level, context)?;
+ write!(self.out, ", ")?;
+ self.put_image_query(image, "height", level, context)?;
+ write!(self.out, ", ")?;
+ self.put_image_query(image, "depth", level, context)?;
+ write!(self.out, ")")?;
+ }
+ crate::ImageDimension::Cube => {
+ write!(self.out, "{}::{}2(", NAMESPACE, coordinate_type)?;
+ self.put_image_query(image, "width", level, context)?;
+ write!(self.out, ")")?;
+ }
+ }
+ Ok(())
+ }
+
+ fn put_cast_to_uint_scalar_or_vector(
+ &mut self,
+ expr: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ // coordinates in IR are int, but Metal expects uint
+ match *context.resolve_type(expr) {
+ crate::TypeInner::Scalar { .. } => {
+ put_numeric_type(&mut self.out, crate::ScalarKind::Uint, &[])?
+ }
+ crate::TypeInner::Vector { size, .. } => {
+ put_numeric_type(&mut self.out, crate::ScalarKind::Uint, &[size])?
+ }
+ _ => return Err(Error::Validation),
+ };
+
+ write!(self.out, "(")?;
+ self.put_expression(expr, context, true)?;
+ write!(self.out, ")")?;
+ Ok(())
+ }
+
+ fn put_image_sample_level(
+ &mut self,
+ image: Handle<crate::Expression>,
+ level: crate::SampleLevel,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ let has_levels = context.image_needs_lod(image);
+ match level {
+ crate::SampleLevel::Auto => {}
+ crate::SampleLevel::Zero => {
+ //TODO: do we support Zero on `Sampled` image classes?
+ }
+ _ if !has_levels => {
+ log::warn!("1D image can't be sampled with level {:?}", level);
+ }
+ crate::SampleLevel::Exact(h) => {
+ write!(self.out, ", {}::level(", NAMESPACE)?;
+ self.put_expression(h, context, true)?;
+ write!(self.out, ")")?;
+ }
+ crate::SampleLevel::Bias(h) => {
+ write!(self.out, ", {}::bias(", NAMESPACE)?;
+ self.put_expression(h, context, true)?;
+ write!(self.out, ")")?;
+ }
+ crate::SampleLevel::Gradient { x, y } => {
+ write!(self.out, ", {}::gradient2d(", NAMESPACE)?;
+ self.put_expression(x, context, true)?;
+ write!(self.out, ", ")?;
+ self.put_expression(y, context, true)?;
+ write!(self.out, ")")?;
+ }
+ }
+ Ok(())
+ }
+
+ fn put_image_coordinate_limits(
+ &mut self,
+ image: Handle<crate::Expression>,
+ level: Option<LevelOfDetail>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
+ write!(self.out, " - 1")?;
+ Ok(())
+ }
+
+ /// General function for writing restricted image indexes.
+ ///
+ /// This is used to produce restricted mip levels, array indices, and sample
+ /// indices for [`ImageLoad`] and [`ImageStore`] accesses under the
+ /// [`Restrict`] bounds check policy.
+ ///
+ /// This function writes an expression of the form:
+ ///
+ /// ```ignore
+ ///
+ /// metal::min(uint(INDEX), IMAGE.LIMIT_METHOD() - 1)
+ ///
+ /// ```
+ ///
+ /// [`ImageLoad`]: crate::Expression::ImageLoad
+ /// [`ImageStore`]: crate::Statement::ImageStore
+ /// [`Restrict`]: index::BoundsCheckPolicy::Restrict
+ fn put_restricted_scalar_image_index(
+ &mut self,
+ image: Handle<crate::Expression>,
+ index: Handle<crate::Expression>,
+ limit_method: &str,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ write!(self.out, "{}::min(uint(", NAMESPACE)?;
+ self.put_expression(index, context, true)?;
+ write!(self.out, "), ")?;
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".{}() - 1)", limit_method)?;
+ Ok(())
+ }
+
+ fn put_restricted_texel_address(
+ &mut self,
+ image: Handle<crate::Expression>,
+ address: &TexelAddress,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ // Write the coordinate.
+ write!(self.out, "{}::min(", NAMESPACE)?;
+ self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
+ write!(self.out, ", ")?;
+ self.put_image_coordinate_limits(image, address.level, context)?;
+ write!(self.out, ")")?;
+
+ // Write the array index, if present.
+ if let Some(array_index) = address.array_index {
+ write!(self.out, ", ")?;
+ self.put_restricted_scalar_image_index(image, array_index, "get_array_size", context)?;
+ }
+
+ // Write the sample index, if present.
+ if let Some(sample) = address.sample {
+ write!(self.out, ", ")?;
+ self.put_restricted_scalar_image_index(image, sample, "get_num_samples", context)?;
+ }
+
+ // The level of detail should be clamped and cached by
+ // `put_cache_restricted_level`, so we don't need to clamp it here.
+ if let Some(level) = address.level {
+ write!(self.out, ", ")?;
+ self.put_level_of_detail(level, context)?;
+ }
+
+ Ok(())
+ }
+
+ /// Write an expression that is true if the given image access is in bounds.
+ fn put_image_access_bounds_check(
+ &mut self,
+ image: Handle<crate::Expression>,
+ address: &TexelAddress,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ let mut conjunction = "";
+
+ // First, check the level of detail. Only if that is in bounds can we
+ // use it to find the appropriate bounds for the coordinates.
+ let level = if let Some(level) = address.level {
+ write!(self.out, "uint(")?;
+ self.put_level_of_detail(level, context)?;
+ write!(self.out, ") < ")?;
+ self.put_expression(image, context, true)?;
+ write!(self.out, ".get_num_mip_levels()")?;
+ conjunction = " && ";
+ Some(level)
+ } else {
+ None
+ };
+
+ // Check sample index, if present.
+ if let Some(sample) = address.sample {
+ write!(self.out, "uint(")?;
+ self.put_expression(sample, context, true)?;
+ write!(self.out, ") < ")?;
+ self.put_expression(image, context, true)?;
+ write!(self.out, ".get_num_samples()")?;
+ conjunction = " && ";
+ }
+
+ // Check array index, if present.
+ if let Some(array_index) = address.array_index {
+ write!(self.out, "{}uint(", conjunction)?;
+ self.put_expression(array_index, context, true)?;
+ write!(self.out, ") < ")?;
+ self.put_expression(image, context, true)?;
+ write!(self.out, ".get_array_size()")?;
+ conjunction = " && ";
+ }
+
+ // Finally, check if the coordinates are within bounds.
+ let coord_is_vector = match *context.resolve_type(address.coordinate) {
+ crate::TypeInner::Vector { .. } => true,
+ _ => false,
+ };
+ write!(self.out, "{}", conjunction)?;
+ if coord_is_vector {
+ write!(self.out, "{}::all(", NAMESPACE)?;
+ }
+ self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
+ write!(self.out, " < ")?;
+ self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
+ if coord_is_vector {
+ write!(self.out, ")")?;
+ }
+
+ Ok(())
+ }
+
+ fn put_image_load(
+ &mut self,
+ load: Handle<crate::Expression>,
+ image: Handle<crate::Expression>,
+ mut address: TexelAddress,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ match context.policies.image {
+ proc::BoundsCheckPolicy::Restrict => {
+ // Use the cached restricted level of detail, if any. Omit the
+ // level altogether for 1D textures.
+ if address.level.is_some() {
+ address.level = if context.image_needs_lod(image) {
+ Some(LevelOfDetail::Restricted(load))
+ } else {
+ None
+ }
+ }
+
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".read(")?;
+ self.put_restricted_texel_address(image, &address, context)?;
+ write!(self.out, ")")?;
+ }
+ proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
+ write!(self.out, "(")?;
+ self.put_image_access_bounds_check(image, &address, context)?;
+ write!(self.out, " ? ")?;
+ self.put_unchecked_image_load(image, &address, context)?;
+ write!(self.out, ": DefaultConstructible())")?;
+ }
+ proc::BoundsCheckPolicy::Unchecked => {
+ self.put_unchecked_image_load(image, &address, context)?;
+ }
+ }
+
+ Ok(())
+ }
+
+ fn put_unchecked_image_load(
+ &mut self,
+ image: Handle<crate::Expression>,
+ address: &TexelAddress,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".read(")?;
+ // coordinates in IR are int, but Metal expects uint
+ self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
+ if let Some(expr) = address.array_index {
+ write!(self.out, ", ")?;
+ self.put_expression(expr, context, true)?;
+ }
+ if let Some(sample) = address.sample {
+ write!(self.out, ", ")?;
+ self.put_expression(sample, context, true)?;
+ }
+ if let Some(level) = address.level {
+ if context.image_needs_lod(image) {
+ write!(self.out, ", ")?;
+ self.put_level_of_detail(level, context)?;
+ }
+ }
+ write!(self.out, ")")?;
+
+ Ok(())
+ }
+
+ fn put_image_store(
+ &mut self,
+ level: back::Level,
+ image: Handle<crate::Expression>,
+ address: &TexelAddress,
+ value: Handle<crate::Expression>,
+ context: &StatementContext,
+ ) -> BackendResult {
+ match context.expression.policies.image {
+ proc::BoundsCheckPolicy::Restrict => {
+ // We don't have a restricted level value, because we don't
+ // support writes to mipmapped textures.
+ debug_assert!(address.level.is_none());
+
+ write!(self.out, "{}", level)?;
+ self.put_expression(image, &context.expression, false)?;
+ write!(self.out, ".write(")?;
+ self.put_expression(value, &context.expression, true)?;
+ write!(self.out, ", ")?;
+ self.put_restricted_texel_address(image, address, &context.expression)?;
+ writeln!(self.out, ");")?;
+ }
+ proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
+ write!(self.out, "{}if (", level)?;
+ self.put_image_access_bounds_check(image, address, &context.expression)?;
+ writeln!(self.out, ") {{")?;
+ self.put_unchecked_image_store(level.next(), image, address, value, context)?;
+ writeln!(self.out, "{}}}", level)?;
+ }
+ proc::BoundsCheckPolicy::Unchecked => {
+ self.put_unchecked_image_store(level, image, address, value, context)?;
+ }
+ }
+
+ Ok(())
+ }
+
+ fn put_unchecked_image_store(
+ &mut self,
+ level: back::Level,
+ image: Handle<crate::Expression>,
+ address: &TexelAddress,
+ value: Handle<crate::Expression>,
+ context: &StatementContext,
+ ) -> BackendResult {
+ write!(self.out, "{}", level)?;
+ self.put_expression(image, &context.expression, false)?;
+ write!(self.out, ".write(")?;
+ self.put_expression(value, &context.expression, true)?;
+ write!(self.out, ", ")?;
+ // coordinates in IR are int, but Metal expects uint
+ self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
+ if let Some(expr) = address.array_index {
+ write!(self.out, ", ")?;
+ self.put_expression(expr, &context.expression, true)?;
+ }
+ writeln!(self.out, ");")?;
+
+ Ok(())
+ }
+
+ fn put_compose(
+ &mut self,
+ ty: Handle<crate::Type>,
+ components: &[Handle<crate::Expression>],
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ match context.module.types[ty].inner {
+ crate::TypeInner::Scalar { width: 4, kind } if components.len() == 1 => {
+ write!(self.out, "{}", kind.to_msl_name())?;
+ self.put_call_parameters(components.iter().cloned(), context)?;
+ }
+ crate::TypeInner::Vector { size, kind, .. } => {
+ put_numeric_type(&mut self.out, kind, &[size])?;
+ self.put_call_parameters(components.iter().cloned(), context)?;
+ }
+ crate::TypeInner::Matrix { columns, rows, .. } => {
+ put_numeric_type(&mut self.out, crate::ScalarKind::Float, &[rows, columns])?;
+ self.put_call_parameters(components.iter().cloned(), context)?;
+ }
+ crate::TypeInner::Array { .. } | crate::TypeInner::Struct { .. } => {
+ write!(self.out, "{} {{", &self.names[&NameKey::Type(ty)])?;
+ for (index, &component) in components.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ // insert padding initialization, if needed
+ if self.struct_member_pads.contains(&(ty, index as u32)) {
+ write!(self.out, "{{}}, ")?;
+ }
+ self.put_expression(component, context, true)?;
+ }
+ write!(self.out, "}}")?;
+ }
+ _ => return Err(Error::UnsupportedCompose(ty)),
+ }
+ Ok(())
+ }
+
+ /// Write the maximum valid index of the dynamically sized array at the end of `handle`.
+ ///
+ /// The 'maximum valid index' is simply one less than the array's length.
+ ///
+ /// This emits an expression of the form `a / b`, so the caller must
+ /// parenthesize its output if it will be applying operators of higher
+ /// precedence.
+ ///
+ /// `handle` must be the handle of a global variable whose final member is a
+ /// dynamically sized array.
+ fn put_dynamic_array_max_index(
+ &mut self,
+ handle: Handle<crate::GlobalVariable>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ let global = &context.module.global_variables[handle];
+ let (offset, array_ty) = match context.module.types[global.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => match members.last() {
+ Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
+ None => return Err(Error::Validation),
+ },
+ crate::TypeInner::Array {
+ size: crate::ArraySize::Dynamic,
+ ..
+ } => (0, global.ty),
+ _ => return Err(Error::Validation),
+ };
+
+ let (size, stride) = match context.module.types[array_ty].inner {
+ crate::TypeInner::Array { base, stride, .. } => (
+ context.module.types[base]
+ .inner
+ .size(&context.module.constants),
+ stride,
+ ),
+ _ => return Err(Error::Validation),
+ };
+
+ // When the stride length is larger than the size, the final element's stride of
+ // bytes would have padding following the value. But the buffer size in
+ // `buffer_sizes.sizeN` may not include this padding - it only needs to be large
+ // enough to hold the actual values' bytes.
+ //
+ // So subtract off the size to get a byte size that falls at the start or within
+ // the final element. Then divide by the stride size, to get one less than the
+ // length, and then add one. This works even if the buffer size does include the
+ // stride padding, since division rounds towards zero (MSL 2.4 §6.1). It will fail
+ // if there are zero elements in the array, but the WebGPU `validating shader binding`
+ // rules, together with draw-time validation when `minBindingSize` is zero,
+ // prevent that.
+ write!(
+ self.out,
+ "(_buffer_sizes.size{idx} - {offset} - {size}) / {stride}",
+ idx = handle.index(),
+ offset = offset,
+ size = size,
+ stride = stride,
+ )?;
+ Ok(())
+ }
+
+ fn put_atomic_fetch(
+ &mut self,
+ pointer: Handle<crate::Expression>,
+ key: &str,
+ value: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ self.put_atomic_operation(pointer, "fetch_", key, value, context)
+ }
+
+ fn put_atomic_operation(
+ &mut self,
+ pointer: Handle<crate::Expression>,
+ key1: &str,
+ key2: &str,
+ value: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ // If the pointer we're passing to the atomic operation needs to be conditional
+ // for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and
+ // the pointer operand should be unchecked.
+ let policy = context.choose_bounds_check_policy(pointer);
+ let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
+ && self.put_bounds_checks(pointer, context, back::Level(0), "")?;
+
+ // If requested and successfully put bounds checks, continue the ternary expression.
+ if checked {
+ write!(self.out, " ? ")?;
+ }
+
+ write!(
+ self.out,
+ "{}::atomic_{}{}_explicit({}",
+ NAMESPACE, key1, key2, ATOMIC_REFERENCE
+ )?;
+ self.put_access_chain(pointer, policy, context)?;
+ write!(self.out, ", ")?;
+ self.put_expression(value, context, true)?;
+ write!(self.out, ", {}::memory_order_relaxed)", NAMESPACE)?;
+
+ // Finish the ternary expression.
+ if checked {
+ write!(self.out, " : DefaultConstructible()")?;
+ }
+
+ Ok(())
+ }
+
+ /// Emit code for the arithmetic expression of the dot product.
+ ///
+ fn put_dot_product(
+ &mut self,
+ arg: Handle<crate::Expression>,
+ arg1: Handle<crate::Expression>,
+ size: usize,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ // Write parantheses around the dot product expression to prevent operators
+ // with different precedences from applying earlier.
+ write!(self.out, "(")?;
+
+ // Cycle trough all the components of the vector
+ for index in 0..size {
+ let component = back::COMPONENTS[index];
+ // Write the addition to the previous product
+ // This will print an extra '+' at the beginning but that is fine in msl
+ write!(self.out, " + ")?;
+ // Write the first vector expression, this expression is marked to be
+ // cached so unless it can't be cached (for example, it's a Constant)
+ // it shouldn't produce large expressions.
+ self.put_expression(arg, context, true)?;
+ // Access the current component on the first vector
+ write!(self.out, ".{} * ", component)?;
+ // Write the second vector expression, this expression is marked to be
+ // cached so unless it can't be cached (for example, it's a Constant)
+ // it shouldn't produce large expressions.
+ self.put_expression(arg1, context, true)?;
+ // Access the current component on the second vector
+ write!(self.out, ".{}", component)?;
+ }
+
+ write!(self.out, ")")?;
+ Ok(())
+ }
+
+ /// Emit code for the expression `expr_handle`.
+ ///
+ /// The `is_scoped` argument is true if the surrounding operators have the
+ /// precedence of the comma operator, or lower. So, for example:
+ ///
+ /// - Pass `true` for `is_scoped` when writing function arguments, an
+ /// expression statement, an initializer expression, or anything already
+ /// wrapped in parenthesis.
+ ///
+ /// - Pass `false` if it is an operand of a `?:` operator, a `[]`, or really
+ /// almost anything else.
+ fn put_expression(
+ &mut self,
+ expr_handle: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ is_scoped: bool,
+ ) -> BackendResult {
+ // Add to the set in order to track the stack size.
+ #[cfg(test)]
+ #[allow(trivial_casts)]
+ self.put_expression_stack_pointers
+ .insert(&expr_handle as *const _ as *const ());
+
+ if let Some(name) = self.named_expressions.get(&expr_handle) {
+ write!(self.out, "{}", name)?;
+ return Ok(());
+ }
+
+ let expression = &context.function.expressions[expr_handle];
+ log::trace!("expression {:?} = {:?}", expr_handle, expression);
+ match *expression {
+ crate::Expression::Access { base, .. }
+ | crate::Expression::AccessIndex { base, .. } => {
+ // This is an acceptable place to generate a `ReadZeroSkipWrite` check.
+ // Since `put_bounds_checks` and `put_access_chain` handle an entire
+ // access chain at a time, recursing back through `put_expression` only
+ // for index expressions and the base object, we will never see intermediate
+ // `Access` or `AccessIndex` expressions here.
+ let policy = context.choose_bounds_check_policy(base);
+ if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
+ && self.put_bounds_checks(
+ expr_handle,
+ context,
+ back::Level(0),
+ if is_scoped { "" } else { "(" },
+ )?
+ {
+ write!(self.out, " ? ")?;
+ self.put_access_chain(expr_handle, policy, context)?;
+ write!(self.out, " : DefaultConstructible()")?;
+
+ if !is_scoped {
+ write!(self.out, ")")?;
+ }
+ } else {
+ self.put_access_chain(expr_handle, policy, context)?;
+ }
+ }
+ crate::Expression::Constant(handle) => {
+ let coco = ConstantContext {
+ handle,
+ arena: &context.module.constants,
+ names: &self.names,
+ first_time: false,
+ };
+ write!(self.out, "{}", coco)?;
+ }
+ crate::Expression::Splat { size, value } => {
+ let scalar_kind = match *context.resolve_type(value) {
+ crate::TypeInner::Scalar { kind, .. } => kind,
+ _ => return Err(Error::Validation),
+ };
+ put_numeric_type(&mut self.out, scalar_kind, &[size])?;
+ write!(self.out, "(")?;
+ self.put_expression(value, context, true)?;
+ write!(self.out, ")")?;
+ }
+ crate::Expression::Swizzle {
+ size,
+ vector,
+ pattern,
+ } => {
+ self.put_wrapped_expression_for_packed_vec3_access(vector, context, false)?;
+ write!(self.out, ".")?;
+ for &sc in pattern[..size as usize].iter() {
+ write!(self.out, "{}", back::COMPONENTS[sc as usize])?;
+ }
+ }
+ crate::Expression::Compose { ty, ref components } => {
+ self.put_compose(ty, components, context)?;
+ }
+ crate::Expression::FunctionArgument(index) => {
+ let name_key = match context.origin {
+ FunctionOrigin::Handle(handle) => NameKey::FunctionArgument(handle, index),
+ FunctionOrigin::EntryPoint(ep_index) => {
+ NameKey::EntryPointArgument(ep_index, index)
+ }
+ };
+ let name = &self.names[&name_key];
+ write!(self.out, "{}", name)?;
+ }
+ crate::Expression::GlobalVariable(handle) => {
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ write!(self.out, "{}", name)?;
+ }
+ crate::Expression::LocalVariable(handle) => {
+ let name_key = match context.origin {
+ FunctionOrigin::Handle(fun_handle) => {
+ NameKey::FunctionLocal(fun_handle, handle)
+ }
+ FunctionOrigin::EntryPoint(ep_index) => {
+ NameKey::EntryPointLocal(ep_index, handle)
+ }
+ };
+ let name = &self.names[&name_key];
+ write!(self.out, "{}", name)?;
+ }
+ crate::Expression::Load { pointer } => self.put_load(pointer, context, is_scoped)?,
+ crate::Expression::ImageSample {
+ image,
+ sampler,
+ gather,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ } => {
+ let main_op = match gather {
+ Some(_) => "gather",
+ None => "sample",
+ };
+ let comparison_op = match depth_ref {
+ Some(_) => "_compare",
+ None => "",
+ };
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".{}{}(", main_op, comparison_op)?;
+ self.put_expression(sampler, context, true)?;
+ write!(self.out, ", ")?;
+ self.put_expression(coordinate, context, true)?;
+ if let Some(expr) = array_index {
+ write!(self.out, ", ")?;
+ self.put_expression(expr, context, true)?;
+ }
+ if let Some(dref) = depth_ref {
+ write!(self.out, ", ")?;
+ self.put_expression(dref, context, true)?;
+ }
+
+ self.put_image_sample_level(image, level, context)?;
+
+ if let Some(constant) = offset {
+ let coco = ConstantContext {
+ handle: constant,
+ arena: &context.module.constants,
+ names: &self.names,
+ first_time: false,
+ };
+ write!(self.out, ", {}", coco)?;
+ }
+ match gather {
+ None | Some(crate::SwizzleComponent::X) => {}
+ Some(component) => {
+ let is_cube_map = match *context.resolve_type(image) {
+ crate::TypeInner::Image {
+ dim: crate::ImageDimension::Cube,
+ ..
+ } => true,
+ _ => false,
+ };
+ // Offset always comes before the gather, except
+ // in cube maps where it's not applicable
+ if offset.is_none() && !is_cube_map {
+ write!(self.out, ", {}::int2(0)", NAMESPACE)?;
+ }
+ let letter = ['x', 'y', 'z', 'w'][component as usize];
+ write!(self.out, ", {}::component::{}", NAMESPACE, letter)?;
+ }
+ }
+ write!(self.out, ")")?;
+ }
+ crate::Expression::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => {
+ let address = TexelAddress {
+ coordinate,
+ array_index,
+ sample,
+ level: level.map(LevelOfDetail::Direct),
+ };
+ self.put_image_load(expr_handle, image, address, context)?;
+ }
+ //Note: for all the queries, the signed integers are expected,
+ // so a conversion is needed.
+ crate::Expression::ImageQuery { image, query } => match query {
+ crate::ImageQuery::Size { level } => {
+ self.put_image_size_query(
+ image,
+ level.map(LevelOfDetail::Direct),
+ crate::ScalarKind::Sint,
+ context,
+ )?;
+ }
+ crate::ImageQuery::NumLevels => {
+ write!(self.out, "int(")?;
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".get_num_mip_levels())")?;
+ }
+ crate::ImageQuery::NumLayers => {
+ write!(self.out, "int(")?;
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".get_array_size())")?;
+ }
+ crate::ImageQuery::NumSamples => {
+ write!(self.out, "int(")?;
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".get_num_samples())")?;
+ }
+ },
+ crate::Expression::Unary { op, expr } => {
+ use crate::{ScalarKind as Sk, UnaryOperator as Uo};
+ let op_str = match op {
+ Uo::Negate => "-",
+ Uo::Not => match context.resolve_type(expr).scalar_kind() {
+ Some(Sk::Sint) | Some(Sk::Uint) => "~",
+ Some(Sk::Bool) => "!",
+ _ => return Err(Error::Validation),
+ },
+ };
+ write!(self.out, "{}", op_str)?;
+ self.put_expression(expr, context, false)?;
+ }
+ crate::Expression::Binary { op, left, right } => {
+ let op_str = crate::back::binary_operation_str(op);
+ let kind = context
+ .resolve_type(left)
+ .scalar_kind()
+ .ok_or(Error::UnsupportedBinaryOp(op))?;
+
+ // TODO: handle undefined behavior of BinaryOperator::Modulo
+ //
+ // sint:
+ // if right == 0 return 0
+ // if left == min(type_of(left)) && right == -1 return 0
+ // if sign(left) == -1 || sign(right) == -1 return result as defined by WGSL
+ //
+ // uint:
+ // if right == 0 return 0
+ //
+ // float:
+ // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798
+
+ if op == crate::BinaryOperator::Modulo && kind == crate::ScalarKind::Float {
+ write!(self.out, "{}::fmod(", NAMESPACE)?;
+ self.put_expression(left, context, true)?;
+ write!(self.out, ", ")?;
+ self.put_expression(right, context, true)?;
+ write!(self.out, ")")?;
+ } else {
+ if !is_scoped {
+ write!(self.out, "(")?;
+ }
+
+ // Cast packed vector if necessary
+ // Packed vector - matrix multiplications are not supported in MSL
+ if op == crate::BinaryOperator::Multiply
+ && matches!(
+ context.resolve_type(right),
+ &crate::TypeInner::Matrix { .. }
+ )
+ {
+ self.put_wrapped_expression_for_packed_vec3_access(left, context, false)?;
+ } else {
+ self.put_expression(left, context, false)?;
+ }
+
+ write!(self.out, " {} ", op_str)?;
+
+ // See comment above
+ if op == crate::BinaryOperator::Multiply
+ && matches!(context.resolve_type(left), &crate::TypeInner::Matrix { .. })
+ {
+ self.put_wrapped_expression_for_packed_vec3_access(right, context, false)?;
+ } else {
+ self.put_expression(right, context, false)?;
+ }
+
+ if !is_scoped {
+ write!(self.out, ")")?;
+ }
+ }
+ }
+ crate::Expression::Select {
+ condition,
+ accept,
+ reject,
+ } => match *context.resolve_type(condition) {
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Bool,
+ ..
+ } => {
+ if !is_scoped {
+ write!(self.out, "(")?;
+ }
+ self.put_expression(condition, context, false)?;
+ write!(self.out, " ? ")?;
+ self.put_expression(accept, context, false)?;
+ write!(self.out, " : ")?;
+ self.put_expression(reject, context, false)?;
+ if !is_scoped {
+ write!(self.out, ")")?;
+ }
+ }
+ crate::TypeInner::Vector {
+ kind: crate::ScalarKind::Bool,
+ ..
+ } => {
+ write!(self.out, "{}::select(", NAMESPACE)?;
+ self.put_expression(reject, context, true)?;
+ write!(self.out, ", ")?;
+ self.put_expression(accept, context, true)?;
+ write!(self.out, ", ")?;
+ self.put_expression(condition, context, true)?;
+ write!(self.out, ")")?;
+ }
+ _ => return Err(Error::Validation),
+ },
+ crate::Expression::Derivative { axis, expr } => {
+ let op = match axis {
+ crate::DerivativeAxis::X => "dfdx",
+ crate::DerivativeAxis::Y => "dfdy",
+ crate::DerivativeAxis::Width => "fwidth",
+ };
+ write!(self.out, "{}::{}", NAMESPACE, op)?;
+ self.put_call_parameters(iter::once(expr), context)?;
+ }
+ crate::Expression::Relational { fun, argument } => {
+ let op = match fun {
+ crate::RelationalFunction::Any => "any",
+ crate::RelationalFunction::All => "all",
+ crate::RelationalFunction::IsNan => "isnan",
+ crate::RelationalFunction::IsInf => "isinf",
+ crate::RelationalFunction::IsFinite => "isfinite",
+ crate::RelationalFunction::IsNormal => "isnormal",
+ };
+ write!(self.out, "{}::{}", NAMESPACE, op)?;
+ self.put_call_parameters(iter::once(argument), context)?;
+ }
+ crate::Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ use crate::MathFunction as Mf;
+
+ let scalar_argument = match *context.resolve_type(arg) {
+ crate::TypeInner::Scalar { .. } => true,
+ _ => false,
+ };
+
+ let fun_name = match fun {
+ // comparison
+ Mf::Abs => "abs",
+ Mf::Min => "min",
+ Mf::Max => "max",
+ Mf::Clamp => "clamp",
+ Mf::Saturate => "saturate",
+ // trigonometry
+ Mf::Cos => "cos",
+ Mf::Cosh => "cosh",
+ Mf::Sin => "sin",
+ Mf::Sinh => "sinh",
+ Mf::Tan => "tan",
+ Mf::Tanh => "tanh",
+ Mf::Acos => "acos",
+ Mf::Asin => "asin",
+ Mf::Atan => "atan",
+ Mf::Atan2 => "atan2",
+ Mf::Asinh => "asinh",
+ Mf::Acosh => "acosh",
+ Mf::Atanh => "atanh",
+ Mf::Radians => "",
+ Mf::Degrees => "",
+ // decomposition
+ Mf::Ceil => "ceil",
+ Mf::Floor => "floor",
+ Mf::Round => "rint",
+ Mf::Fract => "fract",
+ Mf::Trunc => "trunc",
+ Mf::Modf => "modf",
+ Mf::Frexp => "frexp",
+ Mf::Ldexp => "ldexp",
+ // exponent
+ Mf::Exp => "exp",
+ Mf::Exp2 => "exp2",
+ Mf::Log => "log",
+ Mf::Log2 => "log2",
+ Mf::Pow => "pow",
+ // geometry
+ Mf::Dot => match *context.resolve_type(arg) {
+ crate::TypeInner::Vector {
+ kind: crate::ScalarKind::Float,
+ ..
+ } => "dot",
+ crate::TypeInner::Vector { size, .. } => {
+ return self.put_dot_product(arg, arg1.unwrap(), size as usize, context)
+ }
+ _ => unreachable!(
+ "Correct TypeInner for dot product should be already validated"
+ ),
+ },
+ Mf::Outer => return Err(Error::UnsupportedCall(format!("{:?}", fun))),
+ Mf::Cross => "cross",
+ Mf::Distance => "distance",
+ Mf::Length if scalar_argument => "abs",
+ Mf::Length => "length",
+ Mf::Normalize => "normalize",
+ Mf::FaceForward => "faceforward",
+ Mf::Reflect => "reflect",
+ Mf::Refract => "refract",
+ // computational
+ Mf::Sign => "sign",
+ Mf::Fma => "fma",
+ Mf::Mix => "mix",
+ Mf::Step => "step",
+ Mf::SmoothStep => "smoothstep",
+ Mf::Sqrt => "sqrt",
+ Mf::InverseSqrt => "rsqrt",
+ Mf::Inverse => return Err(Error::UnsupportedCall(format!("{:?}", fun))),
+ Mf::Transpose => "transpose",
+ Mf::Determinant => "determinant",
+ // bits
+ Mf::CountOneBits => "popcount",
+ Mf::ReverseBits => "reverse_bits",
+ Mf::ExtractBits => "extract_bits",
+ Mf::InsertBits => "insert_bits",
+ Mf::FindLsb => "",
+ Mf::FindMsb => "",
+ // data packing
+ Mf::Pack4x8snorm => "pack_float_to_snorm4x8",
+ Mf::Pack4x8unorm => "pack_float_to_unorm4x8",
+ Mf::Pack2x16snorm => "pack_float_to_snorm2x16",
+ Mf::Pack2x16unorm => "pack_float_to_unorm2x16",
+ Mf::Pack2x16float => "",
+ // data unpacking
+ Mf::Unpack4x8snorm => "unpack_snorm4x8_to_float",
+ Mf::Unpack4x8unorm => "unpack_unorm4x8_to_float",
+ Mf::Unpack2x16snorm => "unpack_snorm2x16_to_float",
+ Mf::Unpack2x16unorm => "unpack_unorm2x16_to_float",
+ Mf::Unpack2x16float => "",
+ };
+
+ if fun == Mf::Distance && scalar_argument {
+ write!(self.out, "{}::abs(", NAMESPACE)?;
+ self.put_expression(arg, context, false)?;
+ write!(self.out, " - ")?;
+ self.put_expression(arg1.unwrap(), context, false)?;
+ write!(self.out, ")")?;
+ } else if fun == Mf::FindLsb {
+ write!(self.out, "((({}::ctz(", NAMESPACE)?;
+ self.put_expression(arg, context, true)?;
+ write!(self.out, ") + 1) % 33) - 1)")?;
+ } else if fun == Mf::FindMsb {
+ write!(self.out, "((({}::clz(", NAMESPACE)?;
+ self.put_expression(arg, context, true)?;
+ write!(self.out, ") + 1) % 33) - 1)")?
+ } else if fun == Mf::Unpack2x16float {
+ write!(self.out, "float2(as_type<half2>(")?;
+ self.put_expression(arg, context, false)?;
+ write!(self.out, "))")?;
+ } else if fun == Mf::Pack2x16float {
+ write!(self.out, "as_type<uint>(half2(")?;
+ self.put_expression(arg, context, false)?;
+ write!(self.out, "))")?;
+ } else if fun == Mf::Radians {
+ write!(self.out, "((")?;
+ self.put_expression(arg, context, false)?;
+ write!(self.out, ") * 0.017453292519943295474)")?;
+ } else if fun == Mf::Degrees {
+ write!(self.out, "((")?;
+ self.put_expression(arg, context, false)?;
+ write!(self.out, ") * 57.295779513082322865)")?;
+ } else {
+ write!(self.out, "{}::{}", NAMESPACE, fun_name)?;
+ self.put_call_parameters(
+ iter::once(arg).chain(arg1).chain(arg2).chain(arg3),
+ context,
+ )?;
+ }
+ }
+ crate::Expression::As {
+ expr,
+ kind,
+ convert,
+ } => match *context.resolve_type(expr) {
+ crate::TypeInner::Scalar {
+ kind: src_kind,
+ width: src_width,
+ }
+ | crate::TypeInner::Vector {
+ kind: src_kind,
+ width: src_width,
+ ..
+ } => {
+ let is_bool_cast =
+ kind == crate::ScalarKind::Bool || src_kind == crate::ScalarKind::Bool;
+ let op = match convert {
+ Some(w) if w == src_width || is_bool_cast => "static_cast",
+ Some(8) if kind == crate::ScalarKind::Float => {
+ return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
+ }
+ Some(_) => return Err(Error::Validation),
+ None => "as_type",
+ };
+ write!(self.out, "{}<", op)?;
+ match *context.resolve_type(expr) {
+ crate::TypeInner::Vector { size, .. } => {
+ put_numeric_type(&mut self.out, kind, &[size])?
+ }
+ _ => put_numeric_type(&mut self.out, kind, &[])?,
+ };
+ write!(self.out, ">(")?;
+ self.put_expression(expr, context, true)?;
+ write!(self.out, ")")?;
+ }
+ crate::TypeInner::Matrix { columns, rows, .. } => {
+ put_numeric_type(&mut self.out, kind, &[rows, columns])?;
+ write!(self.out, "(")?;
+ self.put_expression(expr, context, true)?;
+ write!(self.out, ")")?;
+ }
+ _ => return Err(Error::Validation),
+ },
+ // has to be a named expression
+ crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } => {
+ unreachable!()
+ }
+ crate::Expression::ArrayLength(expr) => {
+ // Find the global to which the array belongs.
+ let global = match context.function.expressions[expr] {
+ crate::Expression::AccessIndex { base, .. } => {
+ match context.function.expressions[base] {
+ crate::Expression::GlobalVariable(handle) => handle,
+ _ => return Err(Error::Validation),
+ }
+ }
+ crate::Expression::GlobalVariable(handle) => handle,
+ _ => return Err(Error::Validation),
+ };
+
+ if !is_scoped {
+ write!(self.out, "(")?;
+ }
+ write!(self.out, "1 + ")?;
+ self.put_dynamic_array_max_index(global, context)?;
+ if !is_scoped {
+ write!(self.out, ")")?;
+ }
+ }
+ }
+ Ok(())
+ }
+
+ /// Used by expressions like Swizzle and Binary since they need packed_vec3's to be casted to a vec3
+ fn put_wrapped_expression_for_packed_vec3_access(
+ &mut self,
+ expr_handle: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ is_scoped: bool,
+ ) -> BackendResult {
+ if let Some(scalar_kind) = context.get_packed_vec_kind(expr_handle) {
+ write!(self.out, "{}::{}3(", NAMESPACE, scalar_kind.to_msl_name())?;
+ self.put_expression(expr_handle, context, is_scoped)?;
+ write!(self.out, ")")?;
+ } else {
+ self.put_expression(expr_handle, context, is_scoped)?;
+ }
+ Ok(())
+ }
+
+ /// Write a `GuardedIndex` as a Metal expression.
+ fn put_index(
+ &mut self,
+ index: index::GuardedIndex,
+ context: &ExpressionContext,
+ is_scoped: bool,
+ ) -> BackendResult {
+ match index {
+ index::GuardedIndex::Expression(expr) => {
+ self.put_expression(expr, context, is_scoped)?
+ }
+ index::GuardedIndex::Known(value) => write!(self.out, "{}", value)?,
+ }
+ Ok(())
+ }
+
+ /// Emit an index bounds check condition for `chain`, if required.
+ ///
+ /// `chain` is a subtree of `Access` and `AccessIndex` expressions,
+ /// operating either on a pointer to a value, or on a value directly. If we cannot
+ /// statically determine that all indexing operations in `chain` are within
+ /// bounds, then write a conditional expression to check them dynamically,
+ /// and return true. All accesses in the chain are checked by the generated
+ /// expression.
+ ///
+ /// This assumes that the [`BoundsCheckPolicy`] for `chain` is [`ReadZeroSkipWrite`].
+ ///
+ /// The text written is of the form:
+ ///
+ /// ```ignore
+ /// {level}{prefix}uint(i) < 4 && uint(j) < 10
+ /// ```
+ ///
+ /// where `{level}` and `{prefix}` are the arguments to this function. For [`Store`]
+ /// statements, presumably these arguments start an indented `if` statement; for
+ /// [`Load`] expressions, the caller is probably building up a ternary `?:`
+ /// expression. In either case, what is written is not a complete syntactic structure
+ /// in its own right, and the caller will have to finish it off if we return `true`.
+ ///
+ /// If no expression is written, return false.
+ ///
+ /// [`BoundsCheckPolicy`]: index::BoundsCheckPolicy
+ /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
+ /// [`Store`]: crate::Statement::Store
+ /// [`Load`]: crate::Expression::Load
+ #[allow(unused_variables)]
+ fn put_bounds_checks(
+ &mut self,
+ mut chain: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ level: back::Level,
+ prefix: &'static str,
+ ) -> Result<bool, Error> {
+ let mut check_written = false;
+
+ // Iterate over the access chain, handling each expression.
+ loop {
+ // Produce a `GuardedIndex`, so we can shared code between the
+ // `Access` and `AccessIndex` cases.
+ let (base, guarded_index) = match context.function.expressions[chain] {
+ crate::Expression::Access { base, index } => {
+ (base, Some(index::GuardedIndex::Expression(index)))
+ }
+ crate::Expression::AccessIndex { base, index } => {
+ // Don't try to check indices into structs. Validation already took
+ // care of them, and index::needs_guard doesn't handle that case.
+ let mut base_inner = context.resolve_type(base);
+ if let crate::TypeInner::Pointer { base, .. } = *base_inner {
+ base_inner = &context.module.types[base].inner;
+ }
+ match *base_inner {
+ crate::TypeInner::Struct { .. } => (base, None),
+ _ => (base, Some(index::GuardedIndex::Known(index))),
+ }
+ }
+ _ => break,
+ };
+
+ if let Some(index) = guarded_index {
+ if let Some(length) = context.access_needs_check(base, index) {
+ if check_written {
+ write!(self.out, " && ")?;
+ } else {
+ write!(self.out, "{}{}", level, prefix)?;
+ check_written = true;
+ }
+
+ // Check that the index falls within bounds. Do this with a single
+ // comparison, by casting the index to `uint` first, so that negative
+ // indices become large positive values.
+ write!(self.out, "uint(")?;
+ self.put_index(index, context, true)?;
+ self.out.write_str(") < ")?;
+ match length {
+ index::IndexableLength::Known(value) => write!(self.out, "{}", value)?,
+ index::IndexableLength::Dynamic => {
+ let global = context
+ .function
+ .originating_global(base)
+ .ok_or(Error::Validation)?;
+ write!(self.out, "1 + ")?;
+ self.put_dynamic_array_max_index(global, context)?
+ }
+ }
+ }
+ }
+
+ chain = base
+ }
+
+ Ok(check_written)
+ }
+
+ /// Write the access chain `chain`.
+ ///
+ /// `chain` is a subtree of [`Access`] and [`AccessIndex`] expressions,
+ /// operating either on a pointer to a value, or on a value directly.
+ ///
+ /// Generate bounds checks code only if `policy` is [`Restrict`]. The
+ /// [`ReadZeroSkipWrite`] policy requires checks before any accesses take place, so
+ /// that must be handled in the caller.
+ ///
+ /// Handle the entire chain, recursing back into `put_expression` only for index
+ /// expressions and the base expression that originates the pointer or composite value
+ /// being accessed. This allows `put_expression` to assume that any `Access` or
+ /// `AccessIndex` expressions it sees are the top of a chain, so it can emit
+ /// `ReadZeroSkipWrite` checks.
+ ///
+ /// [`Access`]: crate::Expression::Access
+ /// [`AccessIndex`]: crate::Expression::AccessIndex
+ /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict
+ /// [`ReadZeroSkipWrite`]: crate::proc::index::BoundsCheckPolicy::ReadZeroSkipWrite
+ fn put_access_chain(
+ &mut self,
+ chain: Handle<crate::Expression>,
+ policy: index::BoundsCheckPolicy,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ match context.function.expressions[chain] {
+ crate::Expression::Access { base, index } => {
+ let mut base_ty = context.resolve_type(base);
+
+ // Look through any pointers to see what we're really indexing.
+ if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
+ base_ty = &context.module.types[base].inner;
+ }
+
+ self.put_subscripted_access_chain(
+ base,
+ base_ty,
+ index::GuardedIndex::Expression(index),
+ policy,
+ context,
+ )?;
+ }
+ crate::Expression::AccessIndex { base, index } => {
+ let base_resolution = &context.info[base].ty;
+ let mut base_ty = base_resolution.inner_with(&context.module.types);
+ let mut base_ty_handle = base_resolution.handle();
+
+ // Look through any pointers to see what we're really indexing.
+ if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
+ base_ty = &context.module.types[base].inner;
+ base_ty_handle = Some(base);
+ }
+
+ // Handle structs and anything else that can use `.x` syntax here, so
+ // `put_subscripted_access_chain` won't have to handle the absurd case of
+ // indexing a struct with an expression.
+ match *base_ty {
+ crate::TypeInner::Struct { .. } => {
+ let base_ty = base_ty_handle.unwrap();
+ self.put_access_chain(base, policy, context)?;
+ let name = &self.names[&NameKey::StructMember(base_ty, index)];
+ write!(self.out, ".{}", name)?;
+ }
+ crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => {
+ self.put_access_chain(base, policy, context)?;
+ // Prior to Metal v2.1 component access for packed vectors wasn't available
+ // however array indexing is
+ if context.get_packed_vec_kind(base).is_some() {
+ write!(self.out, "[{}]", index)?;
+ } else {
+ write!(self.out, ".{}", back::COMPONENTS[index as usize])?;
+ }
+ }
+ _ => {
+ self.put_subscripted_access_chain(
+ base,
+ base_ty,
+ index::GuardedIndex::Known(index),
+ policy,
+ context,
+ )?;
+ }
+ }
+ }
+ _ => self.put_expression(chain, context, false)?,
+ }
+
+ Ok(())
+ }
+
+ /// Write a `[]`-style access of `base` by `index`.
+ ///
+ /// If `policy` is [`Restrict`], then generate code as needed to force all index
+ /// values within bounds.
+ ///
+ /// The `base_ty` argument must be the type we are actually indexing, like [`Array`] or
+ /// [`Vector`]. In other words, it's `base`'s type with any surrounding [`Pointer`]
+ /// removed. Our callers often already have this handy.
+ ///
+ /// This only emits `[]` expressions; it doesn't handle struct member accesses or
+ /// referencing vector components by name.
+ ///
+ /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict
+ /// [`Array`]: crate::TypeInner::Array
+ /// [`Vector`]: crate::TypeInner::Vector
+ /// [`Pointer`]: crate::TypeInner::Pointer
+ fn put_subscripted_access_chain(
+ &mut self,
+ base: Handle<crate::Expression>,
+ base_ty: &crate::TypeInner,
+ index: index::GuardedIndex,
+ policy: index::BoundsCheckPolicy,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ let accessing_wrapped_array = match *base_ty {
+ crate::TypeInner::Array {
+ size: crate::ArraySize::Constant(_),
+ ..
+ } => true,
+ _ => false,
+ };
+
+ self.put_access_chain(base, policy, context)?;
+ if accessing_wrapped_array {
+ write!(self.out, ".{}", WRAPPED_ARRAY_FIELD)?;
+ }
+ write!(self.out, "[")?;
+
+ // Decide whether this index needs to be clamped to fall within range.
+ let restriction_needed = if policy == index::BoundsCheckPolicy::Restrict {
+ context.access_needs_check(base, index)
+ } else {
+ None
+ };
+ if let Some(limit) = restriction_needed {
+ write!(self.out, "{}::min(unsigned(", NAMESPACE)?;
+ self.put_index(index, context, true)?;
+ write!(self.out, "), ")?;
+ match limit {
+ index::IndexableLength::Known(limit) => {
+ write!(self.out, "{}u", limit - 1)?;
+ }
+ index::IndexableLength::Dynamic => {
+ let global = context
+ .function
+ .originating_global(base)
+ .ok_or(Error::Validation)?;
+ self.put_dynamic_array_max_index(global, context)?;
+ }
+ }
+ write!(self.out, ")")?;
+ } else {
+ self.put_index(index, context, true)?;
+ }
+
+ write!(self.out, "]")?;
+
+ Ok(())
+ }
+
+ fn put_load(
+ &mut self,
+ pointer: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ is_scoped: bool,
+ ) -> BackendResult {
+ // Since access chains never cross between address spaces, we can just
+ // check the index bounds check policy once at the top.
+ let policy = context.choose_bounds_check_policy(pointer);
+ if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
+ && self.put_bounds_checks(
+ pointer,
+ context,
+ back::Level(0),
+ if is_scoped { "" } else { "(" },
+ )?
+ {
+ write!(self.out, " ? ")?;
+ self.put_unchecked_load(pointer, policy, context)?;
+ write!(self.out, " : DefaultConstructible()")?;
+
+ if !is_scoped {
+ write!(self.out, ")")?;
+ }
+ } else {
+ self.put_unchecked_load(pointer, policy, context)?;
+ }
+
+ Ok(())
+ }
+
+ fn put_unchecked_load(
+ &mut self,
+ pointer: Handle<crate::Expression>,
+ policy: index::BoundsCheckPolicy,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ let is_atomic = match *context.resolve_type(pointer) {
+ crate::TypeInner::Pointer { base, .. } => match context.module.types[base].inner {
+ crate::TypeInner::Atomic { .. } => true,
+ _ => false,
+ },
+ _ => false,
+ };
+
+ if is_atomic {
+ write!(
+ self.out,
+ "{}::atomic_load_explicit({}",
+ NAMESPACE, ATOMIC_REFERENCE
+ )?;
+ self.put_access_chain(pointer, policy, context)?;
+ write!(self.out, ", {}::memory_order_relaxed)", NAMESPACE)?;
+ } else {
+ // We don't do any dereferencing with `*` here as pointer arguments to functions
+ // are done by `&` references and not `*` pointers. These do not need to be
+ // dereferenced.
+ self.put_access_chain(pointer, policy, context)?;
+ }
+
+ Ok(())
+ }
+
+ fn put_return_value(
+ &mut self,
+ level: back::Level,
+ expr_handle: Handle<crate::Expression>,
+ result_struct: Option<&str>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ match result_struct {
+ Some(struct_name) => {
+ let mut has_point_size = false;
+ let result_ty = context.function.result.as_ref().unwrap().ty;
+ match context.module.types[result_ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ let tmp = "_tmp";
+ write!(self.out, "{}const auto {} = ", level, tmp)?;
+ self.put_expression(expr_handle, context, true)?;
+ writeln!(self.out, ";")?;
+ write!(self.out, "{}return {} {{", level, struct_name)?;
+
+ let mut is_first = true;
+
+ for (index, member) in members.iter().enumerate() {
+ match member.binding {
+ Some(crate::Binding::BuiltIn(crate::BuiltIn::PointSize)) => {
+ has_point_size = true;
+ if !context.pipeline_options.allow_point_size {
+ continue;
+ }
+ }
+ Some(crate::Binding::BuiltIn(crate::BuiltIn::CullDistance)) => {
+ log::warn!("Ignoring CullDistance built-in");
+ continue;
+ }
+ _ => {}
+ }
+
+ let comma = if is_first { "" } else { "," };
+ is_first = false;
+ let name = &self.names[&NameKey::StructMember(result_ty, index as u32)];
+ // HACK: we are forcefully deduplicating the expression here
+ // to convert from a wrapped struct to a raw array, e.g.
+ // `float gl_ClipDistance1 [[clip_distance]] [1];`.
+ if let crate::TypeInner::Array {
+ size: crate::ArraySize::Constant(const_handle),
+ ..
+ } = context.module.types[member.ty].inner
+ {
+ let size = context.module.constants[const_handle]
+ .to_array_length()
+ .unwrap();
+ write!(self.out, "{} {{", comma)?;
+ for j in 0..size {
+ if j != 0 {
+ write!(self.out, ",")?;
+ }
+ write!(
+ self.out,
+ "{}.{}.{}[{}]",
+ tmp, name, WRAPPED_ARRAY_FIELD, j
+ )?;
+ }
+ write!(self.out, "}}")?;
+ } else {
+ write!(self.out, "{} {}.{}", comma, tmp, name)?;
+ }
+ }
+ }
+ _ => {
+ write!(self.out, "{}return {} {{ ", level, struct_name)?;
+ self.put_expression(expr_handle, context, true)?;
+ }
+ }
+
+ if let FunctionOrigin::EntryPoint(ep_index) = context.origin {
+ let stage = context.module.entry_points[ep_index as usize].stage;
+ if context.pipeline_options.allow_point_size
+ && stage == crate::ShaderStage::Vertex
+ && !has_point_size
+ {
+ // point size was injected and comes last
+ write!(self.out, ", 1.0")?;
+ }
+ }
+ write!(self.out, " }}")?;
+ }
+ None => {
+ write!(self.out, "{}return ", level)?;
+ self.put_expression(expr_handle, context, true)?;
+ }
+ }
+ writeln!(self.out, ";")?;
+ Ok(())
+ }
+
+ /// Helper method used to find which expressions of a given function require baking
+ ///
+ /// # Notes
+ /// This function overwrites the contents of `self.need_bake_expressions`
+ fn update_expressions_to_bake(
+ &mut self,
+ func: &crate::Function,
+ info: &valid::FunctionInfo,
+ context: &ExpressionContext,
+ ) {
+ use crate::Expression;
+ self.need_bake_expressions.clear();
+ for expr in func.expressions.iter() {
+ // Expressions whose reference count is above the
+ // threshold should always be stored in temporaries.
+ let expr_info = &info[expr.0];
+ let min_ref_count = func.expressions[expr.0].bake_ref_count();
+ if min_ref_count <= expr_info.ref_count {
+ self.need_bake_expressions.insert(expr.0);
+ }
+
+ // WGSL's `dot` function works on any `vecN` type, but Metal's only
+ // works on floating-point vectors, so we emit inline code for
+ // integer vector `dot` calls. But that code uses each argument `N`
+ // times, once for each component (see `put_dot_product`), so to
+ // avoid duplicated evaluation, we must bake integer operands.
+ if let (
+ fun_handle,
+ &Expression::Math {
+ fun: crate::MathFunction::Dot,
+ arg,
+ arg1,
+ ..
+ },
+ ) = expr
+ {
+ use crate::TypeInner;
+ // check what kind of product this is depending
+ // on the resolve type of the Dot function itself
+ let inner = context.resolve_type(fun_handle);
+ if let TypeInner::Scalar { kind, .. } = *inner {
+ match kind {
+ crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
+ self.need_bake_expressions.insert(arg);
+ self.need_bake_expressions.insert(arg1.unwrap());
+ }
+ _ => {}
+ }
+ }
+ }
+ }
+ }
+
+ fn start_baking_expression(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ name: &str,
+ ) -> BackendResult {
+ match context.info[handle].ty {
+ TypeResolution::Handle(ty_handle) => {
+ let ty_name = TypeContext {
+ handle: ty_handle,
+ module: context.module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ write!(self.out, "{}", ty_name)?;
+ }
+ TypeResolution::Value(crate::TypeInner::Scalar { kind, .. }) => {
+ put_numeric_type(&mut self.out, kind, &[])?;
+ }
+ TypeResolution::Value(crate::TypeInner::Vector { size, kind, .. }) => {
+ put_numeric_type(&mut self.out, kind, &[size])?;
+ }
+ TypeResolution::Value(crate::TypeInner::Matrix { columns, rows, .. }) => {
+ put_numeric_type(&mut self.out, crate::ScalarKind::Float, &[rows, columns])?;
+ }
+ TypeResolution::Value(ref other) => {
+ log::warn!("Type {:?} isn't a known local", other); //TEMP!
+ return Err(Error::FeatureNotImplemented("weird local type".to_string()));
+ }
+ }
+
+ //TODO: figure out the naming scheme that wouldn't collide with user names.
+ write!(self.out, " {} = ", name)?;
+
+ Ok(())
+ }
+
+ /// Cache a clamped level of detail value, if necessary.
+ ///
+ /// [`ImageLoad`] accesses covered by [`BoundsCheckPolicy::Restrict`] use a
+ /// properly clamped level of detail value both in the access itself, and
+ /// for fetching the size of the requested MIP level, needed to clamp the
+ /// coordinates. To avoid recomputing this clamped level of detail, we cache
+ /// it in a temporary variable, as part of the [`Emit`] statement covering
+ /// the [`ImageLoad`] expression.
+ ///
+ /// [`ImageLoad`]: crate::Expression::ImageLoad
+ /// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict
+ /// [`Emit`]: crate::Statement::Emit
+ fn put_cache_restricted_level(
+ &mut self,
+ load: Handle<crate::Expression>,
+ image: Handle<crate::Expression>,
+ mip_level: Option<Handle<crate::Expression>>,
+ indent: back::Level,
+ context: &StatementContext,
+ ) -> BackendResult {
+ // Does this image access actually require (or even permit) a
+ // level-of-detail, and does the policy require us to restrict it?
+ let level_of_detail = match mip_level {
+ Some(level) => level,
+ None => return Ok(()),
+ };
+
+ if context.expression.policies.image != index::BoundsCheckPolicy::Restrict
+ || !context.expression.image_needs_lod(image)
+ {
+ return Ok(());
+ }
+
+ write!(
+ self.out,
+ "{}uint {}{} = ",
+ indent,
+ CLAMPED_LOD_LOAD_PREFIX,
+ load.index(),
+ )?;
+ self.put_restricted_scalar_image_index(
+ image,
+ level_of_detail,
+ "get_num_mip_levels",
+ &context.expression,
+ )?;
+ writeln!(self.out, ";")?;
+
+ Ok(())
+ }
+
+ fn put_block(
+ &mut self,
+ level: back::Level,
+ statements: &[crate::Statement],
+ context: &StatementContext,
+ ) -> BackendResult {
+ // Add to the set in order to track the stack size.
+ #[cfg(test)]
+ #[allow(trivial_casts)]
+ self.put_block_stack_pointers
+ .insert(&level as *const _ as *const ());
+
+ for statement in statements {
+ log::trace!("statement[{}] {:?}", level.0, statement);
+ match *statement {
+ crate::Statement::Emit(ref range) => {
+ for handle in range.clone() {
+ // `ImageLoad` expressions covered by the `Restrict` bounds check policy
+ // may need to cache a clamped version of their level-of-detail argument.
+ if let crate::Expression::ImageLoad {
+ image,
+ level: mip_level,
+ ..
+ } = context.expression.function.expressions[handle]
+ {
+ self.put_cache_restricted_level(
+ handle, image, mip_level, level, context,
+ )?;
+ }
+
+ let info = &context.expression.info[handle];
+ let ptr_class = info
+ .ty
+ .inner_with(&context.expression.module.types)
+ .pointer_space();
+ let expr_name = if ptr_class.is_some() {
+ None // don't bake pointer expressions (just yet)
+ } else if let Some(name) =
+ context.expression.function.named_expressions.get(&handle)
+ {
+ // The `crate::Function::named_expressions` table holds
+ // expressions that should be saved in temporaries once they
+ // are `Emit`ted. We only add them to `self.named_expressions`
+ // when we reach the `Emit` that covers them, so that we don't
+ // try to use their names before we've actually initialized
+ // the temporary that holds them.
+ //
+ // Don't assume the names in `named_expressions` are unique,
+ // or even valid. Use the `Namer`.
+ Some(self.namer.call(name))
+ } else if info.ref_count == 0 {
+ Some(self.namer.call(""))
+ } else {
+ // If this expression is an index that we're going to first compare
+ // against a limit, and then actually use as an index, then we may
+ // want to cache it in a temporary, to avoid evaluating it twice.
+ let bake =
+ if context.expression.guarded_indices.contains(handle.index()) {
+ true
+ } else {
+ self.need_bake_expressions.contains(&handle)
+ };
+
+ if bake {
+ Some(format!("{}{}", back::BAKE_PREFIX, handle.index()))
+ } else {
+ None
+ }
+ };
+
+ if let Some(name) = expr_name {
+ write!(self.out, "{}", level)?;
+ self.start_baking_expression(handle, &context.expression, &name)?;
+ self.put_expression(handle, &context.expression, true)?;
+ self.named_expressions.insert(handle, name);
+ writeln!(self.out, ";")?;
+ }
+ }
+ }
+ crate::Statement::Block(ref block) => {
+ if !block.is_empty() {
+ writeln!(self.out, "{}{{", level)?;
+ self.put_block(level.next(), block, context)?;
+ writeln!(self.out, "{}}}", level)?;
+ }
+ }
+ crate::Statement::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ write!(self.out, "{}if (", level)?;
+ self.put_expression(condition, &context.expression, true)?;
+ writeln!(self.out, ") {{")?;
+ self.put_block(level.next(), accept, context)?;
+ if !reject.is_empty() {
+ writeln!(self.out, "{}}} else {{", level)?;
+ self.put_block(level.next(), reject, context)?;
+ }
+ writeln!(self.out, "{}}}", level)?;
+ }
+ crate::Statement::Switch {
+ selector,
+ ref cases,
+ } => {
+ write!(self.out, "{}switch(", level)?;
+ self.put_expression(selector, &context.expression, true)?;
+ let type_postfix = match *context.expression.resolve_type(selector) {
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Uint,
+ ..
+ } => "u",
+ _ => "",
+ };
+ writeln!(self.out, ") {{")?;
+ let lcase = level.next();
+ for case in cases.iter() {
+ match case.value {
+ crate::SwitchValue::Integer(value) => {
+ writeln!(self.out, "{}case {}{}: {{", lcase, value, type_postfix)?;
+ }
+ crate::SwitchValue::Default => {
+ writeln!(self.out, "{}default: {{", lcase)?;
+ }
+ }
+ self.put_block(lcase.next(), &case.body, context)?;
+ if !case.fall_through
+ && case.body.last().map_or(true, |s| !s.is_terminator())
+ {
+ writeln!(self.out, "{}break;", lcase.next())?;
+ }
+ writeln!(self.out, "{}}}", lcase)?;
+ }
+ writeln!(self.out, "{}}}", level)?;
+ }
+ crate::Statement::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ if !continuing.is_empty() || break_if.is_some() {
+ let gate_name = self.namer.call("loop_init");
+ writeln!(self.out, "{}bool {} = true;", level, gate_name)?;
+ writeln!(self.out, "{}while(true) {{", level)?;
+ let lif = level.next();
+ let lcontinuing = lif.next();
+ writeln!(self.out, "{}if (!{}) {{", lif, gate_name)?;
+ self.put_block(lcontinuing, continuing, context)?;
+ if let Some(condition) = break_if {
+ write!(self.out, "{}if (", lcontinuing)?;
+ self.put_expression(condition, &context.expression, true)?;
+ writeln!(self.out, ") {{")?;
+ writeln!(self.out, "{}break;", lcontinuing.next())?;
+ writeln!(self.out, "{}}}", lcontinuing)?;
+ }
+ writeln!(self.out, "{}}}", lif)?;
+ writeln!(self.out, "{}{} = false;", lif, gate_name)?;
+ } else {
+ writeln!(self.out, "{}while(true) {{", level)?;
+ }
+ self.put_block(level.next(), body, context)?;
+ writeln!(self.out, "{}}}", level)?;
+ }
+ crate::Statement::Break => {
+ writeln!(self.out, "{}break;", level)?;
+ }
+ crate::Statement::Continue => {
+ writeln!(self.out, "{}continue;", level)?;
+ }
+ crate::Statement::Return {
+ value: Some(expr_handle),
+ } => {
+ self.put_return_value(
+ level,
+ expr_handle,
+ context.result_struct,
+ &context.expression,
+ )?;
+ }
+ crate::Statement::Return { value: None } => {
+ writeln!(self.out, "{}return;", level)?;
+ }
+ crate::Statement::Kill => {
+ writeln!(self.out, "{}{}::discard_fragment();", level, NAMESPACE)?;
+ }
+ crate::Statement::Barrier(flags) => {
+ //Note: OR-ring bitflags requires `__HAVE_MEMFLAG_OPERATORS__`,
+ // so we try to avoid it here.
+ if flags.is_empty() {
+ writeln!(
+ self.out,
+ "{}{}::threadgroup_barrier({}::mem_flags::mem_none);",
+ level, NAMESPACE, NAMESPACE,
+ )?;
+ }
+ if flags.contains(crate::Barrier::STORAGE) {
+ writeln!(
+ self.out,
+ "{}{}::threadgroup_barrier({}::mem_flags::mem_device);",
+ level, NAMESPACE, NAMESPACE,
+ )?;
+ }
+ if flags.contains(crate::Barrier::WORK_GROUP) {
+ writeln!(
+ self.out,
+ "{}{}::threadgroup_barrier({}::mem_flags::mem_threadgroup);",
+ level, NAMESPACE, NAMESPACE,
+ )?;
+ }
+ }
+ crate::Statement::Store { pointer, value } => {
+ self.put_store(pointer, value, level, context)?
+ }
+ crate::Statement::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => {
+ let address = TexelAddress {
+ coordinate,
+ array_index,
+ sample: None,
+ level: None,
+ };
+ self.put_image_store(level, image, &address, value, context)?
+ }
+ crate::Statement::Call {
+ function,
+ ref arguments,
+ result,
+ } => {
+ write!(self.out, "{}", level)?;
+ if let Some(expr) = result {
+ let name = format!("{}{}", back::BAKE_PREFIX, expr.index());
+ self.start_baking_expression(expr, &context.expression, &name)?;
+ self.named_expressions.insert(expr, name);
+ }
+ let fun_name = &self.names[&NameKey::Function(function)];
+ write!(self.out, "{}(", fun_name)?;
+ // first, write down the actual arguments
+ for (i, &handle) in arguments.iter().enumerate() {
+ if i != 0 {
+ write!(self.out, ", ")?;
+ }
+ self.put_expression(handle, &context.expression, true)?;
+ }
+ // follow-up with any global resources used
+ let mut separate = !arguments.is_empty();
+ let fun_info = &context.mod_info[function];
+ let mut supports_array_length = false;
+ for (handle, var) in context.expression.module.global_variables.iter() {
+ if fun_info[handle].is_empty() {
+ continue;
+ }
+ if var.space.needs_pass_through() {
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ if separate {
+ write!(self.out, ", ")?;
+ } else {
+ separate = true;
+ }
+ write!(self.out, "{}", name)?;
+ }
+ supports_array_length |=
+ needs_array_length(var.ty, &context.expression.module.types);
+ }
+ if supports_array_length {
+ if separate {
+ write!(self.out, ", ")?;
+ }
+ write!(self.out, "_buffer_sizes")?;
+ }
+
+ // done
+ writeln!(self.out, ");")?;
+ }
+ crate::Statement::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result,
+ } => {
+ write!(self.out, "{}", level)?;
+ let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ self.start_baking_expression(result, &context.expression, &res_name)?;
+ self.named_expressions.insert(result, res_name);
+ match *fun {
+ crate::AtomicFunction::Add => {
+ self.put_atomic_fetch(pointer, "add", value, &context.expression)?;
+ }
+ crate::AtomicFunction::Subtract => {
+ self.put_atomic_fetch(pointer, "sub", value, &context.expression)?;
+ }
+ crate::AtomicFunction::And => {
+ self.put_atomic_fetch(pointer, "and", value, &context.expression)?;
+ }
+ crate::AtomicFunction::InclusiveOr => {
+ self.put_atomic_fetch(pointer, "or", value, &context.expression)?;
+ }
+ crate::AtomicFunction::ExclusiveOr => {
+ self.put_atomic_fetch(pointer, "xor", value, &context.expression)?;
+ }
+ crate::AtomicFunction::Min => {
+ self.put_atomic_fetch(pointer, "min", value, &context.expression)?;
+ }
+ crate::AtomicFunction::Max => {
+ self.put_atomic_fetch(pointer, "max", value, &context.expression)?;
+ }
+ crate::AtomicFunction::Exchange { compare: None } => {
+ self.put_atomic_operation(
+ pointer,
+ "exchange",
+ "",
+ value,
+ &context.expression,
+ )?;
+ }
+ crate::AtomicFunction::Exchange { .. } => {
+ return Err(Error::FeatureNotImplemented(
+ "atomic CompareExchange".to_string(),
+ ));
+ }
+ }
+ // done
+ writeln!(self.out, ";")?;
+ }
+ }
+ }
+
+ // un-emit expressions
+ //TODO: take care of loop/continuing?
+ for statement in statements {
+ if let crate::Statement::Emit(ref range) = *statement {
+ for handle in range.clone() {
+ self.named_expressions.remove(&handle);
+ }
+ }
+ }
+ Ok(())
+ }
+
+ fn put_store(
+ &mut self,
+ pointer: Handle<crate::Expression>,
+ value: Handle<crate::Expression>,
+ level: back::Level,
+ context: &StatementContext,
+ ) -> BackendResult {
+ let policy = context.expression.choose_bounds_check_policy(pointer);
+ if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
+ && self.put_bounds_checks(pointer, &context.expression, level, "if (")?
+ {
+ writeln!(self.out, ") {{")?;
+ self.put_unchecked_store(pointer, value, policy, level.next(), context)?;
+ writeln!(self.out, "{}}}", level)?;
+ } else {
+ self.put_unchecked_store(pointer, value, policy, level, context)?;
+ }
+
+ Ok(())
+ }
+
+ fn put_unchecked_store(
+ &mut self,
+ pointer: Handle<crate::Expression>,
+ value: Handle<crate::Expression>,
+ policy: index::BoundsCheckPolicy,
+ level: back::Level,
+ context: &StatementContext,
+ ) -> BackendResult {
+ let pointer_inner = context.expression.resolve_type(pointer);
+ let (array_size, is_atomic) = match *pointer_inner {
+ crate::TypeInner::Pointer { base, .. } => {
+ match context.expression.module.types[base].inner {
+ crate::TypeInner::Array {
+ size: crate::ArraySize::Constant(ch),
+ ..
+ } => (Some(ch), false),
+ crate::TypeInner::Atomic { .. } => (None, true),
+ _ => (None, false),
+ }
+ }
+ _ => (None, false),
+ };
+
+ // we can't assign fixed-size arrays
+ if let Some(const_handle) = array_size {
+ let size = context.expression.module.constants[const_handle]
+ .to_array_length()
+ .unwrap();
+ write!(self.out, "{}for(int _i=0; _i<{}; ++_i) ", level, size)?;
+ self.put_access_chain(pointer, policy, &context.expression)?;
+ write!(self.out, ".{}[_i] = ", WRAPPED_ARRAY_FIELD)?;
+ self.put_expression(value, &context.expression, true)?;
+ writeln!(self.out, ".{}[_i];", WRAPPED_ARRAY_FIELD)?;
+ } else if is_atomic {
+ write!(
+ self.out,
+ "{}{}::atomic_store_explicit({}",
+ level, NAMESPACE, ATOMIC_REFERENCE
+ )?;
+ self.put_access_chain(pointer, policy, &context.expression)?;
+ write!(self.out, ", ")?;
+ self.put_expression(value, &context.expression, true)?;
+ writeln!(self.out, ", {}::memory_order_relaxed);", NAMESPACE)?;
+ } else {
+ write!(self.out, "{}", level)?;
+ self.put_access_chain(pointer, policy, &context.expression)?;
+ write!(self.out, " = ")?;
+ self.put_expression(value, &context.expression, true)?;
+ writeln!(self.out, ";")?;
+ }
+
+ Ok(())
+ }
+
+ pub fn write(
+ &mut self,
+ module: &crate::Module,
+ info: &valid::ModuleInfo,
+ options: &Options,
+ pipeline_options: &PipelineOptions,
+ ) -> Result<TranslationInfo, Error> {
+ self.names.clear();
+ self.namer
+ .reset(module, super::keywords::RESERVED, &[], &mut self.names);
+ self.struct_member_pads.clear();
+
+ writeln!(
+ self.out,
+ "// language: metal{}.{}",
+ options.lang_version.0, options.lang_version.1
+ )?;
+ writeln!(self.out, "#include <metal_stdlib>")?;
+ writeln!(self.out, "#include <simd/simd.h>")?;
+ writeln!(self.out)?;
+ // Work around Metal bug where `uint` is not available by default
+ writeln!(self.out, "using {}::uint;", NAMESPACE)?;
+ writeln!(self.out)?;
+
+ if options
+ .bounds_check_policies
+ .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite)
+ {
+ self.put_default_constructible()?;
+ }
+
+ {
+ let mut indices = vec![];
+ for (handle, var) in module.global_variables.iter() {
+ if needs_array_length(var.ty, &module.types) {
+ let idx = handle.index();
+ indices.push(idx);
+ }
+ }
+
+ if !indices.is_empty() {
+ writeln!(self.out, "struct _mslBufferSizes {{")?;
+
+ for idx in indices {
+ writeln!(self.out, "{}uint size{};", back::INDENT, idx)?;
+ }
+
+ writeln!(self.out, "}};")?;
+ writeln!(self.out)?;
+ }
+ };
+
+ self.write_scalar_constants(module)?;
+ self.write_type_defs(module)?;
+ self.write_composite_constants(module)?;
+ self.write_functions(module, info, options, pipeline_options)
+ }
+
+ /// Write the definition for the `DefaultConstructible` class.
+ ///
+ /// The [`ReadZeroSkipWrite`] bounds check policy requires us to be able to
+ /// produce 'zero' values for any type, including structs, arrays, and so
+ /// on. We could do this by emitting default constructor applications, but
+ /// that would entail printing the name of the type, which is more trouble
+ /// than you'd think. Instead, we just construct this magic C++14 class that
+ /// can be converted to any type that can be default constructed, using
+ /// template parameter inference to detect which type is needed, so we don't
+ /// have to figure out the name.
+ ///
+ /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
+ fn put_default_constructible(&mut self) -> BackendResult {
+ writeln!(self.out, "struct DefaultConstructible {{")?;
+ writeln!(self.out, " template<typename T>")?;
+ writeln!(self.out, " operator T() && {{")?;
+ writeln!(self.out, " return T {{}};")?;
+ writeln!(self.out, " }}")?;
+ writeln!(self.out, "}};")?;
+ Ok(())
+ }
+
+ fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult {
+ for (handle, ty) in module.types.iter() {
+ if !ty.needs_alias() {
+ continue;
+ }
+ let name = &self.names[&NameKey::Type(handle)];
+ match ty.inner {
+ // Naga IR can pass around arrays by value, but Metal, following
+ // C++, performs an array-to-pointer conversion (C++ [conv.array])
+ // on expressions of array type, so assigning the array by value
+ // isn't possible. However, Metal *does* assign structs by
+ // value. So in our Metal output, we wrap all array types in
+ // synthetic struct types:
+ //
+ // struct type1 {
+ // float inner[10]
+ // };
+ //
+ // Then we carefully include `.inner` (`WRAPPED_ARRAY_FIELD`) in
+ // any expression that actually wants access to the array.
+ crate::TypeInner::Array {
+ base,
+ size,
+ stride: _,
+ } => {
+ let base_name = TypeContext {
+ handle: base,
+ module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+
+ match size {
+ crate::ArraySize::Constant(const_handle) => {
+ let coco = ConstantContext {
+ handle: const_handle,
+ arena: &module.constants,
+ names: &self.names,
+ first_time: false,
+ };
+
+ writeln!(self.out, "struct {} {{", name)?;
+ writeln!(
+ self.out,
+ "{}{} {}[{}];",
+ back::INDENT,
+ base_name,
+ WRAPPED_ARRAY_FIELD,
+ coco
+ )?;
+ writeln!(self.out, "}};")?;
+ }
+ crate::ArraySize::Dynamic => {
+ writeln!(self.out, "typedef {} {}[1];", base_name, name)?;
+ }
+ }
+ }
+ crate::TypeInner::Struct {
+ ref members, span, ..
+ } => {
+ writeln!(self.out, "struct {} {{", name)?;
+ let mut last_offset = 0;
+ for (index, member) in members.iter().enumerate() {
+ // quick and dirty way to figure out if we need this...
+ if member.binding.is_none() && member.offset > last_offset {
+ self.struct_member_pads.insert((handle, index as u32));
+ let pad = member.offset - last_offset;
+ writeln!(self.out, "{}char _pad{}[{}];", back::INDENT, index, pad)?;
+ }
+ let ty_inner = &module.types[member.ty].inner;
+ last_offset = member.offset + ty_inner.size(&module.constants);
+
+ let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
+
+ // If the member should be packed (as is the case for a misaligned vec3) issue a packed vector
+ match should_pack_struct_member(members, span, index, module) {
+ Some(kind) => {
+ writeln!(
+ self.out,
+ "{}{}::packed_{}3 {};",
+ back::INDENT,
+ NAMESPACE,
+ kind.to_msl_name(),
+ member_name
+ )?;
+ }
+ None => {
+ let base_name = TypeContext {
+ handle: member.ty,
+ module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ writeln!(
+ self.out,
+ "{}{} {};",
+ back::INDENT,
+ base_name,
+ member_name
+ )?;
+
+ // for 3-component vectors, add one component
+ if let crate::TypeInner::Vector {
+ size: crate::VectorSize::Tri,
+ kind: _,
+ width,
+ } = *ty_inner
+ {
+ last_offset += width as u32;
+ }
+ }
+ }
+ }
+ writeln!(self.out, "}};")?;
+ }
+ _ => {
+ let ty_name = TypeContext {
+ handle,
+ module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: true,
+ };
+ writeln!(self.out, "typedef {} {};", ty_name, name)?;
+ }
+ }
+ }
+ Ok(())
+ }
+
+ fn write_scalar_constants(&mut self, module: &crate::Module) -> BackendResult {
+ for (handle, constant) in module.constants.iter() {
+ match constant.inner {
+ crate::ConstantInner::Scalar {
+ width: _,
+ ref value,
+ } if constant.name.is_some() => {
+ debug_assert!(constant.needs_alias());
+ write!(self.out, "constexpr constant ")?;
+ match *value {
+ crate::ScalarValue::Sint(_) => {
+ write!(self.out, "int")?;
+ }
+ crate::ScalarValue::Uint(_) => {
+ write!(self.out, "unsigned")?;
+ }
+ crate::ScalarValue::Float(_) => {
+ write!(self.out, "float")?;
+ }
+ crate::ScalarValue::Bool(_) => {
+ write!(self.out, "bool")?;
+ }
+ }
+ let name = &self.names[&NameKey::Constant(handle)];
+ let coco = ConstantContext {
+ handle,
+ arena: &module.constants,
+ names: &self.names,
+ first_time: true,
+ };
+ writeln!(self.out, " {} = {};", name, coco)?;
+ }
+ _ => {}
+ }
+ }
+ Ok(())
+ }
+
+ fn write_composite_constants(&mut self, module: &crate::Module) -> BackendResult {
+ for (handle, constant) in module.constants.iter() {
+ match constant.inner {
+ crate::ConstantInner::Scalar { .. } => {}
+ crate::ConstantInner::Composite { ty, ref components } => {
+ debug_assert!(constant.needs_alias());
+ let name = &self.names[&NameKey::Constant(handle)];
+ let ty_name = TypeContext {
+ handle: ty,
+ module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ write!(self.out, "constant {} {} = {{", ty_name, name,)?;
+ for (i, &sub_handle) in components.iter().enumerate() {
+ // insert padding initialization, if needed
+ if self.struct_member_pads.contains(&(ty, i as u32)) {
+ write!(self.out, ", {{}}")?;
+ }
+ let separator = if i != 0 { ", " } else { "" };
+ let coco = ConstantContext {
+ handle: sub_handle,
+ arena: &module.constants,
+ names: &self.names,
+ first_time: false,
+ };
+ write!(self.out, "{}{}", separator, coco)?;
+ }
+ writeln!(self.out, "}};")?;
+ }
+ }
+ }
+ Ok(())
+ }
+
+ fn put_inline_sampler_properties(
+ &mut self,
+ level: back::Level,
+ sampler: &sm::InlineSampler,
+ ) -> BackendResult {
+ for (&letter, address) in ['s', 't', 'r'].iter().zip(sampler.address.iter()) {
+ writeln!(
+ self.out,
+ "{}{}::{}_address::{},",
+ level,
+ NAMESPACE,
+ letter,
+ address.as_str(),
+ )?;
+ }
+ writeln!(
+ self.out,
+ "{}{}::mag_filter::{},",
+ level,
+ NAMESPACE,
+ sampler.mag_filter.as_str(),
+ )?;
+ writeln!(
+ self.out,
+ "{}{}::min_filter::{},",
+ level,
+ NAMESPACE,
+ sampler.min_filter.as_str(),
+ )?;
+ if let Some(filter) = sampler.mip_filter {
+ writeln!(
+ self.out,
+ "{}{}::mip_filter::{},",
+ level,
+ NAMESPACE,
+ filter.as_str(),
+ )?;
+ }
+ // avoid setting it on platforms that don't support it
+ if sampler.border_color != sm::BorderColor::TransparentBlack {
+ writeln!(
+ self.out,
+ "{}{}::border_color::{},",
+ level,
+ NAMESPACE,
+ sampler.border_color.as_str(),
+ )?;
+ }
+ //TODO: I'm not able to feed this in a way that MSL likes:
+ //>error: use of undeclared identifier 'lod_clamp'
+ //>error: no member named 'max_anisotropy' in namespace 'metal'
+ if false {
+ if let Some(ref lod) = sampler.lod_clamp {
+ writeln!(self.out, "{}lod_clamp({},{}),", level, lod.start, lod.end,)?;
+ }
+ if let Some(aniso) = sampler.max_anisotropy {
+ writeln!(self.out, "{}max_anisotropy({}),", level, aniso.get(),)?;
+ }
+ }
+ if sampler.compare_func != sm::CompareFunc::Never {
+ writeln!(
+ self.out,
+ "{}{}::compare_func::{},",
+ level,
+ NAMESPACE,
+ sampler.compare_func.as_str(),
+ )?;
+ }
+ writeln!(
+ self.out,
+ "{}{}::coord::{}",
+ level,
+ NAMESPACE,
+ sampler.coord.as_str()
+ )?;
+ Ok(())
+ }
+
+ // Returns the array of mapped entry point names.
+ fn write_functions(
+ &mut self,
+ module: &crate::Module,
+ mod_info: &valid::ModuleInfo,
+ options: &Options,
+ pipeline_options: &PipelineOptions,
+ ) -> Result<TranslationInfo, Error> {
+ let mut pass_through_globals = Vec::new();
+ for (fun_handle, fun) in module.functions.iter() {
+ log::trace!(
+ "function {:?}, handle {:?}",
+ fun.name.as_deref().unwrap_or("(anonymous)"),
+ fun_handle
+ );
+
+ let fun_info = &mod_info[fun_handle];
+ pass_through_globals.clear();
+ let mut supports_array_length = false;
+ for (handle, var) in module.global_variables.iter() {
+ if !fun_info[handle].is_empty() {
+ if var.space.needs_pass_through() {
+ pass_through_globals.push(handle);
+ }
+ supports_array_length |= needs_array_length(var.ty, &module.types);
+ }
+ }
+
+ writeln!(self.out)?;
+ let fun_name = &self.names[&NameKey::Function(fun_handle)];
+ match fun.result {
+ Some(ref result) => {
+ let ty_name = TypeContext {
+ handle: result.ty,
+ module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ write!(self.out, "{}", ty_name)?;
+ }
+ None => {
+ write!(self.out, "void")?;
+ }
+ }
+ writeln!(self.out, " {}(", fun_name)?;
+
+ for (index, arg) in fun.arguments.iter().enumerate() {
+ let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)];
+ let param_type_name = TypeContext {
+ handle: arg.ty,
+ module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ let separator = separate(
+ !pass_through_globals.is_empty()
+ || index + 1 != fun.arguments.len()
+ || supports_array_length,
+ );
+ writeln!(
+ self.out,
+ "{}{} {}{}",
+ back::INDENT,
+ param_type_name,
+ name,
+ separator
+ )?;
+ }
+ for (index, &handle) in pass_through_globals.iter().enumerate() {
+ let tyvar = TypedGlobalVariable {
+ module,
+ names: &self.names,
+ handle,
+ usage: fun_info[handle],
+ binding: None,
+ reference: true,
+ };
+ let separator =
+ separate(index + 1 != pass_through_globals.len() || supports_array_length);
+ write!(self.out, "{}", back::INDENT)?;
+ tyvar.try_fmt(&mut self.out)?;
+ writeln!(self.out, "{}", separator)?;
+ }
+
+ if supports_array_length {
+ writeln!(
+ self.out,
+ "{}constant _mslBufferSizes& _buffer_sizes",
+ back::INDENT
+ )?;
+ }
+
+ writeln!(self.out, ") {{")?;
+
+ for (local_handle, local) in fun.local_variables.iter() {
+ let ty_name = TypeContext {
+ handle: local.ty,
+ module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ let local_name = &self.names[&NameKey::FunctionLocal(fun_handle, local_handle)];
+ write!(self.out, "{}{} {}", back::INDENT, ty_name, local_name)?;
+ match local.init {
+ Some(value) => {
+ let coco = ConstantContext {
+ handle: value,
+ arena: &module.constants,
+ names: &self.names,
+ first_time: false,
+ };
+ write!(self.out, " = {}", coco)?;
+ }
+ None => {
+ write!(self.out, " = {{}}")?;
+ }
+ };
+ writeln!(self.out, ";")?;
+ }
+
+ let guarded_indices =
+ index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
+
+ let context = StatementContext {
+ expression: ExpressionContext {
+ function: fun,
+ origin: FunctionOrigin::Handle(fun_handle),
+ info: fun_info,
+ policies: options.bounds_check_policies,
+ guarded_indices,
+ module,
+ pipeline_options,
+ },
+ mod_info,
+ result_struct: None,
+ };
+ self.named_expressions.clear();
+ self.update_expressions_to_bake(fun, fun_info, &context.expression);
+ self.put_block(back::Level(1), &fun.body, &context)?;
+ writeln!(self.out, "}}")?;
+ }
+
+ let mut info = TranslationInfo {
+ entry_point_names: Vec::with_capacity(module.entry_points.len()),
+ };
+ for (ep_index, ep) in module.entry_points.iter().enumerate() {
+ let fun = &ep.function;
+ let fun_info = mod_info.get_entry_point(ep_index);
+ let mut ep_error = None;
+
+ log::trace!(
+ "entry point {:?}, index {:?}",
+ fun.name.as_deref().unwrap_or("(anonymous)"),
+ ep_index
+ );
+
+ // Is any global variable used by this entry point dynamically sized?
+ let supports_array_length = module
+ .global_variables
+ .iter()
+ .filter(|&(handle, _)| !fun_info[handle].is_empty())
+ .any(|(_, var)| needs_array_length(var.ty, &module.types));
+
+ // skip this entry point if any global bindings are missing,
+ // or their types are incompatible.
+ if !options.fake_missing_bindings {
+ for (var_handle, var) in module.global_variables.iter() {
+ if fun_info[var_handle].is_empty() {
+ continue;
+ }
+ if let Some(ref br) = var.binding {
+ let good = match options.per_stage_map[ep.stage].resources.get(br) {
+ Some(target) => {
+ let binding_ty = match module.types[var.ty].inner {
+ crate::TypeInner::BindingArray { base, .. } => {
+ &module.types[base].inner
+ }
+ ref ty => ty,
+ };
+ match *binding_ty {
+ crate::TypeInner::Image { .. } => target.texture.is_some(),
+ crate::TypeInner::Sampler { .. } => target.sampler.is_some(),
+ _ => target.buffer.is_some(),
+ }
+ }
+ None => false,
+ };
+ if !good {
+ ep_error = Some(super::EntryPointError::MissingBinding(br.clone()));
+ break;
+ }
+ }
+ if var.space == crate::AddressSpace::PushConstant {
+ if let Err(e) = options.resolve_push_constants(ep.stage) {
+ ep_error = Some(e);
+ break;
+ }
+ }
+ }
+ if supports_array_length {
+ if let Err(err) = options.resolve_sizes_buffer(ep.stage) {
+ ep_error = Some(err);
+ }
+ }
+ }
+
+ if let Some(err) = ep_error {
+ info.entry_point_names.push(Err(err));
+ continue;
+ }
+ let fun_name = &self.names[&NameKey::EntryPoint(ep_index as _)];
+ info.entry_point_names.push(Ok(fun_name.clone()));
+
+ writeln!(self.out)?;
+
+ let (em_str, in_mode, out_mode) = match ep.stage {
+ crate::ShaderStage::Vertex => (
+ "vertex",
+ LocationMode::VertexInput,
+ LocationMode::VertexOutput,
+ ),
+ crate::ShaderStage::Fragment { .. } => (
+ "fragment",
+ LocationMode::FragmentInput,
+ LocationMode::FragmentOutput,
+ ),
+ crate::ShaderStage::Compute { .. } => {
+ ("kernel", LocationMode::Uniform, LocationMode::Uniform)
+ }
+ };
+
+ // List all the Naga `EntryPoint`'s `Function`'s arguments,
+ // flattening structs into their members. In Metal, we will pass
+ // each of these values to the entry point as a separate argument—
+ // except for the varyings, handled next.
+ let mut flattened_arguments = Vec::new();
+ for (arg_index, arg) in fun.arguments.iter().enumerate() {
+ match module.types[arg.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ for (member_index, member) in members.iter().enumerate() {
+ let member_index = member_index as u32;
+ flattened_arguments.push((
+ NameKey::StructMember(arg.ty, member_index),
+ member.ty,
+ member.binding.as_ref(),
+ ));
+ }
+ }
+ _ => flattened_arguments.push((
+ NameKey::EntryPointArgument(ep_index as _, arg_index as u32),
+ arg.ty,
+ arg.binding.as_ref(),
+ )),
+ }
+ }
+
+ // Identify the varyings among the argument values, and emit a
+ // struct type named `<fun>Input` to hold them.
+ let stage_in_name = format!("{}Input", fun_name);
+ let varyings_member_name = self.namer.call("varyings");
+ let mut has_varyings = false;
+ if !flattened_arguments.is_empty() {
+ writeln!(self.out, "struct {} {{", stage_in_name)?;
+ for &(ref name_key, ty, binding) in flattened_arguments.iter() {
+ let binding = match binding {
+ Some(ref binding @ &crate::Binding::Location { .. }) => binding,
+ _ => continue,
+ };
+ has_varyings = true;
+ let name = &self.names[name_key];
+ let ty_name = TypeContext {
+ handle: ty,
+ module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ let resolved = options.resolve_local_binding(binding, in_mode)?;
+ write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
+ resolved.try_fmt(&mut self.out)?;
+ writeln!(self.out, ";")?;
+ }
+ writeln!(self.out, "}};")?;
+ }
+
+ // Define a struct type named for the return value, if any, named
+ // `<fun>Output`.
+ let stage_out_name = format!("{}Output", fun_name);
+ let result_member_name = self.namer.call("member");
+ let result_type_name = match fun.result {
+ Some(ref result) => {
+ let mut result_members = Vec::new();
+ if let crate::TypeInner::Struct { ref members, .. } =
+ module.types[result.ty].inner
+ {
+ for (member_index, member) in members.iter().enumerate() {
+ result_members.push((
+ &self.names[&NameKey::StructMember(result.ty, member_index as u32)],
+ member.ty,
+ member.binding.as_ref(),
+ ));
+ }
+ } else {
+ result_members.push((
+ &result_member_name,
+ result.ty,
+ result.binding.as_ref(),
+ ));
+ }
+
+ writeln!(self.out, "struct {} {{", stage_out_name)?;
+ let mut has_point_size = false;
+ for (name, ty, binding) in result_members {
+ let ty_name = TypeContext {
+ handle: ty,
+ module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: true,
+ };
+ let binding = binding.ok_or(Error::Validation)?;
+
+ match *binding {
+ // Point size is only supported in VS of pipelines with
+ // point primitive topology.
+ crate::Binding::BuiltIn(crate::BuiltIn::PointSize) => {
+ has_point_size = true;
+ if !pipeline_options.allow_point_size {
+ continue;
+ }
+ }
+ // Cull Distance is not supported in Metal.
+ // But we can't return UnsupportedBuiltIn error to user.
+ // Because otherwise we can't generate msl shader from any glslang SPIR-V shaders.
+ // glslang generates gl_PerVertex struct with gl_CullDistance builtin inside by default.
+ crate::Binding::BuiltIn(crate::BuiltIn::CullDistance) => {
+ log::warn!("Ignoring CullDistance BuiltIn");
+ continue;
+ }
+ _ => {}
+ }
+
+ let array_len = match module.types[ty].inner {
+ crate::TypeInner::Array {
+ size: crate::ArraySize::Constant(handle),
+ ..
+ } => module.constants[handle].to_array_length(),
+ _ => None,
+ };
+ let resolved = options.resolve_local_binding(binding, out_mode)?;
+ write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
+ if let Some(array_len) = array_len {
+ write!(self.out, " [{}]", array_len)?;
+ }
+ resolved.try_fmt(&mut self.out)?;
+ writeln!(self.out, ";")?;
+ }
+
+ if pipeline_options.allow_point_size
+ && ep.stage == crate::ShaderStage::Vertex
+ && !has_point_size
+ {
+ // inject the point size output last
+ writeln!(
+ self.out,
+ "{}float _point_size [[point_size]];",
+ back::INDENT
+ )?;
+ }
+ writeln!(self.out, "}};")?;
+ &stage_out_name
+ }
+ None => "void",
+ };
+
+ // Write the entry point function's name, and begin its argument list.
+ writeln!(self.out, "{} {} {}(", em_str, result_type_name, fun_name)?;
+ let mut is_first_argument = true;
+
+ // If we have produced a struct holding the `EntryPoint`'s
+ // `Function`'s arguments' varyings, pass that struct first.
+ if has_varyings {
+ writeln!(
+ self.out,
+ " {} {} [[stage_in]]",
+ stage_in_name, varyings_member_name
+ )?;
+ is_first_argument = false;
+ }
+
+ // Then pass the remaining arguments not included in the varyings
+ // struct.
+ //
+ // Since `Namer.reset` wasn't expecting struct members to be
+ // suddenly injected into the normal namespace like this,
+ // `self.names` doesn't keep them distinct from other variables.
+ // Generate fresh names for these arguments, and remember the
+ // mapping.
+ let mut flattened_member_names = FastHashMap::default();
+ for &(ref name_key, ty, binding) in flattened_arguments.iter() {
+ let binding = match binding {
+ Some(ref binding @ &crate::Binding::BuiltIn { .. }) => binding,
+ _ => continue,
+ };
+ let name = if let NameKey::StructMember(ty, index) = *name_key {
+ // We should always insert a fresh entry here, but use
+ // `or_insert` to get a reference to the `String` we just
+ // inserted.
+ flattened_member_names
+ .entry(NameKey::StructMember(ty, index))
+ .or_insert_with(|| self.namer.call(&self.names[name_key]))
+ } else {
+ &self.names[name_key]
+ };
+ let ty_name = TypeContext {
+ handle: ty,
+ module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ let resolved = options.resolve_local_binding(binding, in_mode)?;
+ let separator = if is_first_argument {
+ is_first_argument = false;
+ ' '
+ } else {
+ ','
+ };
+ write!(self.out, "{} {} {}", separator, ty_name, name)?;
+ resolved.try_fmt(&mut self.out)?;
+ writeln!(self.out)?;
+ }
+
+ // Those global variables used by this entry point and its callees
+ // get passed as arguments. `Private` globals are an exception, they
+ // don't outlive this invocation, so we declare them below as locals
+ // within the entry point.
+ for (handle, var) in module.global_variables.iter() {
+ let usage = fun_info[handle];
+ if usage.is_empty() || var.space == crate::AddressSpace::Private {
+ continue;
+ }
+ // the resolves have already been checked for `!fake_missing_bindings` case
+ let resolved = match var.space {
+ crate::AddressSpace::PushConstant => {
+ options.resolve_push_constants(ep.stage).ok()
+ }
+ crate::AddressSpace::WorkGroup => None,
+ crate::AddressSpace::Storage { .. } if options.lang_version < (2, 0) => {
+ return Err(Error::UnsupportedAddressSpace(var.space))
+ }
+ _ => options
+ .resolve_resource_binding(ep.stage, var.binding.as_ref().unwrap())
+ .ok(),
+ };
+ if let Some(ref resolved) = resolved {
+ // Inline samplers are be defined in the EP body
+ if resolved.as_inline_sampler(options).is_some() {
+ continue;
+ }
+ }
+
+ let tyvar = TypedGlobalVariable {
+ module,
+ names: &self.names,
+ handle,
+ usage,
+ binding: resolved.as_ref(),
+ reference: true,
+ };
+ let separator = if is_first_argument {
+ is_first_argument = false;
+ ' '
+ } else {
+ ','
+ };
+ write!(self.out, "{} ", separator)?;
+ tyvar.try_fmt(&mut self.out)?;
+ if let Some(resolved) = resolved {
+ resolved.try_fmt(&mut self.out)?;
+ }
+ if let Some(value) = var.init {
+ let coco = ConstantContext {
+ handle: value,
+ arena: &module.constants,
+ names: &self.names,
+ first_time: false,
+ };
+ write!(self.out, " = {}", coco)?;
+ }
+ writeln!(self.out)?;
+ }
+
+ // If this entry uses any variable-length arrays, their sizes are
+ // passed as a final struct-typed argument.
+ if supports_array_length {
+ // this is checked earlier
+ let resolved = options.resolve_sizes_buffer(ep.stage).unwrap();
+ let separator = if module.global_variables.is_empty() {
+ ' '
+ } else {
+ ','
+ };
+ write!(
+ self.out,
+ "{} constant _mslBufferSizes& _buffer_sizes",
+ separator,
+ )?;
+ resolved.try_fmt(&mut self.out)?;
+ writeln!(self.out)?;
+ }
+
+ // end of the entry point argument list
+ writeln!(self.out, ") {{")?;
+
+ // Metal doesn't support private mutable variables outside of functions,
+ // so we put them here, just like the locals.
+ for (handle, var) in module.global_variables.iter() {
+ let usage = fun_info[handle];
+ if usage.is_empty() {
+ continue;
+ }
+ if var.space == crate::AddressSpace::Private {
+ let tyvar = TypedGlobalVariable {
+ module,
+ names: &self.names,
+ handle,
+ usage,
+ binding: None,
+ reference: false,
+ };
+ write!(self.out, "{}", back::INDENT)?;
+ tyvar.try_fmt(&mut self.out)?;
+ match var.init {
+ Some(value) => {
+ let coco = ConstantContext {
+ handle: value,
+ arena: &module.constants,
+ names: &self.names,
+ first_time: false,
+ };
+ writeln!(self.out, " = {};", coco)?;
+ }
+ None => {
+ writeln!(self.out, " = {{}};")?;
+ }
+ };
+ } else if let Some(ref binding) = var.binding {
+ // write an inline sampler
+ let resolved = options.resolve_resource_binding(ep.stage, binding).unwrap();
+ if let Some(sampler) = resolved.as_inline_sampler(options) {
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ writeln!(
+ self.out,
+ "{}constexpr {}::sampler {}(",
+ back::INDENT,
+ NAMESPACE,
+ name
+ )?;
+ self.put_inline_sampler_properties(back::Level(2), sampler)?;
+ writeln!(self.out, "{});", back::INDENT)?;
+ }
+ }
+ }
+
+ // Now take the arguments that we gathered into structs, and the
+ // structs that we flattened into arguments, and emit local
+ // variables with initializers that put everything back the way the
+ // body code expects.
+ //
+ // If we had to generate fresh names for struct members passed as
+ // arguments, be sure to use those names when rebuilding the struct.
+ //
+ // "Each day, I change some zeros to ones, and some ones to zeros.
+ // The rest, I leave alone."
+ for (arg_index, arg) in fun.arguments.iter().enumerate() {
+ let arg_name =
+ &self.names[&NameKey::EntryPointArgument(ep_index as _, arg_index as u32)];
+ match module.types[arg.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ let struct_name = &self.names[&NameKey::Type(arg.ty)];
+ write!(
+ self.out,
+ "{}const {} {} = {{ ",
+ back::INDENT,
+ struct_name,
+ arg_name
+ )?;
+ for (member_index, member) in members.iter().enumerate() {
+ let key = NameKey::StructMember(arg.ty, member_index as u32);
+ // If it's not in the varying struct, then we should
+ // have passed it as its own argument and assigned
+ // it a new name.
+ let name = match member.binding {
+ Some(crate::Binding::BuiltIn { .. }) => {
+ &flattened_member_names[&key]
+ }
+ _ => &self.names[&key],
+ };
+ if member_index != 0 {
+ write!(self.out, ", ")?;
+ }
+ if let Some(crate::Binding::Location { .. }) = member.binding {
+ write!(self.out, "{}.", varyings_member_name)?;
+ }
+ write!(self.out, "{}", name)?;
+ }
+ writeln!(self.out, " }};")?;
+ }
+ _ => {
+ if let Some(crate::Binding::Location { .. }) = arg.binding {
+ writeln!(
+ self.out,
+ "{}const auto {} = {}.{};",
+ back::INDENT,
+ arg_name,
+ varyings_member_name,
+ arg_name
+ )?;
+ }
+ }
+ }
+ }
+
+ // Finally, declare all the local variables that we need
+ //TODO: we can postpone this till the relevant expressions are emitted
+ for (local_handle, local) in fun.local_variables.iter() {
+ let name = &self.names[&NameKey::EntryPointLocal(ep_index as _, local_handle)];
+ let ty_name = TypeContext {
+ handle: local.ty,
+ module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
+ match local.init {
+ Some(value) => {
+ let coco = ConstantContext {
+ handle: value,
+ arena: &module.constants,
+ names: &self.names,
+ first_time: false,
+ };
+ write!(self.out, " = {}", coco)?;
+ }
+ None => {
+ write!(self.out, " = {{}}")?;
+ }
+ };
+ writeln!(self.out, ";")?;
+ }
+
+ let guarded_indices =
+ index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
+
+ let context = StatementContext {
+ expression: ExpressionContext {
+ function: fun,
+ origin: FunctionOrigin::EntryPoint(ep_index as _),
+ info: fun_info,
+ policies: options.bounds_check_policies,
+ guarded_indices,
+ module,
+ pipeline_options,
+ },
+ mod_info,
+ result_struct: Some(&stage_out_name),
+ };
+ self.named_expressions.clear();
+ self.update_expressions_to_bake(fun, fun_info, &context.expression);
+ self.put_block(back::Level(1), &fun.body, &context)?;
+ writeln!(self.out, "}}")?;
+ if ep_index + 1 != module.entry_points.len() {
+ writeln!(self.out)?;
+ }
+ }
+
+ Ok(info)
+ }
+}
+
+#[test]
+fn test_stack_size() {
+ use crate::valid::{Capabilities, ValidationFlags};
+ // create a module with at least one expression nested
+ let mut module = crate::Module::default();
+ let constant = module.constants.append(
+ crate::Constant {
+ name: None,
+ specialization: None,
+ inner: crate::ConstantInner::Scalar {
+ value: crate::ScalarValue::Float(1.0),
+ width: 4,
+ },
+ },
+ Default::default(),
+ );
+ let mut fun = crate::Function::default();
+ let const_expr = fun
+ .expressions
+ .append(crate::Expression::Constant(constant), Default::default());
+ let nested_expr = fun.expressions.append(
+ crate::Expression::Unary {
+ op: crate::UnaryOperator::Negate,
+ expr: const_expr,
+ },
+ Default::default(),
+ );
+ fun.body.push(
+ crate::Statement::Emit(fun.expressions.range_from(1)),
+ Default::default(),
+ );
+ fun.body.push(
+ crate::Statement::If {
+ condition: nested_expr,
+ accept: crate::Block::new(),
+ reject: crate::Block::new(),
+ },
+ Default::default(),
+ );
+ let _ = module.functions.append(fun, Default::default());
+ // analyse the module
+ let info = crate::valid::Validator::new(ValidationFlags::empty(), Capabilities::empty())
+ .validate(&module)
+ .unwrap();
+ // process the module
+ let mut writer = Writer::new(String::new());
+ writer
+ .write(&module, &info, &Default::default(), &Default::default())
+ .unwrap();
+
+ {
+ // check expression stack
+ let mut addresses = usize::MAX..0usize;
+ for pointer in writer.put_expression_stack_pointers {
+ addresses.start = addresses.start.min(pointer as usize);
+ addresses.end = addresses.end.max(pointer as usize);
+ }
+ let stack_size = addresses.end - addresses.start;
+ // check the size (in debug only)
+ // last observed macOS value: 20528 (CI)
+ if !(11000..=25000).contains(&stack_size) {
+ panic!("`put_expression` stack size {} has changed!", stack_size);
+ }
+ }
+
+ {
+ // check block stack
+ let mut addresses = usize::MAX..0usize;
+ for pointer in writer.put_block_stack_pointers {
+ addresses.start = addresses.start.min(pointer as usize);
+ addresses.end = addresses.end.max(pointer as usize);
+ }
+ let stack_size = addresses.end - addresses.start;
+ // check the size (in debug only)
+ // last observed macOS value: 19152 (CI)
+ if !(9500..=20000).contains(&stack_size) {
+ panic!("`put_block` stack size {} has changed!", stack_size);
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/block.rs b/third_party/rust/naga/src/back/spv/block.rs
new file mode 100644
index 0000000000..10fd5d72aa
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/block.rs
@@ -0,0 +1,2121 @@
+/*!
+Implementations for `BlockContext` methods.
+*/
+
+use super::{
+ index::BoundsCheckResult, make_local, selection::Selection, Block, BlockContext, Dimension,
+ Error, Instruction, LocalType, LookupType, LoopContext, ResultMember, Writer, WriterFlags,
+};
+use crate::{arena::Handle, proc::TypeResolution};
+use spirv::Word;
+
+fn get_dimension(type_inner: &crate::TypeInner) -> Dimension {
+ match *type_inner {
+ crate::TypeInner::Scalar { .. } => Dimension::Scalar,
+ crate::TypeInner::Vector { .. } => Dimension::Vector,
+ crate::TypeInner::Matrix { .. } => Dimension::Matrix,
+ _ => unreachable!(),
+ }
+}
+
+/// The results of emitting code for a left-hand-side expression.
+///
+/// On success, `write_expression_pointer` returns one of these.
+enum ExpressionPointer {
+ /// The pointer to the expression's value is available, as the value of the
+ /// expression with the given id.
+ Ready { pointer_id: Word },
+
+ /// The access expression must be conditional on the value of `condition`, a boolean
+ /// expression that is true if all indices are in bounds. If `condition` is true, then
+ /// `access` is an `OpAccessChain` instruction that will compute a pointer to the
+ /// expression's value. If `condition` is false, then executing `access` would be
+ /// undefined behavior.
+ Conditional {
+ condition: Word,
+ access: Instruction,
+ },
+}
+
+/// The termination statement to be added to the end of the block
+pub enum BlockExit {
+ /// Generates an OpReturn (void return)
+ Return,
+ /// Generates an OpBranch to the specified block
+ Branch {
+ /// The branch target block
+ target: Word,
+ },
+ /// Translates a loop `break if` into an `OpBranchConditional` to the
+ /// merge block if true (the merge block is passed through [`LoopContext::break_id`]
+ /// or else to the loop header (passed through [`preamble_id`])
+ ///
+ /// [`preamble_id`]: Self::BreakIf::preamble_id
+ BreakIf {
+ /// The condition of the `break if`
+ condition: Handle<crate::Expression>,
+ /// The loop header block id
+ preamble_id: Word,
+ },
+}
+
+impl Writer {
+ // Flip Y coordinate to adjust for coordinate space difference
+ // between SPIR-V and our IR.
+ // The `position_id` argument is a pointer to a `vecN<f32>`,
+ // whose `y` component we will negate.
+ fn write_epilogue_position_y_flip(
+ &mut self,
+ position_id: Word,
+ body: &mut Vec<Instruction>,
+ ) -> Result<(), Error> {
+ let float_ptr_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ pointer_space: Some(spirv::StorageClass::Output),
+ }));
+ let index_y_id = self.get_index_constant(1);
+ let access_id = self.id_gen.next();
+ body.push(Instruction::access_chain(
+ float_ptr_type_id,
+ access_id,
+ position_id,
+ &[index_y_id],
+ ));
+
+ let float_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ pointer_space: None,
+ }));
+ let load_id = self.id_gen.next();
+ body.push(Instruction::load(float_type_id, load_id, access_id, None));
+
+ let neg_id = self.id_gen.next();
+ body.push(Instruction::unary(
+ spirv::Op::FNegate,
+ float_type_id,
+ neg_id,
+ load_id,
+ ));
+
+ body.push(Instruction::store(access_id, neg_id, None));
+ Ok(())
+ }
+
+ // Clamp fragment depth between 0 and 1.
+ fn write_epilogue_frag_depth_clamp(
+ &mut self,
+ frag_depth_id: Word,
+ body: &mut Vec<Instruction>,
+ ) -> Result<(), Error> {
+ let float_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ pointer_space: None,
+ }));
+ let value0_id = self.get_constant_scalar(crate::ScalarValue::Float(0.0), 4);
+ let value1_id = self.get_constant_scalar(crate::ScalarValue::Float(1.0), 4);
+
+ let original_id = self.id_gen.next();
+ body.push(Instruction::load(
+ float_type_id,
+ original_id,
+ frag_depth_id,
+ None,
+ ));
+
+ let clamp_id = self.id_gen.next();
+ body.push(Instruction::ext_inst(
+ self.gl450_ext_inst_id,
+ spirv::GLOp::FClamp,
+ float_type_id,
+ clamp_id,
+ &[original_id, value0_id, value1_id],
+ ));
+
+ body.push(Instruction::store(frag_depth_id, clamp_id, None));
+ Ok(())
+ }
+
+ fn write_entry_point_return(
+ &mut self,
+ value_id: Word,
+ ir_result: &crate::FunctionResult,
+ result_members: &[ResultMember],
+ body: &mut Vec<Instruction>,
+ ) -> Result<(), Error> {
+ for (index, res_member) in result_members.iter().enumerate() {
+ let member_value_id = match ir_result.binding {
+ Some(_) => value_id,
+ None => {
+ let member_value_id = self.id_gen.next();
+ body.push(Instruction::composite_extract(
+ res_member.type_id,
+ member_value_id,
+ value_id,
+ &[index as u32],
+ ));
+ member_value_id
+ }
+ };
+
+ body.push(Instruction::store(res_member.id, member_value_id, None));
+
+ match res_member.built_in {
+ Some(crate::BuiltIn::Position { .. })
+ if self.flags.contains(WriterFlags::ADJUST_COORDINATE_SPACE) =>
+ {
+ self.write_epilogue_position_y_flip(res_member.id, body)?;
+ }
+ Some(crate::BuiltIn::FragDepth)
+ if self.flags.contains(WriterFlags::CLAMP_FRAG_DEPTH) =>
+ {
+ self.write_epilogue_frag_depth_clamp(res_member.id, body)?;
+ }
+ _ => {}
+ }
+ }
+ Ok(())
+ }
+}
+
+impl<'w> BlockContext<'w> {
+ /// Decide whether to put off emitting instructions for `expr_handle`.
+ ///
+ /// We would like to gather together chains of `Access` and `AccessIndex`
+ /// Naga expressions into a single `OpAccessChain` SPIR-V instruction. To do
+ /// this, we don't generate instructions for these exprs when we first
+ /// encounter them. Their ids in `self.writer.cached.ids` are left as zero. Then,
+ /// once we encounter a `Load` or `Store` expression that actually needs the
+ /// chain's value, we call `write_expression_pointer` to handle the whole
+ /// thing in one fell swoop.
+ fn is_intermediate(&self, expr_handle: Handle<crate::Expression>) -> bool {
+ match self.ir_function.expressions[expr_handle] {
+ crate::Expression::GlobalVariable(handle) => {
+ let ty = self.ir_module.global_variables[handle].ty;
+ match self.ir_module.types[ty].inner {
+ crate::TypeInner::BindingArray { .. } => false,
+ _ => true,
+ }
+ }
+ crate::Expression::LocalVariable(_) => true,
+ crate::Expression::FunctionArgument(index) => {
+ let arg = &self.ir_function.arguments[index as usize];
+ self.ir_module.types[arg.ty].inner.pointer_space().is_some()
+ }
+
+ // The chain rule: if this `Access...`'s `base` operand was
+ // previously omitted, then omit this one, too.
+ _ => self.cached.ids[expr_handle.index()] == 0,
+ }
+ }
+
+ /// Cache an expression for a value.
+ pub(super) fn cache_expression_value(
+ &mut self,
+ expr_handle: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<(), Error> {
+ let result_type_id = self.get_expression_type_id(&self.fun_info[expr_handle].ty);
+
+ let id = match self.ir_function.expressions[expr_handle] {
+ crate::Expression::Access { base, index: _ } if self.is_intermediate(base) => {
+ // See `is_intermediate`; we'll handle this later in
+ // `write_expression_pointer`.
+ 0
+ }
+ crate::Expression::Access { base, index } => {
+ let base_ty_inner = self.fun_info[base].ty.inner_with(&self.ir_module.types);
+ match *base_ty_inner {
+ crate::TypeInner::Vector { .. } => {
+ self.write_vector_access(expr_handle, base, index, block)?
+ }
+ crate::TypeInner::BindingArray {
+ base: binding_type, ..
+ } => {
+ let binding_array_false_pointer = LookupType::Local(LocalType::Pointer {
+ base: binding_type,
+ class: spirv::StorageClass::UniformConstant,
+ });
+
+ let result_id = match self.write_expression_pointer(
+ expr_handle,
+ block,
+ Some(binding_array_false_pointer),
+ )? {
+ ExpressionPointer::Ready { pointer_id } => pointer_id,
+ ExpressionPointer::Conditional { .. } => {
+ return Err(Error::FeatureNotImplemented(
+ "Texture array out-of-bounds handling",
+ ));
+ }
+ };
+
+ let binding_type_id = self.get_type_id(LookupType::Handle(binding_type));
+
+ let load_id = self.gen_id();
+ block.body.push(Instruction::load(
+ binding_type_id,
+ load_id,
+ result_id,
+ None,
+ ));
+
+ if self.fun_info[index].uniformity.non_uniform_result.is_some() {
+ self.writer.require_any(
+ "NonUniformEXT",
+ &[spirv::Capability::ShaderNonUniform],
+ )?;
+ self.writer.use_extension("SPV_EXT_descriptor_indexing");
+ self.writer
+ .decorate(load_id, spirv::Decoration::NonUniform, &[]);
+ }
+ load_id
+ }
+ ref other => {
+ log::error!(
+ "Unable to access base {:?} of type {:?}",
+ self.ir_function.expressions[base],
+ other
+ );
+ return Err(Error::Validation(
+ "only vectors may be dynamically indexed by value",
+ ));
+ }
+ }
+ }
+ crate::Expression::AccessIndex { base, index: _ } if self.is_intermediate(base) => {
+ // See `is_intermediate`; we'll handle this later in
+ // `write_expression_pointer`.
+ 0
+ }
+ crate::Expression::AccessIndex { base, index } => {
+ match *self.fun_info[base].ty.inner_with(&self.ir_module.types) {
+ crate::TypeInner::Vector { .. }
+ | crate::TypeInner::Matrix { .. }
+ | crate::TypeInner::Array { .. }
+ | crate::TypeInner::Struct { .. } => {
+ // We never need bounds checks here: dynamically sized arrays can
+ // only appear behind pointers, and are thus handled by the
+ // `is_intermediate` case above. Everything else's size is
+ // statically known and checked in validation.
+ let id = self.gen_id();
+ let base_id = self.cached[base];
+ block.body.push(Instruction::composite_extract(
+ result_type_id,
+ id,
+ base_id,
+ &[index],
+ ));
+ id
+ }
+ crate::TypeInner::BindingArray {
+ base: binding_type, ..
+ } => {
+ let binding_array_false_pointer = LookupType::Local(LocalType::Pointer {
+ base: binding_type,
+ class: spirv::StorageClass::UniformConstant,
+ });
+
+ let result_id = match self.write_expression_pointer(
+ expr_handle,
+ block,
+ Some(binding_array_false_pointer),
+ )? {
+ ExpressionPointer::Ready { pointer_id } => pointer_id,
+ ExpressionPointer::Conditional { .. } => {
+ return Err(Error::FeatureNotImplemented(
+ "Texture array out-of-bounds handling",
+ ));
+ }
+ };
+
+ let binding_type_id = self.get_type_id(LookupType::Handle(binding_type));
+
+ let load_id = self.gen_id();
+ block.body.push(Instruction::load(
+ binding_type_id,
+ load_id,
+ result_id,
+ None,
+ ));
+
+ load_id
+ }
+ ref other => {
+ log::error!("Unable to access index of {:?}", other);
+ return Err(Error::FeatureNotImplemented("access index for type"));
+ }
+ }
+ }
+ crate::Expression::GlobalVariable(handle) => {
+ self.writer.global_variables[handle.index()].access_id
+ }
+ crate::Expression::Constant(handle) => self.writer.constant_ids[handle.index()],
+ crate::Expression::Splat { size, value } => {
+ let value_id = self.cached[value];
+ let components = [value_id; 4];
+ let id = self.gen_id();
+ block.body.push(Instruction::composite_construct(
+ result_type_id,
+ id,
+ &components[..size as usize],
+ ));
+ id
+ }
+ crate::Expression::Swizzle {
+ size,
+ vector,
+ pattern,
+ } => {
+ let vector_id = self.cached[vector];
+ self.temp_list.clear();
+ for &sc in pattern[..size as usize].iter() {
+ self.temp_list.push(sc as Word);
+ }
+ let id = self.gen_id();
+ block.body.push(Instruction::vector_shuffle(
+ result_type_id,
+ id,
+ vector_id,
+ vector_id,
+ &self.temp_list,
+ ));
+ id
+ }
+ crate::Expression::Compose {
+ ty: _,
+ ref components,
+ } => {
+ self.temp_list.clear();
+ for &component in components {
+ self.temp_list.push(self.cached[component]);
+ }
+
+ let id = self.gen_id();
+ block.body.push(Instruction::composite_construct(
+ result_type_id,
+ id,
+ &self.temp_list,
+ ));
+ id
+ }
+ crate::Expression::Unary { op, expr } => {
+ let id = self.gen_id();
+ let expr_id = self.cached[expr];
+ let expr_ty_inner = self.fun_info[expr].ty.inner_with(&self.ir_module.types);
+
+ let spirv_op = match op {
+ crate::UnaryOperator::Negate => match expr_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Float) => spirv::Op::FNegate,
+ Some(crate::ScalarKind::Sint) => spirv::Op::SNegate,
+ Some(crate::ScalarKind::Bool) => spirv::Op::LogicalNot,
+ Some(crate::ScalarKind::Uint) | None => {
+ log::error!("Unable to negate {:?}", expr_ty_inner);
+ return Err(Error::FeatureNotImplemented("negation"));
+ }
+ },
+ crate::UnaryOperator::Not => match expr_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Bool) => spirv::Op::LogicalNot,
+ _ => spirv::Op::Not,
+ },
+ };
+
+ block
+ .body
+ .push(Instruction::unary(spirv_op, result_type_id, id, expr_id));
+ id
+ }
+ crate::Expression::Binary { op, left, right } => {
+ let id = self.gen_id();
+ let left_id = self.cached[left];
+ let right_id = self.cached[right];
+
+ let left_ty_inner = self.fun_info[left].ty.inner_with(&self.ir_module.types);
+ let right_ty_inner = self.fun_info[right].ty.inner_with(&self.ir_module.types);
+
+ let left_dimension = get_dimension(left_ty_inner);
+ let right_dimension = get_dimension(right_ty_inner);
+
+ let mut reverse_operands = false;
+
+ let spirv_op = match op {
+ crate::BinaryOperator::Add => match *left_ty_inner {
+ crate::TypeInner::Scalar { kind, .. }
+ | crate::TypeInner::Vector { kind, .. } => match kind {
+ crate::ScalarKind::Float => spirv::Op::FAdd,
+ _ => spirv::Op::IAdd,
+ },
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ self.write_matrix_matrix_column_op(
+ block,
+ id,
+ result_type_id,
+ left_id,
+ right_id,
+ columns,
+ rows,
+ width,
+ spirv::Op::FAdd,
+ );
+
+ self.cached[expr_handle] = id;
+ return Ok(());
+ }
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::Subtract => match *left_ty_inner {
+ crate::TypeInner::Scalar { kind, .. }
+ | crate::TypeInner::Vector { kind, .. } => match kind {
+ crate::ScalarKind::Float => spirv::Op::FSub,
+ _ => spirv::Op::ISub,
+ },
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ self.write_matrix_matrix_column_op(
+ block,
+ id,
+ result_type_id,
+ left_id,
+ right_id,
+ columns,
+ rows,
+ width,
+ spirv::Op::FSub,
+ );
+
+ self.cached[expr_handle] = id;
+ return Ok(());
+ }
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::Multiply => match (left_dimension, right_dimension) {
+ (Dimension::Scalar, Dimension::Vector) => {
+ self.write_vector_scalar_mult(
+ block,
+ id,
+ result_type_id,
+ right_id,
+ left_id,
+ right_ty_inner,
+ );
+
+ self.cached[expr_handle] = id;
+ return Ok(());
+ }
+ (Dimension::Vector, Dimension::Scalar) => {
+ self.write_vector_scalar_mult(
+ block,
+ id,
+ result_type_id,
+ left_id,
+ right_id,
+ left_ty_inner,
+ );
+
+ self.cached[expr_handle] = id;
+ return Ok(());
+ }
+ (Dimension::Vector, Dimension::Matrix) => spirv::Op::VectorTimesMatrix,
+ (Dimension::Matrix, Dimension::Scalar) => spirv::Op::MatrixTimesScalar,
+ (Dimension::Scalar, Dimension::Matrix) => {
+ reverse_operands = true;
+ spirv::Op::MatrixTimesScalar
+ }
+ (Dimension::Matrix, Dimension::Vector) => spirv::Op::MatrixTimesVector,
+ (Dimension::Matrix, Dimension::Matrix) => spirv::Op::MatrixTimesMatrix,
+ (Dimension::Vector, Dimension::Vector)
+ | (Dimension::Scalar, Dimension::Scalar)
+ if left_ty_inner.scalar_kind() == Some(crate::ScalarKind::Float) =>
+ {
+ spirv::Op::FMul
+ }
+ (Dimension::Vector, Dimension::Vector)
+ | (Dimension::Scalar, Dimension::Scalar) => spirv::Op::IMul,
+ },
+ crate::BinaryOperator::Divide => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Sint) => spirv::Op::SDiv,
+ Some(crate::ScalarKind::Uint) => spirv::Op::UDiv,
+ Some(crate::ScalarKind::Float) => spirv::Op::FDiv,
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::Modulo => match left_ty_inner.scalar_kind() {
+ // TODO: handle undefined behavior
+ // if right == 0 return 0
+ // if left == min(type_of(left)) && right == -1 return 0
+ Some(crate::ScalarKind::Sint) => spirv::Op::SRem,
+ // TODO: handle undefined behavior
+ // if right == 0 return 0
+ Some(crate::ScalarKind::Uint) => spirv::Op::UMod,
+ // TODO: handle undefined behavior
+ // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798
+ Some(crate::ScalarKind::Float) => spirv::Op::FRem,
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::Equal => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
+ spirv::Op::IEqual
+ }
+ Some(crate::ScalarKind::Float) => spirv::Op::FOrdEqual,
+ Some(crate::ScalarKind::Bool) => spirv::Op::LogicalEqual,
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::NotEqual => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
+ spirv::Op::INotEqual
+ }
+ Some(crate::ScalarKind::Float) => spirv::Op::FOrdNotEqual,
+ Some(crate::ScalarKind::Bool) => spirv::Op::LogicalNotEqual,
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::Less => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Sint) => spirv::Op::SLessThan,
+ Some(crate::ScalarKind::Uint) => spirv::Op::ULessThan,
+ Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThan,
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::LessEqual => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Sint) => spirv::Op::SLessThanEqual,
+ Some(crate::ScalarKind::Uint) => spirv::Op::ULessThanEqual,
+ Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThanEqual,
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::Greater => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThan,
+ Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThan,
+ Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThan,
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::GreaterEqual => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThanEqual,
+ Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThanEqual,
+ Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThanEqual,
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::And => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Bool) => spirv::Op::LogicalAnd,
+ _ => spirv::Op::BitwiseAnd,
+ },
+ crate::BinaryOperator::ExclusiveOr => spirv::Op::BitwiseXor,
+ crate::BinaryOperator::InclusiveOr => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Bool) => spirv::Op::LogicalOr,
+ _ => spirv::Op::BitwiseOr,
+ },
+ crate::BinaryOperator::LogicalAnd => spirv::Op::LogicalAnd,
+ crate::BinaryOperator::LogicalOr => spirv::Op::LogicalOr,
+ crate::BinaryOperator::ShiftLeft => spirv::Op::ShiftLeftLogical,
+ crate::BinaryOperator::ShiftRight => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Sint) => spirv::Op::ShiftRightArithmetic,
+ Some(crate::ScalarKind::Uint) => spirv::Op::ShiftRightLogical,
+ _ => unimplemented!(),
+ },
+ };
+
+ block.body.push(Instruction::binary(
+ spirv_op,
+ result_type_id,
+ id,
+ if reverse_operands { right_id } else { left_id },
+ if reverse_operands { left_id } else { right_id },
+ ));
+ id
+ }
+ crate::Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ use crate::MathFunction as Mf;
+ enum MathOp {
+ Ext(spirv::GLOp),
+ Custom(Instruction),
+ }
+
+ let arg0_id = self.cached[arg];
+ let arg_ty = self.fun_info[arg].ty.inner_with(&self.ir_module.types);
+ let arg_scalar_kind = arg_ty.scalar_kind();
+ let arg1_id = match arg1 {
+ Some(handle) => self.cached[handle],
+ None => 0,
+ };
+ let arg2_id = match arg2 {
+ Some(handle) => self.cached[handle],
+ None => 0,
+ };
+ let arg3_id = match arg3 {
+ Some(handle) => self.cached[handle],
+ None => 0,
+ };
+
+ let id = self.gen_id();
+ let math_op = match fun {
+ // comparison
+ Mf::Abs => {
+ match arg_scalar_kind {
+ Some(crate::ScalarKind::Float) => MathOp::Ext(spirv::GLOp::FAbs),
+ Some(crate::ScalarKind::Sint) => MathOp::Ext(spirv::GLOp::SAbs),
+ Some(crate::ScalarKind::Uint) => {
+ MathOp::Custom(Instruction::unary(
+ spirv::Op::CopyObject, // do nothing
+ result_type_id,
+ id,
+ arg0_id,
+ ))
+ }
+ other => unimplemented!("Unexpected abs({:?})", other),
+ }
+ }
+ Mf::Min => MathOp::Ext(match arg_scalar_kind {
+ Some(crate::ScalarKind::Float) => spirv::GLOp::FMin,
+ Some(crate::ScalarKind::Sint) => spirv::GLOp::SMin,
+ Some(crate::ScalarKind::Uint) => spirv::GLOp::UMin,
+ other => unimplemented!("Unexpected min({:?})", other),
+ }),
+ Mf::Max => MathOp::Ext(match arg_scalar_kind {
+ Some(crate::ScalarKind::Float) => spirv::GLOp::FMax,
+ Some(crate::ScalarKind::Sint) => spirv::GLOp::SMax,
+ Some(crate::ScalarKind::Uint) => spirv::GLOp::UMax,
+ other => unimplemented!("Unexpected max({:?})", other),
+ }),
+ Mf::Clamp => MathOp::Ext(match arg_scalar_kind {
+ Some(crate::ScalarKind::Float) => spirv::GLOp::FClamp,
+ Some(crate::ScalarKind::Sint) => spirv::GLOp::SClamp,
+ Some(crate::ScalarKind::Uint) => spirv::GLOp::UClamp,
+ other => unimplemented!("Unexpected max({:?})", other),
+ }),
+ Mf::Saturate => {
+ let (maybe_size, width) = match *arg_ty {
+ crate::TypeInner::Vector { size, width, .. } => (Some(size), width),
+ crate::TypeInner::Scalar { width, .. } => (None, width),
+ ref other => unimplemented!("Unexpected saturate({:?})", other),
+ };
+
+ let mut arg1_id = self
+ .writer
+ .get_constant_scalar(crate::ScalarValue::Float(0.0), width);
+ let mut arg2_id = self
+ .writer
+ .get_constant_scalar(crate::ScalarValue::Float(1.0), width);
+
+ if let Some(size) = maybe_size {
+ let value = LocalType::Value {
+ vector_size: Some(size),
+ kind: crate::ScalarKind::Float,
+ width,
+ pointer_space: None,
+ };
+
+ let result_type_id = self.get_type_id(LookupType::Local(value));
+
+ self.temp_list.clear();
+ self.temp_list.resize(size as _, arg1_id);
+
+ let id = self.gen_id();
+ block.body.push(Instruction::composite_construct(
+ result_type_id,
+ id,
+ &self.temp_list,
+ ));
+ arg1_id = id;
+
+ self.temp_list.clear();
+ self.temp_list.resize(size as _, arg2_id);
+
+ let id = self.gen_id();
+ block.body.push(Instruction::composite_construct(
+ result_type_id,
+ id,
+ &self.temp_list,
+ ));
+ arg2_id = id;
+ }
+
+ MathOp::Custom(Instruction::ext_inst(
+ self.writer.gl450_ext_inst_id,
+ spirv::GLOp::FClamp,
+ result_type_id,
+ id,
+ &[arg0_id, arg1_id, arg2_id],
+ ))
+ }
+ // trigonometry
+ Mf::Sin => MathOp::Ext(spirv::GLOp::Sin),
+ Mf::Sinh => MathOp::Ext(spirv::GLOp::Sinh),
+ Mf::Asin => MathOp::Ext(spirv::GLOp::Asin),
+ Mf::Cos => MathOp::Ext(spirv::GLOp::Cos),
+ Mf::Cosh => MathOp::Ext(spirv::GLOp::Cosh),
+ Mf::Acos => MathOp::Ext(spirv::GLOp::Acos),
+ Mf::Tan => MathOp::Ext(spirv::GLOp::Tan),
+ Mf::Tanh => MathOp::Ext(spirv::GLOp::Tanh),
+ Mf::Atan => MathOp::Ext(spirv::GLOp::Atan),
+ Mf::Atan2 => MathOp::Ext(spirv::GLOp::Atan2),
+ Mf::Asinh => MathOp::Ext(spirv::GLOp::Asinh),
+ Mf::Acosh => MathOp::Ext(spirv::GLOp::Acosh),
+ Mf::Atanh => MathOp::Ext(spirv::GLOp::Atanh),
+ Mf::Radians => MathOp::Ext(spirv::GLOp::Radians),
+ Mf::Degrees => MathOp::Ext(spirv::GLOp::Degrees),
+ // decomposition
+ Mf::Ceil => MathOp::Ext(spirv::GLOp::Ceil),
+ Mf::Round => MathOp::Ext(spirv::GLOp::RoundEven),
+ Mf::Floor => MathOp::Ext(spirv::GLOp::Floor),
+ Mf::Fract => MathOp::Ext(spirv::GLOp::Fract),
+ Mf::Trunc => MathOp::Ext(spirv::GLOp::Trunc),
+ Mf::Modf => MathOp::Ext(spirv::GLOp::Modf),
+ Mf::Frexp => MathOp::Ext(spirv::GLOp::Frexp),
+ Mf::Ldexp => MathOp::Ext(spirv::GLOp::Ldexp),
+ // geometry
+ Mf::Dot => match *self.fun_info[arg].ty.inner_with(&self.ir_module.types) {
+ crate::TypeInner::Vector {
+ kind: crate::ScalarKind::Float,
+ ..
+ } => MathOp::Custom(Instruction::binary(
+ spirv::Op::Dot,
+ result_type_id,
+ id,
+ arg0_id,
+ arg1_id,
+ )),
+ // TODO: consider using integer dot product if VK_KHR_shader_integer_dot_product is available
+ crate::TypeInner::Vector { size, .. } => {
+ self.write_dot_product(
+ id,
+ result_type_id,
+ arg0_id,
+ arg1_id,
+ size as u32,
+ block,
+ );
+ self.cached[expr_handle] = id;
+ return Ok(());
+ }
+ _ => unreachable!(
+ "Correct TypeInner for dot product should be already validated"
+ ),
+ },
+ Mf::Outer => MathOp::Custom(Instruction::binary(
+ spirv::Op::OuterProduct,
+ result_type_id,
+ id,
+ arg0_id,
+ arg1_id,
+ )),
+ Mf::Cross => MathOp::Ext(spirv::GLOp::Cross),
+ Mf::Distance => MathOp::Ext(spirv::GLOp::Distance),
+ Mf::Length => MathOp::Ext(spirv::GLOp::Length),
+ Mf::Normalize => MathOp::Ext(spirv::GLOp::Normalize),
+ Mf::FaceForward => MathOp::Ext(spirv::GLOp::FaceForward),
+ Mf::Reflect => MathOp::Ext(spirv::GLOp::Reflect),
+ Mf::Refract => MathOp::Ext(spirv::GLOp::Refract),
+ // exponent
+ Mf::Exp => MathOp::Ext(spirv::GLOp::Exp),
+ Mf::Exp2 => MathOp::Ext(spirv::GLOp::Exp2),
+ Mf::Log => MathOp::Ext(spirv::GLOp::Log),
+ Mf::Log2 => MathOp::Ext(spirv::GLOp::Log2),
+ Mf::Pow => MathOp::Ext(spirv::GLOp::Pow),
+ // computational
+ Mf::Sign => MathOp::Ext(match arg_scalar_kind {
+ Some(crate::ScalarKind::Float) => spirv::GLOp::FSign,
+ Some(crate::ScalarKind::Sint) => spirv::GLOp::SSign,
+ other => unimplemented!("Unexpected sign({:?})", other),
+ }),
+ Mf::Fma => MathOp::Ext(spirv::GLOp::Fma),
+ Mf::Mix => {
+ let selector = arg2.unwrap();
+ let selector_ty =
+ self.fun_info[selector].ty.inner_with(&self.ir_module.types);
+ match (arg_ty, selector_ty) {
+ // if the selector is a scalar, we need to splat it
+ (
+ &crate::TypeInner::Vector { size, .. },
+ &crate::TypeInner::Scalar { kind, width },
+ ) => {
+ let selector_type_id =
+ self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(size),
+ kind,
+ width,
+ pointer_space: None,
+ }));
+ self.temp_list.clear();
+ self.temp_list.resize(size as usize, arg2_id);
+
+ let selector_id = self.gen_id();
+ block.body.push(Instruction::composite_construct(
+ selector_type_id,
+ selector_id,
+ &self.temp_list,
+ ));
+
+ MathOp::Custom(Instruction::ext_inst(
+ self.writer.gl450_ext_inst_id,
+ spirv::GLOp::FMix,
+ result_type_id,
+ id,
+ &[arg0_id, arg1_id, selector_id],
+ ))
+ }
+ _ => MathOp::Ext(spirv::GLOp::FMix),
+ }
+ }
+ Mf::Step => MathOp::Ext(spirv::GLOp::Step),
+ Mf::SmoothStep => MathOp::Ext(spirv::GLOp::SmoothStep),
+ Mf::Sqrt => MathOp::Ext(spirv::GLOp::Sqrt),
+ Mf::InverseSqrt => MathOp::Ext(spirv::GLOp::InverseSqrt),
+ Mf::Inverse => MathOp::Ext(spirv::GLOp::MatrixInverse),
+ Mf::Transpose => MathOp::Custom(Instruction::unary(
+ spirv::Op::Transpose,
+ result_type_id,
+ id,
+ arg0_id,
+ )),
+ Mf::Determinant => MathOp::Ext(spirv::GLOp::Determinant),
+ Mf::ReverseBits => MathOp::Custom(Instruction::unary(
+ spirv::Op::BitReverse,
+ result_type_id,
+ id,
+ arg0_id,
+ )),
+ Mf::CountOneBits => MathOp::Custom(Instruction::unary(
+ spirv::Op::BitCount,
+ result_type_id,
+ id,
+ arg0_id,
+ )),
+ Mf::ExtractBits => {
+ let op = match arg_scalar_kind {
+ Some(crate::ScalarKind::Uint) => spirv::Op::BitFieldUExtract,
+ Some(crate::ScalarKind::Sint) => spirv::Op::BitFieldSExtract,
+ other => unimplemented!("Unexpected sign({:?})", other),
+ };
+ MathOp::Custom(Instruction::ternary(
+ op,
+ result_type_id,
+ id,
+ arg0_id,
+ arg1_id,
+ arg2_id,
+ ))
+ }
+ Mf::InsertBits => MathOp::Custom(Instruction::quaternary(
+ spirv::Op::BitFieldInsert,
+ result_type_id,
+ id,
+ arg0_id,
+ arg1_id,
+ arg2_id,
+ arg3_id,
+ )),
+ Mf::FindLsb => MathOp::Ext(spirv::GLOp::FindILsb),
+ Mf::FindMsb => MathOp::Ext(match arg_scalar_kind {
+ Some(crate::ScalarKind::Uint) => spirv::GLOp::FindUMsb,
+ Some(crate::ScalarKind::Sint) => spirv::GLOp::FindSMsb,
+ other => unimplemented!("Unexpected findMSB({:?})", other),
+ }),
+ Mf::Pack4x8unorm => MathOp::Ext(spirv::GLOp::PackUnorm4x8),
+ Mf::Pack4x8snorm => MathOp::Ext(spirv::GLOp::PackSnorm4x8),
+ Mf::Pack2x16float => MathOp::Ext(spirv::GLOp::PackHalf2x16),
+ Mf::Pack2x16unorm => MathOp::Ext(spirv::GLOp::PackUnorm2x16),
+ Mf::Pack2x16snorm => MathOp::Ext(spirv::GLOp::PackSnorm2x16),
+ Mf::Unpack4x8unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm4x8),
+ Mf::Unpack4x8snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm4x8),
+ Mf::Unpack2x16float => MathOp::Ext(spirv::GLOp::UnpackHalf2x16),
+ Mf::Unpack2x16unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm2x16),
+ Mf::Unpack2x16snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm2x16),
+ };
+
+ block.body.push(match math_op {
+ MathOp::Ext(op) => Instruction::ext_inst(
+ self.writer.gl450_ext_inst_id,
+ op,
+ result_type_id,
+ id,
+ &[arg0_id, arg1_id, arg2_id, arg3_id][..fun.argument_count()],
+ ),
+ MathOp::Custom(inst) => inst,
+ });
+ id
+ }
+ crate::Expression::LocalVariable(variable) => self.function.variables[&variable].id,
+ crate::Expression::Load { pointer } => {
+ match self.write_expression_pointer(pointer, block, None)? {
+ ExpressionPointer::Ready { pointer_id } => {
+ let id = self.gen_id();
+ let atomic_space =
+ match *self.fun_info[pointer].ty.inner_with(&self.ir_module.types) {
+ crate::TypeInner::Pointer { base, space } => {
+ match self.ir_module.types[base].inner {
+ crate::TypeInner::Atomic { .. } => Some(space),
+ _ => None,
+ }
+ }
+ _ => None,
+ };
+ let instruction = if let Some(space) = atomic_space {
+ let (semantics, scope) = space.to_spirv_semantics_and_scope();
+ let scope_constant_id = self.get_scope_constant(scope as u32);
+ let semantics_id = self.get_index_constant(semantics.bits());
+ Instruction::atomic_load(
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ )
+ } else {
+ Instruction::load(result_type_id, id, pointer_id, None)
+ };
+ block.body.push(instruction);
+ id
+ }
+ ExpressionPointer::Conditional { condition, access } => {
+ //TODO: support atomics?
+ self.write_conditional_indexed_load(
+ result_type_id,
+ condition,
+ block,
+ move |id_gen, block| {
+ // The in-bounds path. Perform the access and the load.
+ let pointer_id = access.result_id.unwrap();
+ let value_id = id_gen.next();
+ block.body.push(access);
+ block.body.push(Instruction::load(
+ result_type_id,
+ value_id,
+ pointer_id,
+ None,
+ ));
+ value_id
+ },
+ )
+ }
+ }
+ }
+ crate::Expression::FunctionArgument(index) => self.function.parameter_id(index),
+ crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } => {
+ self.cached[expr_handle]
+ }
+ crate::Expression::As {
+ expr,
+ kind,
+ convert,
+ } => {
+ use crate::ScalarKind as Sk;
+
+ let expr_id = self.cached[expr];
+ let (src_kind, src_size, src_width, is_matrix) =
+ match *self.fun_info[expr].ty.inner_with(&self.ir_module.types) {
+ crate::TypeInner::Scalar { kind, width } => (kind, None, width, false),
+ crate::TypeInner::Vector { kind, width, size } => {
+ (kind, Some(size), width, false)
+ }
+ crate::TypeInner::Matrix { width, .. } => (kind, None, width, true),
+ ref other => {
+ log::error!("As source {:?}", other);
+ return Err(Error::Validation("Unexpected Expression::As source"));
+ }
+ };
+
+ enum Cast {
+ Identity,
+ Unary(spirv::Op),
+ Binary(spirv::Op, Word),
+ Ternary(spirv::Op, Word, Word),
+ }
+
+ let cast = if is_matrix {
+ // we only support identity casts for matrices
+ Cast::Unary(spirv::Op::CopyObject)
+ } else {
+ match (src_kind, kind, convert) {
+ // Filter out identity casts. Some Adreno drivers are
+ // confused by no-op OpBitCast instructions.
+ (src_kind, kind, convert)
+ if src_kind == kind && convert.unwrap_or(src_width) == src_width =>
+ {
+ Cast::Identity
+ }
+ (Sk::Bool, Sk::Bool, _) => Cast::Unary(spirv::Op::CopyObject),
+ (_, _, None) => Cast::Unary(spirv::Op::Bitcast),
+ // casting to a bool - generate `OpXxxNotEqual`
+ (_, Sk::Bool, Some(_)) => {
+ let (op, value) = match src_kind {
+ Sk::Sint => (spirv::Op::INotEqual, crate::ScalarValue::Sint(0)),
+ Sk::Uint => (spirv::Op::INotEqual, crate::ScalarValue::Uint(0)),
+ Sk::Float => {
+ (spirv::Op::FUnordNotEqual, crate::ScalarValue::Float(0.0))
+ }
+ Sk::Bool => unreachable!(),
+ };
+ let zero_scalar_id = self.writer.get_constant_scalar(value, src_width);
+ let zero_id = match src_size {
+ Some(size) => {
+ let vector_type_id =
+ self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(size),
+ kind: src_kind,
+ width: src_width,
+ pointer_space: None,
+ }));
+ let components = [zero_scalar_id; 4];
+
+ let zero_id = self.gen_id();
+ block.body.push(Instruction::composite_construct(
+ vector_type_id,
+ zero_id,
+ &components[..size as usize],
+ ));
+ zero_id
+ }
+ None => zero_scalar_id,
+ };
+
+ Cast::Binary(op, zero_id)
+ }
+ // casting from a bool - generate `OpSelect`
+ (Sk::Bool, _, Some(dst_width)) => {
+ let (val0, val1) = match kind {
+ Sk::Sint => {
+ (crate::ScalarValue::Sint(0), crate::ScalarValue::Sint(1))
+ }
+ Sk::Uint => {
+ (crate::ScalarValue::Uint(0), crate::ScalarValue::Uint(1))
+ }
+ Sk::Float => (
+ crate::ScalarValue::Float(0.0),
+ crate::ScalarValue::Float(1.0),
+ ),
+ Sk::Bool => unreachable!(),
+ };
+ let scalar0_id = self.writer.get_constant_scalar(val0, dst_width);
+ let scalar1_id = self.writer.get_constant_scalar(val1, dst_width);
+ let (accept_id, reject_id) = match src_size {
+ Some(size) => {
+ let vector_type_id =
+ self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(size),
+ kind,
+ width: dst_width,
+ pointer_space: None,
+ }));
+ let components0 = [scalar0_id; 4];
+ let components1 = [scalar1_id; 4];
+
+ let vec0_id = self.gen_id();
+ block.body.push(Instruction::composite_construct(
+ vector_type_id,
+ vec0_id,
+ &components0[..size as usize],
+ ));
+ let vec1_id = self.gen_id();
+ block.body.push(Instruction::composite_construct(
+ vector_type_id,
+ vec1_id,
+ &components1[..size as usize],
+ ));
+ (vec1_id, vec0_id)
+ }
+ None => (scalar1_id, scalar0_id),
+ };
+
+ Cast::Ternary(spirv::Op::Select, accept_id, reject_id)
+ }
+ (Sk::Float, Sk::Uint, Some(_)) => Cast::Unary(spirv::Op::ConvertFToU),
+ (Sk::Float, Sk::Sint, Some(_)) => Cast::Unary(spirv::Op::ConvertFToS),
+ (Sk::Float, Sk::Float, Some(dst_width)) if src_width != dst_width => {
+ Cast::Unary(spirv::Op::FConvert)
+ }
+ (Sk::Sint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertSToF),
+ (Sk::Sint, Sk::Sint, Some(dst_width)) if src_width != dst_width => {
+ Cast::Unary(spirv::Op::SConvert)
+ }
+ (Sk::Uint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertUToF),
+ (Sk::Uint, Sk::Uint, Some(dst_width)) if src_width != dst_width => {
+ Cast::Unary(spirv::Op::UConvert)
+ }
+ // We assume it's either an identity cast, or int-uint.
+ _ => Cast::Unary(spirv::Op::Bitcast),
+ }
+ };
+
+ let id = self.gen_id();
+ let instruction = match cast {
+ Cast::Identity => None,
+ Cast::Unary(op) => Some(Instruction::unary(op, result_type_id, id, expr_id)),
+ Cast::Binary(op, operand) => Some(Instruction::binary(
+ op,
+ result_type_id,
+ id,
+ expr_id,
+ operand,
+ )),
+ Cast::Ternary(op, op1, op2) => Some(Instruction::ternary(
+ op,
+ result_type_id,
+ id,
+ expr_id,
+ op1,
+ op2,
+ )),
+ };
+ if let Some(instruction) = instruction {
+ block.body.push(instruction);
+ id
+ } else {
+ expr_id
+ }
+ }
+ crate::Expression::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => self.write_image_load(
+ result_type_id,
+ image,
+ coordinate,
+ array_index,
+ level,
+ sample,
+ block,
+ )?,
+ crate::Expression::ImageSample {
+ image,
+ sampler,
+ gather,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ } => self.write_image_sample(
+ result_type_id,
+ image,
+ sampler,
+ gather,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ block,
+ )?,
+ crate::Expression::Select {
+ condition,
+ accept,
+ reject,
+ } => {
+ let id = self.gen_id();
+ let mut condition_id = self.cached[condition];
+ let accept_id = self.cached[accept];
+ let reject_id = self.cached[reject];
+
+ let condition_ty = self.fun_info[condition]
+ .ty
+ .inner_with(&self.ir_module.types);
+ let object_ty = self.fun_info[accept].ty.inner_with(&self.ir_module.types);
+
+ if let (
+ &crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Bool,
+ width,
+ },
+ &crate::TypeInner::Vector { size, .. },
+ ) = (condition_ty, object_ty)
+ {
+ self.temp_list.clear();
+ self.temp_list.resize(size as usize, condition_id);
+
+ let bool_vector_type_id =
+ self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(size),
+ kind: crate::ScalarKind::Bool,
+ width,
+ pointer_space: None,
+ }));
+
+ let id = self.gen_id();
+ block.body.push(Instruction::composite_construct(
+ bool_vector_type_id,
+ id,
+ &self.temp_list,
+ ));
+ condition_id = id
+ }
+
+ let instruction =
+ Instruction::select(result_type_id, id, condition_id, accept_id, reject_id);
+ block.body.push(instruction);
+ id
+ }
+ crate::Expression::Derivative { axis, expr } => {
+ use crate::DerivativeAxis as Da;
+
+ let id = self.gen_id();
+ let expr_id = self.cached[expr];
+ let op = match axis {
+ Da::X => spirv::Op::DPdx,
+ Da::Y => spirv::Op::DPdy,
+ Da::Width => spirv::Op::Fwidth,
+ };
+ block
+ .body
+ .push(Instruction::derivative(op, result_type_id, id, expr_id));
+ id
+ }
+ crate::Expression::ImageQuery { image, query } => {
+ self.write_image_query(result_type_id, image, query, block)?
+ }
+ crate::Expression::Relational { fun, argument } => {
+ use crate::RelationalFunction as Rf;
+ let arg_id = self.cached[argument];
+ let op = match fun {
+ Rf::All => spirv::Op::All,
+ Rf::Any => spirv::Op::Any,
+ Rf::IsNan => spirv::Op::IsNan,
+ Rf::IsInf => spirv::Op::IsInf,
+ //TODO: these require Kernel capability
+ Rf::IsFinite | Rf::IsNormal => {
+ return Err(Error::FeatureNotImplemented("is finite/normal"))
+ }
+ };
+ let id = self.gen_id();
+ block
+ .body
+ .push(Instruction::relational(op, result_type_id, id, arg_id));
+ id
+ }
+ crate::Expression::ArrayLength(expr) => self.write_runtime_array_length(expr, block)?,
+ };
+
+ self.cached[expr_handle] = id;
+ Ok(())
+ }
+
+ /// Build an `OpAccessChain` instruction.
+ ///
+ /// Emit any needed bounds-checking expressions to `block`.
+ ///
+ /// Some cases we need to generate a different return type than what the IR gives us.
+ /// This is because pointers to binding arrays don't exist in the IR, but we need to
+ /// create them to create an access chain in SPIRV.
+ ///
+ /// On success, the return value is an [`ExpressionPointer`] value; see the
+ /// documentation for that type.
+ fn write_expression_pointer(
+ &mut self,
+ mut expr_handle: Handle<crate::Expression>,
+ block: &mut Block,
+ return_type_override: Option<LookupType>,
+ ) -> Result<ExpressionPointer, Error> {
+ let result_lookup_ty = match self.fun_info[expr_handle].ty {
+ TypeResolution::Handle(ty_handle) => match return_type_override {
+ // We use the return type override as a special case for binding arrays as the OpAccessChain
+ // needs to return a pointer, but indexing into a binding array just gives you the type of
+ // the binding in the IR.
+ Some(ty) => ty,
+ None => LookupType::Handle(ty_handle),
+ },
+ TypeResolution::Value(ref inner) => LookupType::Local(make_local(inner).unwrap()),
+ };
+ let result_type_id = self.get_type_id(result_lookup_ty);
+
+ // The id of the boolean `and` of all dynamic bounds checks up to this point. If
+ // `None`, then we haven't done any dynamic bounds checks yet.
+ //
+ // When we have a chain of bounds checks, we combine them with `OpLogicalAnd`, not
+ // a short-circuit branch. This means we might do comparisons we don't need to,
+ // but we expect these checks to almost always succeed, and keeping branches to a
+ // minimum is essential.
+ let mut accumulated_checks = None;
+
+ self.temp_list.clear();
+ let root_id = loop {
+ expr_handle = match self.ir_function.expressions[expr_handle] {
+ crate::Expression::Access { base, index } => {
+ let index_id = match self.write_bounds_check(base, index, block)? {
+ BoundsCheckResult::KnownInBounds(known_index) => {
+ // Even if the index is known, `OpAccessIndex`
+ // requires expression operands, not literals.
+ let scalar = crate::ScalarValue::Uint(known_index as u64);
+ self.writer.get_constant_scalar(scalar, 4)
+ }
+ BoundsCheckResult::Computed(computed_index_id) => computed_index_id,
+ BoundsCheckResult::Conditional(comparison_id) => {
+ match accumulated_checks {
+ Some(prior_checks) => {
+ let combined = self.gen_id();
+ block.body.push(Instruction::binary(
+ spirv::Op::LogicalAnd,
+ self.writer.get_bool_type_id(),
+ combined,
+ prior_checks,
+ comparison_id,
+ ));
+ accumulated_checks = Some(combined);
+ }
+ None => {
+ // Start a fresh chain of checks.
+ accumulated_checks = Some(comparison_id);
+ }
+ }
+
+ // Either way, the index to use is unchanged.
+ self.cached[index]
+ }
+ };
+ self.temp_list.push(index_id);
+
+ base
+ }
+ crate::Expression::AccessIndex { base, index } => {
+ let const_id = self.get_index_constant(index);
+ self.temp_list.push(const_id);
+ base
+ }
+ crate::Expression::GlobalVariable(handle) => {
+ let gv = &self.writer.global_variables[handle.index()];
+ break gv.access_id;
+ }
+ crate::Expression::LocalVariable(variable) => {
+ let local_var = &self.function.variables[&variable];
+ break local_var.id;
+ }
+ crate::Expression::FunctionArgument(index) => {
+ break self.function.parameter_id(index);
+ }
+ ref other => unimplemented!("Unexpected pointer expression {:?}", other),
+ }
+ };
+
+ let pointer = if self.temp_list.is_empty() {
+ ExpressionPointer::Ready {
+ pointer_id: root_id,
+ }
+ } else {
+ self.temp_list.reverse();
+ let pointer_id = self.gen_id();
+ let access =
+ Instruction::access_chain(result_type_id, pointer_id, root_id, &self.temp_list);
+
+ // If we generated some bounds checks, we need to leave it to our
+ // caller to generate the branch, the access, the load or store, and
+ // the zero value (for loads). Otherwise, we can emit the access
+ // ourselves, and just hand them the id of the pointer.
+ match accumulated_checks {
+ Some(condition) => ExpressionPointer::Conditional { condition, access },
+ None => {
+ block.body.push(access);
+ ExpressionPointer::Ready { pointer_id }
+ }
+ }
+ };
+
+ Ok(pointer)
+ }
+
+ /// Build the instructions for matrix - matrix column operations
+ #[allow(clippy::too_many_arguments)]
+ fn write_matrix_matrix_column_op(
+ &mut self,
+ block: &mut Block,
+ result_id: Word,
+ result_type_id: Word,
+ left_id: Word,
+ right_id: Word,
+ columns: crate::VectorSize,
+ rows: crate::VectorSize,
+ width: u8,
+ op: spirv::Op,
+ ) {
+ self.temp_list.clear();
+
+ let vector_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(rows),
+ kind: crate::ScalarKind::Float,
+ width,
+ pointer_space: None,
+ }));
+
+ for index in 0..columns as u32 {
+ let column_id_left = self.gen_id();
+ let column_id_right = self.gen_id();
+ let column_id_res = self.gen_id();
+
+ block.body.push(Instruction::composite_extract(
+ vector_type_id,
+ column_id_left,
+ left_id,
+ &[index],
+ ));
+ block.body.push(Instruction::composite_extract(
+ vector_type_id,
+ column_id_right,
+ right_id,
+ &[index],
+ ));
+ block.body.push(Instruction::binary(
+ op,
+ vector_type_id,
+ column_id_res,
+ column_id_left,
+ column_id_right,
+ ));
+
+ self.temp_list.push(column_id_res);
+ }
+
+ block.body.push(Instruction::composite_construct(
+ result_type_id,
+ result_id,
+ &self.temp_list,
+ ));
+ }
+
+ /// Build the instructions for vector - scalar multiplication
+ fn write_vector_scalar_mult(
+ &mut self,
+ block: &mut Block,
+ result_id: Word,
+ result_type_id: Word,
+ vector_id: Word,
+ scalar_id: Word,
+ vector: &crate::TypeInner,
+ ) {
+ let (size, kind) = match *vector {
+ crate::TypeInner::Vector { size, kind, .. } => (size, kind),
+ _ => unreachable!(),
+ };
+
+ let (op, operand_id) = match kind {
+ crate::ScalarKind::Float => (spirv::Op::VectorTimesScalar, scalar_id),
+ _ => {
+ let operand_id = self.gen_id();
+ self.temp_list.clear();
+ self.temp_list.resize(size as usize, scalar_id);
+ block.body.push(Instruction::composite_construct(
+ result_type_id,
+ operand_id,
+ &self.temp_list,
+ ));
+ (spirv::Op::IMul, operand_id)
+ }
+ };
+
+ block.body.push(Instruction::binary(
+ op,
+ result_type_id,
+ result_id,
+ vector_id,
+ operand_id,
+ ));
+ }
+
+ /// Build the instructions for the arithmetic expression of a dot product
+ fn write_dot_product(
+ &mut self,
+ result_id: Word,
+ result_type_id: Word,
+ arg0_id: Word,
+ arg1_id: Word,
+ size: u32,
+ block: &mut Block,
+ ) {
+ let const_null = self.gen_id();
+ block
+ .body
+ .push(Instruction::constant_null(result_type_id, const_null));
+
+ let mut partial_sum = const_null;
+ let last_component = size - 1;
+ for index in 0..=last_component {
+ // compute the product of the current components
+ let a_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ result_type_id,
+ a_id,
+ arg0_id,
+ &[index],
+ ));
+ let b_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ result_type_id,
+ b_id,
+ arg1_id,
+ &[index],
+ ));
+ let prod_id = self.gen_id();
+ block.body.push(Instruction::binary(
+ spirv::Op::IMul,
+ result_type_id,
+ prod_id,
+ a_id,
+ b_id,
+ ));
+
+ // choose the id for the next sum, depending on current index
+ let id = if index == last_component {
+ result_id
+ } else {
+ self.gen_id()
+ };
+
+ // sum the computed product with the partial sum
+ block.body.push(Instruction::binary(
+ spirv::Op::IAdd,
+ result_type_id,
+ id,
+ partial_sum,
+ prod_id,
+ ));
+ // set the id of the result as the previous partial sum
+ partial_sum = id;
+ }
+ }
+
+ pub(super) fn write_block(
+ &mut self,
+ label_id: Word,
+ statements: &[crate::Statement],
+ exit: BlockExit,
+ loop_context: LoopContext,
+ ) -> Result<(), Error> {
+ let mut block = Block::new(label_id);
+
+ for statement in statements {
+ match *statement {
+ crate::Statement::Emit(ref range) => {
+ for handle in range.clone() {
+ self.cache_expression_value(handle, &mut block)?;
+ }
+ }
+ crate::Statement::Block(ref block_statements) => {
+ let scope_id = self.gen_id();
+ self.function.consume(block, Instruction::branch(scope_id));
+
+ let merge_id = self.gen_id();
+ self.write_block(
+ scope_id,
+ block_statements,
+ BlockExit::Branch { target: merge_id },
+ loop_context,
+ )?;
+
+ block = Block::new(merge_id);
+ }
+ crate::Statement::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ let condition_id = self.cached[condition];
+
+ let merge_id = self.gen_id();
+ block.body.push(Instruction::selection_merge(
+ merge_id,
+ spirv::SelectionControl::NONE,
+ ));
+
+ let accept_id = if accept.is_empty() {
+ None
+ } else {
+ Some(self.gen_id())
+ };
+ let reject_id = if reject.is_empty() {
+ None
+ } else {
+ Some(self.gen_id())
+ };
+
+ self.function.consume(
+ block,
+ Instruction::branch_conditional(
+ condition_id,
+ accept_id.unwrap_or(merge_id),
+ reject_id.unwrap_or(merge_id),
+ ),
+ );
+
+ if let Some(block_id) = accept_id {
+ self.write_block(
+ block_id,
+ accept,
+ BlockExit::Branch { target: merge_id },
+ loop_context,
+ )?;
+ }
+ if let Some(block_id) = reject_id {
+ self.write_block(
+ block_id,
+ reject,
+ BlockExit::Branch { target: merge_id },
+ loop_context,
+ )?;
+ }
+
+ block = Block::new(merge_id);
+ }
+ crate::Statement::Switch {
+ selector,
+ ref cases,
+ } => {
+ let selector_id = self.cached[selector];
+
+ let merge_id = self.gen_id();
+ block.body.push(Instruction::selection_merge(
+ merge_id,
+ spirv::SelectionControl::NONE,
+ ));
+
+ let default_id = self.gen_id();
+
+ let mut reached_default = false;
+ let mut raw_cases = Vec::with_capacity(cases.len());
+ let mut case_ids = Vec::with_capacity(cases.len());
+ for case in cases.iter() {
+ match case.value {
+ crate::SwitchValue::Integer(value) => {
+ let label_id = self.gen_id();
+ // No cases should be added after the default case is encountered
+ // since the default case catches all
+ if !reached_default {
+ raw_cases.push(super::instructions::Case {
+ value: value as Word,
+ label_id,
+ });
+ }
+ case_ids.push(label_id);
+ }
+ crate::SwitchValue::Default => {
+ case_ids.push(default_id);
+ reached_default = true;
+ }
+ }
+ }
+
+ self.function.consume(
+ block,
+ Instruction::switch(selector_id, default_id, &raw_cases),
+ );
+
+ let inner_context = LoopContext {
+ break_id: Some(merge_id),
+ ..loop_context
+ };
+
+ for (i, (case, label_id)) in cases.iter().zip(case_ids.iter()).enumerate() {
+ let case_finish_id = if case.fall_through {
+ case_ids[i + 1]
+ } else {
+ merge_id
+ };
+ self.write_block(
+ *label_id,
+ &case.body,
+ BlockExit::Branch {
+ target: case_finish_id,
+ },
+ inner_context,
+ )?;
+ }
+
+ // If no default was encountered write a empty block to satisfy the presence of
+ // a block the default label
+ if !reached_default {
+ self.write_block(
+ default_id,
+ &[],
+ BlockExit::Branch { target: merge_id },
+ inner_context,
+ )?;
+ }
+
+ block = Block::new(merge_id);
+ }
+ crate::Statement::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ let preamble_id = self.gen_id();
+ self.function
+ .consume(block, Instruction::branch(preamble_id));
+
+ let merge_id = self.gen_id();
+ let body_id = self.gen_id();
+ let continuing_id = self.gen_id();
+
+ // SPIR-V requires the continuing to the `OpLoopMerge`,
+ // so we have to start a new block with it.
+ block = Block::new(preamble_id);
+ block.body.push(Instruction::loop_merge(
+ merge_id,
+ continuing_id,
+ spirv::SelectionControl::NONE,
+ ));
+ self.function.consume(block, Instruction::branch(body_id));
+
+ self.write_block(
+ body_id,
+ body,
+ BlockExit::Branch {
+ target: continuing_id,
+ },
+ LoopContext {
+ continuing_id: Some(continuing_id),
+ break_id: Some(merge_id),
+ },
+ )?;
+
+ let exit = match break_if {
+ Some(condition) => BlockExit::BreakIf {
+ condition,
+ preamble_id,
+ },
+ None => BlockExit::Branch {
+ target: preamble_id,
+ },
+ };
+
+ self.write_block(
+ continuing_id,
+ continuing,
+ exit,
+ LoopContext {
+ continuing_id: None,
+ break_id: Some(merge_id),
+ },
+ )?;
+
+ block = Block::new(merge_id);
+ }
+ crate::Statement::Break => {
+ self.function
+ .consume(block, Instruction::branch(loop_context.break_id.unwrap()));
+ return Ok(());
+ }
+ crate::Statement::Continue => {
+ self.function.consume(
+ block,
+ Instruction::branch(loop_context.continuing_id.unwrap()),
+ );
+ return Ok(());
+ }
+ crate::Statement::Return { value: Some(value) } => {
+ let value_id = self.cached[value];
+ let instruction = match self.function.entry_point_context {
+ // If this is an entry point, and we need to return anything,
+ // let's instead store the output variables and return `void`.
+ Some(ref context) => {
+ self.writer.write_entry_point_return(
+ value_id,
+ self.ir_function.result.as_ref().unwrap(),
+ &context.results,
+ &mut block.body,
+ )?;
+ Instruction::return_void()
+ }
+ None => Instruction::return_value(value_id),
+ };
+ self.function.consume(block, instruction);
+ return Ok(());
+ }
+ crate::Statement::Return { value: None } => {
+ self.function.consume(block, Instruction::return_void());
+ return Ok(());
+ }
+ crate::Statement::Kill => {
+ self.function.consume(block, Instruction::kill());
+ return Ok(());
+ }
+ crate::Statement::Barrier(flags) => {
+ let memory_scope = if flags.contains(crate::Barrier::STORAGE) {
+ spirv::Scope::Device
+ } else {
+ spirv::Scope::Workgroup
+ };
+ let mut semantics = spirv::MemorySemantics::ACQUIRE_RELEASE;
+ semantics.set(
+ spirv::MemorySemantics::UNIFORM_MEMORY,
+ flags.contains(crate::Barrier::STORAGE),
+ );
+ semantics.set(
+ spirv::MemorySemantics::WORKGROUP_MEMORY,
+ flags.contains(crate::Barrier::WORK_GROUP),
+ );
+ let exec_scope_id = self.get_index_constant(spirv::Scope::Workgroup as u32);
+ let mem_scope_id = self.get_index_constant(memory_scope as u32);
+ let semantics_id = self.get_index_constant(semantics.bits());
+ block.body.push(Instruction::control_barrier(
+ exec_scope_id,
+ mem_scope_id,
+ semantics_id,
+ ));
+ }
+ crate::Statement::Store { pointer, value } => {
+ let value_id = self.cached[value];
+ match self.write_expression_pointer(pointer, &mut block, None)? {
+ ExpressionPointer::Ready { pointer_id } => {
+ let atomic_space = match *self.fun_info[pointer]
+ .ty
+ .inner_with(&self.ir_module.types)
+ {
+ crate::TypeInner::Pointer { base, space } => {
+ match self.ir_module.types[base].inner {
+ crate::TypeInner::Atomic { .. } => Some(space),
+ _ => None,
+ }
+ }
+ _ => None,
+ };
+ let instruction = if let Some(space) = atomic_space {
+ let (semantics, scope) = space.to_spirv_semantics_and_scope();
+ let scope_constant_id = self.get_scope_constant(scope as u32);
+ let semantics_id = self.get_index_constant(semantics.bits());
+ Instruction::atomic_store(
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ )
+ } else {
+ Instruction::store(pointer_id, value_id, None)
+ };
+ block.body.push(instruction);
+ }
+ ExpressionPointer::Conditional { condition, access } => {
+ let mut selection = Selection::start(&mut block, ());
+ selection.if_true(self, condition, ());
+
+ // The in-bounds path. Perform the access and the store.
+ let pointer_id = access.result_id.unwrap();
+ selection.block().body.push(access);
+ selection
+ .block()
+ .body
+ .push(Instruction::store(pointer_id, value_id, None));
+
+ // Finish the in-bounds block and start the merge block. This
+ // is the block we'll leave current on return.
+ selection.finish(self, ());
+ }
+ };
+ }
+ crate::Statement::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => self.write_image_store(image, coordinate, array_index, value, &mut block)?,
+ crate::Statement::Call {
+ function: local_function,
+ ref arguments,
+ result,
+ } => {
+ let id = self.gen_id();
+ self.temp_list.clear();
+ for &argument in arguments {
+ self.temp_list.push(self.cached[argument]);
+ }
+
+ let type_id = match result {
+ Some(expr) => {
+ self.cached[expr] = id;
+ self.get_expression_type_id(&self.fun_info[expr].ty)
+ }
+ None => self.writer.void_type,
+ };
+
+ block.body.push(Instruction::function_call(
+ type_id,
+ id,
+ self.writer.lookup_function[&local_function],
+ &self.temp_list,
+ ));
+ }
+ crate::Statement::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result,
+ } => {
+ let id = self.gen_id();
+ let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty);
+
+ self.cached[result] = id;
+
+ let pointer_id =
+ match self.write_expression_pointer(pointer, &mut block, None)? {
+ ExpressionPointer::Ready { pointer_id } => pointer_id,
+ ExpressionPointer::Conditional { .. } => {
+ return Err(Error::FeatureNotImplemented(
+ "Atomics out-of-bounds handling",
+ ));
+ }
+ };
+
+ let space = self.fun_info[pointer]
+ .ty
+ .inner_with(&self.ir_module.types)
+ .pointer_space()
+ .unwrap();
+ let (semantics, scope) = space.to_spirv_semantics_and_scope();
+ let scope_constant_id = self.get_scope_constant(scope as u32);
+ let semantics_id = self.get_index_constant(semantics.bits());
+ let value_id = self.cached[value];
+ let value_inner = self.fun_info[value].ty.inner_with(&self.ir_module.types);
+
+ let instruction = match *fun {
+ crate::AtomicFunction::Add => Instruction::atomic_binary(
+ spirv::Op::AtomicIAdd,
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ ),
+ crate::AtomicFunction::Subtract => Instruction::atomic_binary(
+ spirv::Op::AtomicISub,
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ ),
+ crate::AtomicFunction::And => Instruction::atomic_binary(
+ spirv::Op::AtomicAnd,
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ ),
+ crate::AtomicFunction::InclusiveOr => Instruction::atomic_binary(
+ spirv::Op::AtomicOr,
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ ),
+ crate::AtomicFunction::ExclusiveOr => Instruction::atomic_binary(
+ spirv::Op::AtomicXor,
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ ),
+ crate::AtomicFunction::Min => {
+ let spirv_op = match *value_inner {
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Sint,
+ width: _,
+ } => spirv::Op::AtomicSMin,
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Uint,
+ width: _,
+ } => spirv::Op::AtomicUMin,
+ _ => unimplemented!(),
+ };
+ Instruction::atomic_binary(
+ spirv_op,
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ )
+ }
+ crate::AtomicFunction::Max => {
+ let spirv_op = match *value_inner {
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Sint,
+ width: _,
+ } => spirv::Op::AtomicSMax,
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Uint,
+ width: _,
+ } => spirv::Op::AtomicUMax,
+ _ => unimplemented!(),
+ };
+ Instruction::atomic_binary(
+ spirv_op,
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ )
+ }
+ crate::AtomicFunction::Exchange { compare: None } => {
+ Instruction::atomic_binary(
+ spirv::Op::AtomicExchange,
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ )
+ }
+ crate::AtomicFunction::Exchange { compare: Some(_) } => {
+ return Err(Error::FeatureNotImplemented("atomic CompareExchange"));
+ }
+ };
+
+ block.body.push(instruction);
+ }
+ }
+ }
+
+ let termination = match exit {
+ // We're generating code for the top-level Block of the function, so we
+ // need to end it with some kind of return instruction.
+ BlockExit::Return => match self.ir_function.result {
+ Some(ref result) if self.function.entry_point_context.is_none() => {
+ let type_id = self.get_type_id(LookupType::Handle(result.ty));
+ let null_id = self.writer.write_constant_null(type_id);
+ Instruction::return_value(null_id)
+ }
+ _ => Instruction::return_void(),
+ },
+ BlockExit::Branch { target } => Instruction::branch(target),
+ BlockExit::BreakIf {
+ condition,
+ preamble_id,
+ } => {
+ let condition_id = self.cached[condition];
+
+ Instruction::branch_conditional(
+ condition_id,
+ loop_context.break_id.unwrap(),
+ preamble_id,
+ )
+ }
+ };
+
+ self.function.consume(block, termination);
+ Ok(())
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/helpers.rs b/third_party/rust/naga/src/back/spv/helpers.rs
new file mode 100644
index 0000000000..1ef0db1912
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/helpers.rs
@@ -0,0 +1,108 @@
+use crate::{Handle, UniqueArena};
+use spirv::Word;
+
+pub(super) fn bytes_to_words(bytes: &[u8]) -> Vec<Word> {
+ bytes
+ .chunks(4)
+ .map(|chars| chars.iter().rev().fold(0u32, |u, c| (u << 8) | *c as u32))
+ .collect()
+}
+
+pub(super) fn string_to_words(input: &str) -> Vec<Word> {
+ let bytes = input.as_bytes();
+ let mut words = bytes_to_words(bytes);
+
+ if bytes.len() % 4 == 0 {
+ // nul-termination
+ words.push(0x0u32);
+ }
+
+ words
+}
+
+pub(super) const fn map_storage_class(space: crate::AddressSpace) -> spirv::StorageClass {
+ match space {
+ crate::AddressSpace::Handle => spirv::StorageClass::UniformConstant,
+ crate::AddressSpace::Function => spirv::StorageClass::Function,
+ crate::AddressSpace::Private => spirv::StorageClass::Private,
+ crate::AddressSpace::Storage { .. } => spirv::StorageClass::StorageBuffer,
+ crate::AddressSpace::Uniform => spirv::StorageClass::Uniform,
+ crate::AddressSpace::WorkGroup => spirv::StorageClass::Workgroup,
+ crate::AddressSpace::PushConstant => spirv::StorageClass::PushConstant,
+ }
+}
+
+pub(super) fn contains_builtin(
+ binding: Option<&crate::Binding>,
+ ty: Handle<crate::Type>,
+ arena: &UniqueArena<crate::Type>,
+ built_in: crate::BuiltIn,
+) -> bool {
+ if let Some(&crate::Binding::BuiltIn(bi)) = binding {
+ bi == built_in
+ } else if let crate::TypeInner::Struct { ref members, .. } = arena[ty].inner {
+ members
+ .iter()
+ .any(|member| contains_builtin(member.binding.as_ref(), member.ty, arena, built_in))
+ } else {
+ false // unreachable
+ }
+}
+
+impl crate::AddressSpace {
+ pub(super) const fn to_spirv_semantics_and_scope(
+ self,
+ ) -> (spirv::MemorySemantics, spirv::Scope) {
+ match self {
+ Self::Storage { .. } => (spirv::MemorySemantics::UNIFORM_MEMORY, spirv::Scope::Device),
+ Self::WorkGroup => (
+ spirv::MemorySemantics::WORKGROUP_MEMORY,
+ spirv::Scope::Workgroup,
+ ),
+ _ => (spirv::MemorySemantics::empty(), spirv::Scope::Invocation),
+ }
+ }
+}
+
+/// Return true if the global requires a type decorated with `Block`.
+///
+/// Vulkan spec v1.3 §15.6.2, "Descriptor Set Interface", says:
+///
+/// > Variables identified with the `Uniform` storage class are used to
+/// > access transparent buffer backed resources. Such variables must
+/// > be:
+/// >
+/// > - typed as `OpTypeStruct`, or an array of this type,
+/// >
+/// > - identified with a `Block` or `BufferBlock` decoration, and
+/// >
+/// > - laid out explicitly using the `Offset`, `ArrayStride`, and
+/// > `MatrixStride` decorations as specified in §15.6.4, "Offset
+/// > and Stride Assignment."
+// See `back::spv::GlobalVariable::access_id` for details.
+pub fn global_needs_wrapper(ir_module: &crate::Module, var: &crate::GlobalVariable) -> bool {
+ match var.space {
+ crate::AddressSpace::Uniform
+ | crate::AddressSpace::Storage { .. }
+ | crate::AddressSpace::PushConstant => {}
+ _ => return false,
+ };
+ match ir_module.types[var.ty].inner {
+ crate::TypeInner::Struct {
+ ref members,
+ span: _,
+ } => match members.last() {
+ Some(member) => match ir_module.types[member.ty].inner {
+ // Structs with dynamically sized arrays can't be copied and can't be wrapped.
+ crate::TypeInner::Array {
+ size: crate::ArraySize::Dynamic,
+ ..
+ } => false,
+ _ => true,
+ },
+ None => false,
+ },
+ // if it's not a structure, let's wrap it to be able to put "Block"
+ _ => true,
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/image.rs b/third_party/rust/naga/src/back/spv/image.rs
new file mode 100644
index 0000000000..e070cd6175
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/image.rs
@@ -0,0 +1,1179 @@
+/*!
+Generating SPIR-V for image operations.
+*/
+
+use super::{
+ selection::{MergeTuple, Selection},
+ Block, BlockContext, Error, IdGenerator, Instruction, LocalType, LookupType,
+};
+use crate::arena::Handle;
+use spirv::Word;
+
+/// Information about a vector of coordinates.
+///
+/// The coordinate vectors expected by SPIR-V `OpImageRead` and `OpImageFetch`
+/// supply the array index for arrayed images as an additional component at
+/// the end, whereas Naga's `ImageLoad`, `ImageStore`, and `ImageSample` carry
+/// the array index as a separate field.
+///
+/// In the process of generating code to compute the combined vector, we also
+/// produce SPIR-V types and vector lengths that are useful elsewhere. This
+/// struct gathers that information into one place, with standard names.
+struct ImageCoordinates {
+ /// The SPIR-V id of the combined coordinate/index vector value.
+ ///
+ /// Note: when indexing a non-arrayed 1D image, this will be a scalar.
+ value_id: Word,
+
+ /// The SPIR-V id of the type of `value`.
+ type_id: Word,
+
+ /// The number of components in `value`, if it is a vector, or `None` if it
+ /// is a scalar.
+ size: Option<crate::VectorSize>,
+}
+
+/// A trait for image access (load or store) code generators.
+///
+/// Types implementing this trait hold information about an `ImageStore` or
+/// `ImageLoad` operation that is not affected by the bounds check policy. The
+/// `generate` method emits code for the access, given the results of bounds
+/// checking.
+///
+/// The [`image`] bounds checks policy affects access coordinates, level of
+/// detail, and sample index, but never the image id, result type (if any), or
+/// the specific SPIR-V instruction used. Types that implement this trait gather
+/// together the latter category, so we don't have to plumb them through the
+/// bounds-checking code.
+///
+/// [`image`]: crate::proc::BoundsCheckPolicies::index
+trait Access {
+ /// The Rust type that represents SPIR-V values and types for this access.
+ ///
+ /// For operations like loads, this is `Word`. For operations like stores,
+ /// this is `()`.
+ ///
+ /// For `ReadZeroSkipWrite`, this will be the type of the selection
+ /// construct that performs the bounds checks, so it must implement
+ /// `MergeTuple`.
+ type Output: MergeTuple + Copy + Clone;
+
+ /// Write an image access to `block`.
+ ///
+ /// Access the texel at `coordinates_id`. The optional `level_id` indicates
+ /// the level of detail, and `sample_id` is the index of the sample to
+ /// access in a multisampled texel.
+ ///
+ /// Ths method assumes that `coordinates_id` has already had the image array
+ /// index, if any, folded in, as done by `write_image_coordinates`.
+ ///
+ /// Return the value id produced by the instruction, if any.
+ ///
+ /// Use `id_gen` to generate SPIR-V ids as necessary.
+ fn generate(
+ &self,
+ id_gen: &mut IdGenerator,
+ coordinates_id: Word,
+ level_id: Option<Word>,
+ sample_id: Option<Word>,
+ block: &mut Block,
+ ) -> Self::Output;
+
+ /// Return the SPIR-V type of the value produced by the code written by
+ /// `generate`. If the access does not produce a value, `Self::Output`
+ /// should be `()`.
+ fn result_type(&self) -> Self::Output;
+
+ /// Construct the SPIR-V 'zero' value to be returned for an out-of-bounds
+ /// access under the `ReadZeroSkipWrite` policy. If the access does not
+ /// produce a value, `Self::Output` should be `()`.
+ fn out_of_bounds_value(&self, ctx: &mut BlockContext<'_>) -> Self::Output;
+}
+
+/// Texel access information for an [`ImageLoad`] expression.
+///
+/// [`ImageLoad`]: crate::Expression::ImageLoad
+struct Load {
+ /// The specific opcode we'll use to perform the fetch. Storage images
+ /// require `OpImageRead`, while sampled images require `OpImageFetch`.
+ opcode: spirv::Op,
+
+ /// The type id produced by the actual image access instruction.
+ type_id: Word,
+
+ /// The id of the image being accessed.
+ image_id: Word,
+}
+
+impl Load {
+ fn from_image_expr(
+ ctx: &mut BlockContext<'_>,
+ image_id: Word,
+ image_class: crate::ImageClass,
+ result_type_id: Word,
+ ) -> Result<Load, Error> {
+ let opcode = match image_class {
+ crate::ImageClass::Storage { .. } => spirv::Op::ImageRead,
+ crate::ImageClass::Depth { .. } | crate::ImageClass::Sampled { .. } => {
+ spirv::Op::ImageFetch
+ }
+ };
+
+ // `OpImageRead` and `OpImageFetch` instructions produce vec4<f32>
+ // values. Most of the time, we can just use `result_type_id` for
+ // this. The exception is that `Expression::ImageLoad` from a depth
+ // image produces a scalar `f32`, so in that case we need to find
+ // the right SPIR-V type for the access instruction here.
+ let type_id = match image_class {
+ crate::ImageClass::Depth { .. } => {
+ ctx.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(crate::VectorSize::Quad),
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ pointer_space: None,
+ }))
+ }
+ _ => result_type_id,
+ };
+
+ Ok(Load {
+ opcode,
+ type_id,
+ image_id,
+ })
+ }
+}
+
+impl Access for Load {
+ type Output = Word;
+
+ /// Write an instruction to access a given texel of this image.
+ fn generate(
+ &self,
+ id_gen: &mut IdGenerator,
+ coordinates_id: Word,
+ level_id: Option<Word>,
+ sample_id: Option<Word>,
+ block: &mut Block,
+ ) -> Word {
+ let texel_id = id_gen.next();
+ let mut instruction = Instruction::image_fetch_or_read(
+ self.opcode,
+ self.type_id,
+ texel_id,
+ self.image_id,
+ coordinates_id,
+ );
+
+ match (level_id, sample_id) {
+ (None, None) => {}
+ (Some(level_id), None) => {
+ instruction.add_operand(spirv::ImageOperands::LOD.bits());
+ instruction.add_operand(level_id);
+ }
+ (None, Some(sample_id)) => {
+ instruction.add_operand(spirv::ImageOperands::SAMPLE.bits());
+ instruction.add_operand(sample_id);
+ }
+ // There's no such thing as a multi-sampled mipmap.
+ (Some(_), Some(_)) => unreachable!(),
+ }
+
+ block.body.push(instruction);
+
+ texel_id
+ }
+
+ fn result_type(&self) -> Word {
+ self.type_id
+ }
+
+ fn out_of_bounds_value(&self, ctx: &mut BlockContext<'_>) -> Word {
+ ctx.writer.write_constant_null(self.type_id)
+ }
+}
+
+/// Texel access information for a [`Store`] statement.
+///
+/// [`Store`]: crate::Statement::Store
+struct Store {
+ /// The id of the image being written to.
+ image_id: Word,
+
+ /// The value we're going to write to the texel.
+ value_id: Word,
+}
+
+impl Access for Store {
+ /// Stores don't generate any value.
+ type Output = ();
+
+ fn generate(
+ &self,
+ _id_gen: &mut IdGenerator,
+ coordinates_id: Word,
+ _level_id: Option<Word>,
+ _sample_id: Option<Word>,
+ block: &mut Block,
+ ) {
+ block.body.push(Instruction::image_write(
+ self.image_id,
+ coordinates_id,
+ self.value_id,
+ ));
+ }
+
+ /// Stores don't generate any value, so this just returns `()`.
+ fn result_type(&self) {}
+
+ /// Stores don't generate any value, so this just returns `()`.
+ fn out_of_bounds_value(&self, _ctx: &mut BlockContext<'_>) {}
+}
+
+impl<'w> BlockContext<'w> {
+ /// Extend image coordinates with an array index, if necessary.
+ ///
+ /// Whereas [`Expression::ImageLoad`] and [`ImageSample`] treat the array
+ /// index as a separate operand from the coordinates, SPIR-V image access
+ /// instructions include the array index in the `coordinates` operand. This
+ /// function builds a SPIR-V coordinate vector from a Naga coordinate vector
+ /// and array index, if one is supplied, and returns a `ImageCoordinates`
+ /// struct describing what it built.
+ ///
+ /// If `array_index` is `Some(expr)`, then this function constructs a new
+ /// vector that is `coordinates` with `array_index` concatenated onto the
+ /// end: a `vec2` becomes a `vec3`, a scalar becomes a `vec2`, and so on.
+ ///
+ /// If `array_index` is `None`, then the return value uses `coordinates`
+ /// unchanged. Note that, when indexing a non-arrayed 1D image, this will be
+ /// a scalar value.
+ ///
+ /// If needed, this function generates code to convert the array index,
+ /// always an integer scalar, to match the component type of `coordinates`.
+ /// Naga's `ImageLoad` and SPIR-V's `OpImageRead`, `OpImageFetch`, and
+ /// `OpImageWrite` all use integer coordinates, while Naga's `ImageSample`
+ /// and SPIR-V's `OpImageSample...` instructions all take floating-point
+ /// coordinate vectors.
+ ///
+ /// [`Expression::ImageLoad`]: crate::Expression::ImageLoad
+ /// [`ImageSample`]: crate::Expression::ImageSample
+ fn write_image_coordinates(
+ &mut self,
+ coordinates: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ block: &mut Block,
+ ) -> Result<ImageCoordinates, Error> {
+ use crate::TypeInner as Ti;
+ use crate::VectorSize as Vs;
+
+ let coordinates_id = self.cached[coordinates];
+ let ty = &self.fun_info[coordinates].ty;
+ let inner_ty = ty.inner_with(&self.ir_module.types);
+
+ // If there's no array index, the image coordinates are exactly the
+ // `coordinate` field of the `Expression::ImageLoad`. No work is needed.
+ let array_index = match array_index {
+ None => {
+ let value_id = coordinates_id;
+ let type_id = self.get_expression_type_id(ty);
+ let size = match *inner_ty {
+ Ti::Scalar { .. } => None,
+ Ti::Vector { size, .. } => Some(size),
+ _ => return Err(Error::Validation("coordinate type")),
+ };
+ return Ok(ImageCoordinates {
+ value_id,
+ type_id,
+ size,
+ });
+ }
+ Some(ix) => ix,
+ };
+
+ // Find the component type of `coordinates`, and figure out the size the
+ // combined coordinate vector will have.
+ let (component_kind, size) = match *inner_ty {
+ Ti::Scalar { kind, width: 4 } => (kind, Some(Vs::Bi)),
+ Ti::Vector {
+ kind,
+ width: 4,
+ size: Vs::Bi,
+ } => (kind, Some(Vs::Tri)),
+ Ti::Vector {
+ kind,
+ width: 4,
+ size: Vs::Tri,
+ } => (kind, Some(Vs::Quad)),
+ Ti::Vector { size: Vs::Quad, .. } => {
+ return Err(Error::Validation("extending vec4 coordinate"));
+ }
+ ref other => {
+ log::error!("wrong coordinate type {:?}", other);
+ return Err(Error::Validation("coordinate type"));
+ }
+ };
+
+ // Convert the index to the coordinate component type, if necessary.
+ let array_index_i32_id = self.cached[array_index];
+ let reconciled_array_index_id = if component_kind == crate::ScalarKind::Sint {
+ array_index_i32_id
+ } else {
+ let component_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ kind: component_kind,
+ width: 4,
+ pointer_space: None,
+ }));
+
+ let reconciled_id = self.gen_id();
+ block.body.push(Instruction::unary(
+ spirv::Op::ConvertUToF,
+ component_type_id,
+ reconciled_id,
+ array_index_i32_id,
+ ));
+ reconciled_id
+ };
+
+ // Find the SPIR-V type for the combined coordinates/index vector.
+ let type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: size,
+ kind: component_kind,
+ width: 4,
+ pointer_space: None,
+ }));
+
+ // Schmear the coordinates and index together.
+ let value_id = self.gen_id();
+ block.body.push(Instruction::composite_construct(
+ type_id,
+ value_id,
+ &[coordinates_id, reconciled_array_index_id],
+ ));
+ Ok(ImageCoordinates {
+ value_id,
+ type_id,
+ size,
+ })
+ }
+
+ fn get_image_id(&mut self, expr_handle: Handle<crate::Expression>) -> Word {
+ let id = match self.ir_function.expressions[expr_handle] {
+ crate::Expression::GlobalVariable(handle) => {
+ self.writer.global_variables[handle.index()].handle_id
+ }
+ crate::Expression::FunctionArgument(i) => {
+ self.function.parameters[i as usize].handle_id
+ }
+ crate::Expression::Access { .. } | crate::Expression::AccessIndex { .. } => {
+ self.cached[expr_handle]
+ }
+ ref other => unreachable!("Unexpected image expression {:?}", other),
+ };
+
+ if id == 0 {
+ unreachable!(
+ "Image expression {:?} doesn't have a handle ID",
+ expr_handle
+ );
+ }
+
+ id
+ }
+
+ /// Generate a vector or scalar 'one' for arithmetic on `coordinates`.
+ ///
+ /// If `coordinates` is a scalar, return a scalar one. Otherwise, return
+ /// a vector of ones.
+ fn write_coordinate_one(&mut self, coordinates: &ImageCoordinates) -> Result<Word, Error> {
+ let one = self.get_scope_constant(1);
+ match coordinates.size {
+ None => Ok(one),
+ Some(vector_size) => {
+ let ones = [one; 4];
+ let id = self.gen_id();
+ Instruction::constant_composite(
+ coordinates.type_id,
+ id,
+ &ones[..vector_size as usize],
+ )
+ .to_words(&mut self.writer.logical_layout.declarations);
+ Ok(id)
+ }
+ }
+ }
+
+ /// Generate code to restrict `input` to fall between zero and one less than
+ /// `size_id`.
+ ///
+ /// Both must be 32-bit scalar integer values, whose type is given by
+ /// `type_id`. The computed value is also of type `type_id`.
+ fn restrict_scalar(
+ &mut self,
+ type_id: Word,
+ input_id: Word,
+ size_id: Word,
+ block: &mut Block,
+ ) -> Result<Word, Error> {
+ let i32_one_id = self.get_scope_constant(1);
+
+ // Subtract one from `size` to get the largest valid value.
+ let limit_id = self.gen_id();
+ block.body.push(Instruction::binary(
+ spirv::Op::ISub,
+ type_id,
+ limit_id,
+ size_id,
+ i32_one_id,
+ ));
+
+ // Use an unsigned minimum, to handle both positive out-of-range values
+ // and negative values in a single instruction: negative values of
+ // `input_id` get treated as very large positive values.
+ let restricted_id = self.gen_id();
+ block.body.push(Instruction::ext_inst(
+ self.writer.gl450_ext_inst_id,
+ spirv::GLOp::UMin,
+ type_id,
+ restricted_id,
+ &[input_id, limit_id],
+ ));
+
+ Ok(restricted_id)
+ }
+
+ /// Write instructions to query the size of an image.
+ ///
+ /// This takes care of selecting the right instruction depending on whether
+ /// a level of detail parameter is present.
+ fn write_coordinate_bounds(
+ &mut self,
+ type_id: Word,
+ image_id: Word,
+ level_id: Option<Word>,
+ block: &mut Block,
+ ) -> Word {
+ let coordinate_bounds_id = self.gen_id();
+ match level_id {
+ Some(level_id) => {
+ // A level of detail was provided, so fetch the image size for
+ // that level.
+ let mut inst = Instruction::image_query(
+ spirv::Op::ImageQuerySizeLod,
+ type_id,
+ coordinate_bounds_id,
+ image_id,
+ );
+ inst.add_operand(level_id);
+ block.body.push(inst);
+ }
+ _ => {
+ // No level of detail was given.
+ block.body.push(Instruction::image_query(
+ spirv::Op::ImageQuerySize,
+ type_id,
+ coordinate_bounds_id,
+ image_id,
+ ));
+ }
+ }
+
+ coordinate_bounds_id
+ }
+
+ /// Write code to restrict coordinates for an image reference.
+ ///
+ /// First, clamp the level of detail or sample index to fall within bounds.
+ /// Then, obtain the image size, possibly using the clamped level of detail.
+ /// Finally, use an unsigned minimum instruction to force all coordinates
+ /// into range.
+ ///
+ /// Return a triple `(COORDS, LEVEL, SAMPLE)`, where `COORDS` is a coordinate
+ /// vector (including the array index, if any), `LEVEL` is an optional level
+ /// of detail, and `SAMPLE` is an optional sample index, all guaranteed to
+ /// be in-bounds for `image_id`.
+ ///
+ /// The result is usually a vector, but it is a scalar when indexing
+ /// non-arrayed 1D images.
+ fn write_restricted_coordinates(
+ &mut self,
+ image_id: Word,
+ coordinates: ImageCoordinates,
+ level_id: Option<Word>,
+ sample_id: Option<Word>,
+ block: &mut Block,
+ ) -> Result<(Word, Option<Word>, Option<Word>), Error> {
+ self.writer.require_any(
+ "the `Restrict` image bounds check policy",
+ &[spirv::Capability::ImageQuery],
+ )?;
+
+ let i32_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ kind: crate::ScalarKind::Sint,
+ width: 4,
+ pointer_space: None,
+ }));
+
+ // If `level` is `Some`, clamp it to fall within bounds. This must
+ // happen first, because we'll use it to query the image size for
+ // clamping the actual coordinates.
+ let level_id = level_id
+ .map(|level_id| {
+ // Find the number of mipmap levels in this image.
+ let num_levels_id = self.gen_id();
+ block.body.push(Instruction::image_query(
+ spirv::Op::ImageQueryLevels,
+ i32_type_id,
+ num_levels_id,
+ image_id,
+ ));
+
+ self.restrict_scalar(i32_type_id, level_id, num_levels_id, block)
+ })
+ .transpose()?;
+
+ // If `sample_id` is `Some`, clamp it to fall within bounds.
+ let sample_id = sample_id
+ .map(|sample_id| {
+ // Find the number of samples per texel.
+ let num_samples_id = self.gen_id();
+ block.body.push(Instruction::image_query(
+ spirv::Op::ImageQuerySamples,
+ i32_type_id,
+ num_samples_id,
+ image_id,
+ ));
+
+ self.restrict_scalar(i32_type_id, sample_id, num_samples_id, block)
+ })
+ .transpose()?;
+
+ // Obtain the image bounds, including the array element count.
+ let coordinate_bounds_id =
+ self.write_coordinate_bounds(coordinates.type_id, image_id, level_id, block);
+
+ // Compute maximum valid values from the bounds.
+ let ones = self.write_coordinate_one(&coordinates)?;
+ let coordinate_limit_id = self.gen_id();
+ block.body.push(Instruction::binary(
+ spirv::Op::ISub,
+ coordinates.type_id,
+ coordinate_limit_id,
+ coordinate_bounds_id,
+ ones,
+ ));
+
+ // Restrict the coordinates to fall within those bounds.
+ //
+ // Use an unsigned minimum, to handle both positive out-of-range values
+ // and negative values in a single instruction: negative values of
+ // `coordinates` get treated as very large positive values.
+ let restricted_coordinates_id = self.gen_id();
+ block.body.push(Instruction::ext_inst(
+ self.writer.gl450_ext_inst_id,
+ spirv::GLOp::UMin,
+ coordinates.type_id,
+ restricted_coordinates_id,
+ &[coordinates.value_id, coordinate_limit_id],
+ ));
+
+ Ok((restricted_coordinates_id, level_id, sample_id))
+ }
+
+ fn write_conditional_image_access<A: Access>(
+ &mut self,
+ image_id: Word,
+ coordinates: ImageCoordinates,
+ level_id: Option<Word>,
+ sample_id: Option<Word>,
+ block: &mut Block,
+ access: &A,
+ ) -> Result<A::Output, Error> {
+ self.writer.require_any(
+ "the `ReadZeroSkipWrite` image bounds check policy",
+ &[spirv::Capability::ImageQuery],
+ )?;
+
+ let bool_type_id = self.writer.get_bool_type_id();
+ let i32_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ kind: crate::ScalarKind::Sint,
+ width: 4,
+ pointer_space: None,
+ }));
+
+ let null_id = access.out_of_bounds_value(self);
+
+ let mut selection = Selection::start(block, access.result_type());
+
+ // If `level_id` is `Some`, check whether it is within bounds. This must
+ // happen first, because we'll be supplying this as an argument when we
+ // query the image size.
+ if let Some(level_id) = level_id {
+ // Find the number of mipmap levels in this image.
+ let num_levels_id = self.gen_id();
+ selection.block().body.push(Instruction::image_query(
+ spirv::Op::ImageQueryLevels,
+ i32_type_id,
+ num_levels_id,
+ image_id,
+ ));
+
+ let lod_cond_id = self.gen_id();
+ selection.block().body.push(Instruction::binary(
+ spirv::Op::ULessThan,
+ bool_type_id,
+ lod_cond_id,
+ level_id,
+ num_levels_id,
+ ));
+
+ selection.if_true(self, lod_cond_id, null_id);
+ }
+
+ // If `sample_id` is `Some`, check whether it is in bounds.
+ if let Some(sample_id) = sample_id {
+ // Find the number of samples per texel.
+ let num_samples_id = self.gen_id();
+ selection.block().body.push(Instruction::image_query(
+ spirv::Op::ImageQuerySamples,
+ i32_type_id,
+ num_samples_id,
+ image_id,
+ ));
+
+ let samples_cond_id = self.gen_id();
+ selection.block().body.push(Instruction::binary(
+ spirv::Op::ULessThan,
+ bool_type_id,
+ samples_cond_id,
+ sample_id,
+ num_samples_id,
+ ));
+
+ selection.if_true(self, samples_cond_id, null_id);
+ }
+
+ // Obtain the image bounds, including any array element count.
+ let coordinate_bounds_id = self.write_coordinate_bounds(
+ coordinates.type_id,
+ image_id,
+ level_id,
+ selection.block(),
+ );
+
+ // Compare the coordinates against the bounds.
+ let coords_bool_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: coordinates.size,
+ kind: crate::ScalarKind::Bool,
+ width: 1,
+ pointer_space: None,
+ }));
+ let coords_conds_id = self.gen_id();
+ selection.block().body.push(Instruction::binary(
+ spirv::Op::ULessThan,
+ coords_bool_type_id,
+ coords_conds_id,
+ coordinates.value_id,
+ coordinate_bounds_id,
+ ));
+
+ // If the comparison above was a vector comparison, then we need to
+ // check that all components of the comparison are true.
+ let coords_cond_id = if coords_bool_type_id != bool_type_id {
+ let id = self.gen_id();
+ selection.block().body.push(Instruction::relational(
+ spirv::Op::All,
+ bool_type_id,
+ id,
+ coords_conds_id,
+ ));
+ id
+ } else {
+ coords_conds_id
+ };
+
+ selection.if_true(self, coords_cond_id, null_id);
+
+ // All conditions are met. We can carry out the access.
+ let texel_id = access.generate(
+ &mut self.writer.id_gen,
+ coordinates.value_id,
+ level_id,
+ sample_id,
+ selection.block(),
+ );
+
+ // This, then, is the value of the 'true' branch.
+ Ok(selection.finish(self, texel_id))
+ }
+
+ /// Generate code for an `ImageLoad` expression.
+ ///
+ /// The arguments are the components of an `Expression::ImageLoad` variant.
+ #[allow(clippy::too_many_arguments)]
+ pub(super) fn write_image_load(
+ &mut self,
+ result_type_id: Word,
+ image: Handle<crate::Expression>,
+ coordinate: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ level: Option<Handle<crate::Expression>>,
+ sample: Option<Handle<crate::Expression>>,
+ block: &mut Block,
+ ) -> Result<Word, Error> {
+ let image_id = self.get_image_id(image);
+ let image_type = self.fun_info[image].ty.inner_with(&self.ir_module.types);
+ let image_class = match *image_type {
+ crate::TypeInner::Image { class, .. } => class,
+ _ => return Err(Error::Validation("image type")),
+ };
+
+ let access = Load::from_image_expr(self, image_id, image_class, result_type_id)?;
+ let coordinates = self.write_image_coordinates(coordinate, array_index, block)?;
+
+ let level_id = level.map(|expr| self.cached[expr]);
+ let sample_id = sample.map(|expr| self.cached[expr]);
+
+ // Perform the access, according to the bounds check policy.
+ let access_id = match self.writer.bounds_check_policies.image {
+ crate::proc::BoundsCheckPolicy::Restrict => {
+ let (coords, level_id, sample_id) = self.write_restricted_coordinates(
+ image_id,
+ coordinates,
+ level_id,
+ sample_id,
+ block,
+ )?;
+ access.generate(&mut self.writer.id_gen, coords, level_id, sample_id, block)
+ }
+ crate::proc::BoundsCheckPolicy::ReadZeroSkipWrite => self
+ .write_conditional_image_access(
+ image_id,
+ coordinates,
+ level_id,
+ sample_id,
+ block,
+ &access,
+ )?,
+ crate::proc::BoundsCheckPolicy::Unchecked => access.generate(
+ &mut self.writer.id_gen,
+ coordinates.value_id,
+ level_id,
+ sample_id,
+ block,
+ ),
+ };
+
+ // For depth images, `ImageLoad` expressions produce a single f32,
+ // whereas the SPIR-V instructions always produce a vec4. So we may have
+ // to pull out the component we need.
+ let result_id = if result_type_id == access.result_type() {
+ // The instruction produced the type we expected. We can use
+ // its result as-is.
+ access_id
+ } else {
+ // For `ImageClass::Depth` images, SPIR-V gave us four components,
+ // but we only want the first one.
+ let component_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ result_type_id,
+ component_id,
+ access_id,
+ &[0],
+ ));
+ component_id
+ };
+
+ Ok(result_id)
+ }
+
+ /// Generate code for an `ImageSample` expression.
+ ///
+ /// The arguments are the components of an `Expression::ImageSample` variant.
+ #[allow(clippy::too_many_arguments)]
+ pub(super) fn write_image_sample(
+ &mut self,
+ result_type_id: Word,
+ image: Handle<crate::Expression>,
+ sampler: Handle<crate::Expression>,
+ gather: Option<crate::SwizzleComponent>,
+ coordinate: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ offset: Option<Handle<crate::Constant>>,
+ level: crate::SampleLevel,
+ depth_ref: Option<Handle<crate::Expression>>,
+ block: &mut Block,
+ ) -> Result<Word, Error> {
+ use super::instructions::SampleLod;
+ // image
+ let image_id = self.get_image_id(image);
+ let image_type = self.fun_info[image].ty.handle().unwrap();
+ // SPIR-V doesn't know about our `Depth` class, and it returns
+ // `vec4<f32>`, so we need to grab the first component out of it.
+ let needs_sub_access = match self.ir_module.types[image_type].inner {
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Depth { .. },
+ ..
+ } => depth_ref.is_none() && gather.is_none(),
+ _ => false,
+ };
+ let sample_result_type_id = if needs_sub_access {
+ self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(crate::VectorSize::Quad),
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ pointer_space: None,
+ }))
+ } else {
+ result_type_id
+ };
+
+ // OpTypeSampledImage
+ let image_type_id = self.get_type_id(LookupType::Handle(image_type));
+ let sampled_image_type_id =
+ self.get_type_id(LookupType::Local(LocalType::SampledImage { image_type_id }));
+
+ let sampler_id = self.get_image_id(sampler);
+ let coordinates_id = self
+ .write_image_coordinates(coordinate, array_index, block)?
+ .value_id;
+
+ let sampled_image_id = self.gen_id();
+ block.body.push(Instruction::sampled_image(
+ sampled_image_type_id,
+ sampled_image_id,
+ image_id,
+ sampler_id,
+ ));
+ let id = self.gen_id();
+
+ let depth_id = depth_ref.map(|handle| self.cached[handle]);
+ let mut mask = spirv::ImageOperands::empty();
+ mask.set(spirv::ImageOperands::CONST_OFFSET, offset.is_some());
+
+ let mut main_instruction = match (level, gather) {
+ (_, Some(component)) => {
+ let component_id = self.get_index_constant(component as u32);
+ let mut inst = Instruction::image_gather(
+ sample_result_type_id,
+ id,
+ sampled_image_id,
+ coordinates_id,
+ component_id,
+ depth_id,
+ );
+ if !mask.is_empty() {
+ inst.add_operand(mask.bits());
+ }
+ inst
+ }
+ (crate::SampleLevel::Zero, None) => {
+ let mut inst = Instruction::image_sample(
+ sample_result_type_id,
+ id,
+ SampleLod::Explicit,
+ sampled_image_id,
+ coordinates_id,
+ depth_id,
+ );
+
+ let zero_id = self
+ .writer
+ .get_constant_scalar(crate::ScalarValue::Float(0.0), 4);
+
+ mask |= spirv::ImageOperands::LOD;
+ inst.add_operand(mask.bits());
+ inst.add_operand(zero_id);
+
+ inst
+ }
+ (crate::SampleLevel::Auto, None) => {
+ let mut inst = Instruction::image_sample(
+ sample_result_type_id,
+ id,
+ SampleLod::Implicit,
+ sampled_image_id,
+ coordinates_id,
+ depth_id,
+ );
+ if !mask.is_empty() {
+ inst.add_operand(mask.bits());
+ }
+ inst
+ }
+ (crate::SampleLevel::Exact(lod_handle), None) => {
+ let mut inst = Instruction::image_sample(
+ sample_result_type_id,
+ id,
+ SampleLod::Explicit,
+ sampled_image_id,
+ coordinates_id,
+ depth_id,
+ );
+
+ let lod_id = self.cached[lod_handle];
+ mask |= spirv::ImageOperands::LOD;
+ inst.add_operand(mask.bits());
+ inst.add_operand(lod_id);
+
+ inst
+ }
+ (crate::SampleLevel::Bias(bias_handle), None) => {
+ let mut inst = Instruction::image_sample(
+ sample_result_type_id,
+ id,
+ SampleLod::Implicit,
+ sampled_image_id,
+ coordinates_id,
+ depth_id,
+ );
+
+ let bias_id = self.cached[bias_handle];
+ mask |= spirv::ImageOperands::BIAS;
+ inst.add_operand(mask.bits());
+ inst.add_operand(bias_id);
+
+ inst
+ }
+ (crate::SampleLevel::Gradient { x, y }, None) => {
+ let mut inst = Instruction::image_sample(
+ sample_result_type_id,
+ id,
+ SampleLod::Explicit,
+ sampled_image_id,
+ coordinates_id,
+ depth_id,
+ );
+
+ let x_id = self.cached[x];
+ let y_id = self.cached[y];
+ mask |= spirv::ImageOperands::GRAD;
+ inst.add_operand(mask.bits());
+ inst.add_operand(x_id);
+ inst.add_operand(y_id);
+
+ inst
+ }
+ };
+
+ if let Some(offset_const) = offset {
+ let offset_id = self.writer.constant_ids[offset_const.index()];
+ main_instruction.add_operand(offset_id);
+ }
+
+ block.body.push(main_instruction);
+
+ let id = if needs_sub_access {
+ let sub_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ result_type_id,
+ sub_id,
+ id,
+ &[0],
+ ));
+ sub_id
+ } else {
+ id
+ };
+
+ Ok(id)
+ }
+
+ /// Generate code for an `ImageQuery` expression.
+ ///
+ /// The arguments are the components of an `Expression::ImageQuery` variant.
+ pub(super) fn write_image_query(
+ &mut self,
+ result_type_id: Word,
+ image: Handle<crate::Expression>,
+ query: crate::ImageQuery,
+ block: &mut Block,
+ ) -> Result<Word, Error> {
+ use crate::{ImageClass as Ic, ImageDimension as Id, ImageQuery as Iq};
+
+ let image_id = self.get_image_id(image);
+ let image_type = self.fun_info[image].ty.handle().unwrap();
+ let (dim, arrayed, class) = match self.ir_module.types[image_type].inner {
+ crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => (dim, arrayed, class),
+ _ => {
+ return Err(Error::Validation("image type"));
+ }
+ };
+
+ self.writer
+ .require_any("image queries", &[spirv::Capability::ImageQuery])?;
+
+ let id = match query {
+ Iq::Size { level } => {
+ let dim_coords = match dim {
+ Id::D1 => 1,
+ Id::D2 | Id::Cube => 2,
+ Id::D3 => 3,
+ };
+ let extended_size_type_id = {
+ let array_coords = if arrayed { 1 } else { 0 };
+ let vector_size = match dim_coords + array_coords {
+ 2 => Some(crate::VectorSize::Bi),
+ 3 => Some(crate::VectorSize::Tri),
+ 4 => Some(crate::VectorSize::Quad),
+ _ => None,
+ };
+ self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size,
+ kind: crate::ScalarKind::Sint,
+ width: 4,
+ pointer_space: None,
+ }))
+ };
+
+ let (query_op, level_id) = match class {
+ Ic::Sampled { multi: true, .. }
+ | Ic::Depth { multi: true }
+ | Ic::Storage { .. } => (spirv::Op::ImageQuerySize, None),
+ _ => {
+ let level_id = match level {
+ Some(expr) => self.cached[expr],
+ None => self.get_index_constant(0),
+ };
+ (spirv::Op::ImageQuerySizeLod, Some(level_id))
+ }
+ };
+
+ // The ID of the vector returned by SPIR-V, which contains the dimensions
+ // as well as the layer count.
+ let id_extended = self.gen_id();
+ let mut inst = Instruction::image_query(
+ query_op,
+ extended_size_type_id,
+ id_extended,
+ image_id,
+ );
+ if let Some(expr_id) = level_id {
+ inst.add_operand(expr_id);
+ }
+ block.body.push(inst);
+
+ if result_type_id != extended_size_type_id {
+ let id = self.gen_id();
+ let components = match dim {
+ // always pick the first component, and duplicate it for all 3 dimensions
+ Id::Cube => &[0u32, 0][..],
+ _ => &[0u32, 1, 2, 3][..dim_coords],
+ };
+ block.body.push(Instruction::vector_shuffle(
+ result_type_id,
+ id,
+ id_extended,
+ id_extended,
+ components,
+ ));
+ id
+ } else {
+ id_extended
+ }
+ }
+ Iq::NumLevels => {
+ let id = self.gen_id();
+ block.body.push(Instruction::image_query(
+ spirv::Op::ImageQueryLevels,
+ result_type_id,
+ id,
+ image_id,
+ ));
+ id
+ }
+ Iq::NumLayers => {
+ let vec_size = match dim {
+ Id::D1 => crate::VectorSize::Bi,
+ Id::D2 | Id::Cube => crate::VectorSize::Tri,
+ Id::D3 => crate::VectorSize::Quad,
+ };
+ let extended_size_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(vec_size),
+ kind: crate::ScalarKind::Sint,
+ width: 4,
+ pointer_space: None,
+ }));
+ let id_extended = self.gen_id();
+ let mut inst = Instruction::image_query(
+ spirv::Op::ImageQuerySizeLod,
+ extended_size_type_id,
+ id_extended,
+ image_id,
+ );
+ inst.add_operand(self.get_index_constant(0));
+ block.body.push(inst);
+ let id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ result_type_id,
+ id,
+ id_extended,
+ &[vec_size as u32 - 1],
+ ));
+ id
+ }
+ Iq::NumSamples => {
+ let id = self.gen_id();
+ block.body.push(Instruction::image_query(
+ spirv::Op::ImageQuerySamples,
+ result_type_id,
+ id,
+ image_id,
+ ));
+ id
+ }
+ };
+
+ Ok(id)
+ }
+
+ pub(super) fn write_image_store(
+ &mut self,
+ image: Handle<crate::Expression>,
+ coordinate: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ value: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<(), Error> {
+ let image_id = self.get_image_id(image);
+ let coordinates = self.write_image_coordinates(coordinate, array_index, block)?;
+ let value_id = self.cached[value];
+
+ let write = Store { image_id, value_id };
+
+ match self.writer.bounds_check_policies.image {
+ crate::proc::BoundsCheckPolicy::Restrict => {
+ let (coords, _, _) =
+ self.write_restricted_coordinates(image_id, coordinates, None, None, block)?;
+ write.generate(&mut self.writer.id_gen, coords, None, None, block);
+ }
+ crate::proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
+ self.write_conditional_image_access(
+ image_id,
+ coordinates,
+ None,
+ None,
+ block,
+ &write,
+ )?;
+ }
+ crate::proc::BoundsCheckPolicy::Unchecked => {
+ write.generate(
+ &mut self.writer.id_gen,
+ coordinates.value_id,
+ None,
+ None,
+ block,
+ );
+ }
+ }
+
+ Ok(())
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/index.rs b/third_party/rust/naga/src/back/spv/index.rs
new file mode 100644
index 0000000000..d2cbdf4d6d
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/index.rs
@@ -0,0 +1,417 @@
+/*!
+Bounds-checking for SPIR-V output.
+*/
+
+use super::{
+ helpers::global_needs_wrapper, selection::Selection, Block, BlockContext, Error, IdGenerator,
+ Instruction, Word,
+};
+use crate::{arena::Handle, proc::BoundsCheckPolicy};
+
+/// The results of performing a bounds check.
+///
+/// On success, `write_bounds_check` returns a value of this type.
+pub(super) enum BoundsCheckResult {
+ /// The index is statically known and in bounds, with the given value.
+ KnownInBounds(u32),
+
+ /// The given instruction computes the index to be used.
+ Computed(Word),
+
+ /// The given instruction computes a boolean condition which is true
+ /// if the index is in bounds.
+ Conditional(Word),
+}
+
+/// A value that we either know at translation time, or need to compute at runtime.
+pub(super) enum MaybeKnown<T> {
+ /// The value is known at shader translation time.
+ Known(T),
+
+ /// The value is computed by the instruction with the given id.
+ Computed(Word),
+}
+
+impl<'w> BlockContext<'w> {
+ /// Emit code to compute the length of a run-time array.
+ ///
+ /// Given `array`, an expression referring a runtime-sized array, return the
+ /// instruction id for the array's length.
+ pub(super) fn write_runtime_array_length(
+ &mut self,
+ array: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<Word, Error> {
+ // Naga IR permits runtime-sized arrays as global variables or as the
+ // final member of a struct that is a global variable. SPIR-V permits
+ // only the latter, so this back end wraps bare runtime-sized arrays
+ // in a made-up struct; see `helpers::global_needs_wrapper` and its uses.
+ // This code must handle both cases.
+ let (structure_id, last_member_index) = match self.ir_function.expressions[array] {
+ crate::Expression::AccessIndex { base, index } => {
+ match self.ir_function.expressions[base] {
+ crate::Expression::GlobalVariable(handle) => (
+ self.writer.global_variables[handle.index()].access_id,
+ index,
+ ),
+ _ => return Err(Error::Validation("array length expression")),
+ }
+ }
+ crate::Expression::GlobalVariable(handle) => {
+ let global = &self.ir_module.global_variables[handle];
+ if !global_needs_wrapper(self.ir_module, global) {
+ return Err(Error::Validation("array length expression"));
+ }
+
+ (self.writer.global_variables[handle.index()].var_id, 0)
+ }
+ _ => return Err(Error::Validation("array length expression")),
+ };
+
+ let length_id = self.gen_id();
+ block.body.push(Instruction::array_length(
+ self.writer.get_uint_type_id(),
+ length_id,
+ structure_id,
+ last_member_index,
+ ));
+
+ Ok(length_id)
+ }
+
+ /// Compute the length of a subscriptable value.
+ ///
+ /// Given `sequence`, an expression referring to some indexable type, return
+ /// its length. The result may either be computed by SPIR-V instructions, or
+ /// known at shader translation time.
+ ///
+ /// `sequence` may be a `Vector`, `Matrix`, or `Array`, a `Pointer` to any
+ /// of those, or a `ValuePointer`. An array may be fixed-size, dynamically
+ /// sized, or use a specializable constant as its length.
+ fn write_sequence_length(
+ &mut self,
+ sequence: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<MaybeKnown<u32>, Error> {
+ let sequence_ty = self.fun_info[sequence].ty.inner_with(&self.ir_module.types);
+ match sequence_ty.indexable_length(self.ir_module) {
+ Ok(crate::proc::IndexableLength::Known(known_length)) => {
+ Ok(MaybeKnown::Known(known_length))
+ }
+ Ok(crate::proc::IndexableLength::Dynamic) => {
+ let length_id = self.write_runtime_array_length(sequence, block)?;
+ Ok(MaybeKnown::Computed(length_id))
+ }
+ Err(err) => {
+ log::error!("Sequence length for {:?} failed: {}", sequence, err);
+ Err(Error::Validation("indexable length"))
+ }
+ }
+ }
+
+ /// Compute the maximum valid index of a subscriptable value.
+ ///
+ /// Given `sequence`, an expression referring to some indexable type, return
+ /// its maximum valid index - one less than its length. The result may
+ /// either be computed, or known at shader translation time.
+ ///
+ /// `sequence` may be a `Vector`, `Matrix`, or `Array`, a `Pointer` to any
+ /// of those, or a `ValuePointer`. An array may be fixed-size, dynamically
+ /// sized, or use a specializable constant as its length.
+ fn write_sequence_max_index(
+ &mut self,
+ sequence: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<MaybeKnown<u32>, Error> {
+ match self.write_sequence_length(sequence, block)? {
+ MaybeKnown::Known(known_length) => {
+ // We should have thrown out all attempts to subscript zero-length
+ // sequences during validation, so the following subtraction should never
+ // underflow.
+ assert!(known_length > 0);
+ // Compute the max index from the length now.
+ Ok(MaybeKnown::Known(known_length - 1))
+ }
+ MaybeKnown::Computed(length_id) => {
+ // Emit code to compute the max index from the length.
+ let const_one_id = self.get_index_constant(1);
+ let max_index_id = self.gen_id();
+ block.body.push(Instruction::binary(
+ spirv::Op::ISub,
+ self.writer.get_uint_type_id(),
+ max_index_id,
+ length_id,
+ const_one_id,
+ ));
+ Ok(MaybeKnown::Computed(max_index_id))
+ }
+ }
+ }
+
+ /// Restrict an index to be in range for a vector, matrix, or array.
+ ///
+ /// This is used to implement `BoundsCheckPolicy::Restrict`. An in-bounds
+ /// index is left unchanged. An out-of-bounds index is replaced with some
+ /// arbitrary in-bounds index. Note,this is not necessarily clamping; for
+ /// example, negative indices might be changed to refer to the last element
+ /// of the sequence, not the first, as clamping would do.
+ ///
+ /// Either return the restricted index value, if known, or add instructions
+ /// to `block` to compute it, and return the id of the result. See the
+ /// documentation for `BoundsCheckResult` for details.
+ ///
+ /// The `sequence` expression may be a `Vector`, `Matrix`, or `Array`, a
+ /// `Pointer` to any of those, or a `ValuePointer`. An array may be
+ /// fixed-size, dynamically sized, or use a specializable constant as its
+ /// length.
+ pub(super) fn write_restricted_index(
+ &mut self,
+ sequence: Handle<crate::Expression>,
+ index: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<BoundsCheckResult, Error> {
+ let index_id = self.cached[index];
+
+ // Get the sequence's maximum valid index. Return early if we've already
+ // done the bounds check.
+ let max_index_id = match self.write_sequence_max_index(sequence, block)? {
+ MaybeKnown::Known(known_max_index) => {
+ if let crate::Expression::Constant(index_k) = self.ir_function.expressions[index] {
+ if let Some(known_index) = self.ir_module.constants[index_k].to_array_length() {
+ // Both the index and length are known at compile time.
+ //
+ // In strict WGSL compliance mode, out-of-bounds indices cannot be
+ // reported at shader translation time, and must be replaced with
+ // in-bounds indices at run time. So we cannot assume that
+ // validation ensured the index was in bounds. Restrict now.
+ let restricted = std::cmp::min(known_index, known_max_index);
+ return Ok(BoundsCheckResult::KnownInBounds(restricted));
+ }
+ }
+
+ self.get_index_constant(known_max_index)
+ }
+ MaybeKnown::Computed(max_index_id) => max_index_id,
+ };
+
+ // One or the other of the index or length is dynamic, so emit code for
+ // BoundsCheckPolicy::Restrict.
+ let restricted_index_id = self.gen_id();
+ block.body.push(Instruction::ext_inst(
+ self.writer.gl450_ext_inst_id,
+ spirv::GLOp::UMin,
+ self.writer.get_uint_type_id(),
+ restricted_index_id,
+ &[index_id, max_index_id],
+ ));
+ Ok(BoundsCheckResult::Computed(restricted_index_id))
+ }
+
+ /// Write an index bounds comparison to `block`, if needed.
+ ///
+ /// If we're able to determine statically that `index` is in bounds for
+ /// `sequence`, return `KnownInBounds(value)`, where `value` is the actual
+ /// value of the index. (In principle, one could know that the index is in
+ /// bounds without knowing its specific value, but in our simple-minded
+ /// situation, we always know it.)
+ ///
+ /// If instead we must generate code to perform the comparison at run time,
+ /// return `Conditional(comparison_id)`, where `comparison_id` is an
+ /// instruction producing a boolean value that is true if `index` is in
+ /// bounds for `sequence`.
+ ///
+ /// The `sequence` expression may be a `Vector`, `Matrix`, or `Array`, a
+ /// `Pointer` to any of those, or a `ValuePointer`. An array may be
+ /// fixed-size, dynamically sized, or use a specializable constant as its
+ /// length.
+ fn write_index_comparison(
+ &mut self,
+ sequence: Handle<crate::Expression>,
+ index: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<BoundsCheckResult, Error> {
+ let index_id = self.cached[index];
+
+ // Get the sequence's length. Return early if we've already done the
+ // bounds check.
+ let length_id = match self.write_sequence_length(sequence, block)? {
+ MaybeKnown::Known(known_length) => {
+ if let crate::Expression::Constant(index_k) = self.ir_function.expressions[index] {
+ if let Some(known_index) = self.ir_module.constants[index_k].to_array_length() {
+ // Both the index and length are known at compile time.
+ //
+ // It would be nice to assume that, since we are using the
+ // `ReadZeroSkipWrite` policy, we are not in strict WGSL
+ // compliance mode, and thus we can count on the validator to have
+ // rejected any programs with known out-of-bounds indices, and
+ // thus just return `KnownInBounds` here without actually
+ // checking.
+ //
+ // But it's also reasonable to expect that bounds check policies
+ // and error reporting policies should be able to vary
+ // independently without introducing security holes. So, we should
+ // support the case where bad indices do not cause validation
+ // errors, and are handled via `ReadZeroSkipWrite`.
+ //
+ // In theory, when `known_index` is bad, we could return a new
+ // `KnownOutOfBounds` variant here. But it's simpler just to fall
+ // through and let the bounds check take place. The shader is
+ // broken anyway, so it doesn't make sense to invest in emitting
+ // the ideal code for it.
+ if known_index < known_length {
+ return Ok(BoundsCheckResult::KnownInBounds(known_index));
+ }
+ }
+ }
+
+ self.get_index_constant(known_length)
+ }
+ MaybeKnown::Computed(length_id) => length_id,
+ };
+
+ // Compare the index against the length.
+ let condition_id = self.gen_id();
+ block.body.push(Instruction::binary(
+ spirv::Op::ULessThan,
+ self.writer.get_bool_type_id(),
+ condition_id,
+ index_id,
+ length_id,
+ ));
+
+ // Indicate that we did generate the check.
+ Ok(BoundsCheckResult::Conditional(condition_id))
+ }
+
+ /// Emit a conditional load for `BoundsCheckPolicy::ReadZeroSkipWrite`.
+ ///
+ /// Generate code to load a value of `result_type` if `condition` is true,
+ /// and generate a null value of that type if it is false. Call `emit_load`
+ /// to emit the instructions to perform the load. Return the id of the
+ /// merged value of the two branches.
+ pub(super) fn write_conditional_indexed_load<F>(
+ &mut self,
+ result_type: Word,
+ condition: Word,
+ block: &mut Block,
+ emit_load: F,
+ ) -> Word
+ where
+ F: FnOnce(&mut IdGenerator, &mut Block) -> Word,
+ {
+ // For the out-of-bounds case, we produce a zero value.
+ let null_id = self.writer.write_constant_null(result_type);
+
+ let mut selection = Selection::start(block, result_type);
+
+ // As it turns out, we don't actually need a full 'if-then-else'
+ // structure for this: SPIR-V constants are declared up front, so the
+ // 'else' block would have no instructions. Instead we emit something
+ // like this:
+ //
+ // result = zero;
+ // if in_bounds {
+ // result = do the load;
+ // }
+ // use result;
+
+ // Continue only if the index was in bounds. Otherwise, branch to the
+ // merge block.
+ selection.if_true(self, condition, null_id);
+
+ // The in-bounds path. Perform the access and the load.
+ let loaded_value = emit_load(&mut self.writer.id_gen, selection.block());
+
+ selection.finish(self, loaded_value)
+ }
+
+ /// Emit code for bounds checks for an array, vector, or matrix access.
+ ///
+ /// This implements either `index_bounds_check_policy` or
+ /// `buffer_bounds_check_policy`, depending on the address space of the
+ /// pointer being accessed.
+ ///
+ /// Return a `BoundsCheckResult` indicating how the index should be
+ /// consumed. See that type's documentation for details.
+ pub(super) fn write_bounds_check(
+ &mut self,
+ base: Handle<crate::Expression>,
+ index: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<BoundsCheckResult, Error> {
+ let policy = self.writer.bounds_check_policies.choose_policy(
+ base,
+ &self.ir_module.types,
+ self.fun_info,
+ );
+
+ Ok(match policy {
+ BoundsCheckPolicy::Restrict => self.write_restricted_index(base, index, block)?,
+ BoundsCheckPolicy::ReadZeroSkipWrite => {
+ self.write_index_comparison(base, index, block)?
+ }
+ BoundsCheckPolicy::Unchecked => BoundsCheckResult::Computed(self.cached[index]),
+ })
+ }
+
+ /// Emit code to subscript a vector by value with a computed index.
+ ///
+ /// Return the id of the element value.
+ pub(super) fn write_vector_access(
+ &mut self,
+ expr_handle: Handle<crate::Expression>,
+ base: Handle<crate::Expression>,
+ index: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<Word, Error> {
+ let result_type_id = self.get_expression_type_id(&self.fun_info[expr_handle].ty);
+
+ let base_id = self.cached[base];
+ let index_id = self.cached[index];
+
+ let result_id = match self.write_bounds_check(base, index, block)? {
+ BoundsCheckResult::KnownInBounds(known_index) => {
+ let result_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ result_type_id,
+ result_id,
+ base_id,
+ &[known_index],
+ ));
+ result_id
+ }
+ BoundsCheckResult::Computed(computed_index_id) => {
+ let result_id = self.gen_id();
+ block.body.push(Instruction::vector_extract_dynamic(
+ result_type_id,
+ result_id,
+ base_id,
+ computed_index_id,
+ ));
+ result_id
+ }
+ BoundsCheckResult::Conditional(comparison_id) => {
+ // Run-time bounds checks were required. Emit
+ // conditional load.
+ self.write_conditional_indexed_load(
+ result_type_id,
+ comparison_id,
+ block,
+ |id_gen, block| {
+ // The in-bounds path. Generate the access.
+ let element_id = id_gen.next();
+ block.body.push(Instruction::vector_extract_dynamic(
+ result_type_id,
+ element_id,
+ base_id,
+ index_id,
+ ));
+ element_id
+ },
+ )
+ }
+ };
+
+ Ok(result_id)
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/instructions.rs b/third_party/rust/naga/src/back/spv/instructions.rs
new file mode 100644
index 0000000000..9ec1deb0b2
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/instructions.rs
@@ -0,0 +1,996 @@
+use super::helpers;
+use spirv::{Op, Word};
+
+pub(super) enum Signedness {
+ Unsigned = 0,
+ Signed = 1,
+}
+
+pub(super) enum SampleLod {
+ Explicit,
+ Implicit,
+}
+
+pub(super) struct Case {
+ pub value: Word,
+ pub label_id: Word,
+}
+
+impl super::Instruction {
+ //
+ // Debug Instructions
+ //
+
+ pub(super) fn source(source_language: spirv::SourceLanguage, version: u32) -> Self {
+ let mut instruction = Self::new(Op::Source);
+ instruction.add_operand(source_language as u32);
+ instruction.add_operands(helpers::bytes_to_words(&version.to_le_bytes()));
+ instruction
+ }
+
+ pub(super) fn name(target_id: Word, name: &str) -> Self {
+ let mut instruction = Self::new(Op::Name);
+ instruction.add_operand(target_id);
+ instruction.add_operands(helpers::string_to_words(name));
+ instruction
+ }
+
+ pub(super) fn member_name(target_id: Word, member: Word, name: &str) -> Self {
+ let mut instruction = Self::new(Op::MemberName);
+ instruction.add_operand(target_id);
+ instruction.add_operand(member);
+ instruction.add_operands(helpers::string_to_words(name));
+ instruction
+ }
+
+ //
+ // Annotation Instructions
+ //
+
+ pub(super) fn decorate(
+ target_id: Word,
+ decoration: spirv::Decoration,
+ operands: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::Decorate);
+ instruction.add_operand(target_id);
+ instruction.add_operand(decoration as u32);
+ for operand in operands {
+ instruction.add_operand(*operand)
+ }
+ instruction
+ }
+
+ pub(super) fn member_decorate(
+ target_id: Word,
+ member_index: Word,
+ decoration: spirv::Decoration,
+ operands: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::MemberDecorate);
+ instruction.add_operand(target_id);
+ instruction.add_operand(member_index);
+ instruction.add_operand(decoration as u32);
+ for operand in operands {
+ instruction.add_operand(*operand)
+ }
+ instruction
+ }
+
+ //
+ // Extension Instructions
+ //
+
+ pub(super) fn extension(name: &str) -> Self {
+ let mut instruction = Self::new(Op::Extension);
+ instruction.add_operands(helpers::string_to_words(name));
+ instruction
+ }
+
+ pub(super) fn ext_inst_import(id: Word, name: &str) -> Self {
+ let mut instruction = Self::new(Op::ExtInstImport);
+ instruction.set_result(id);
+ instruction.add_operands(helpers::string_to_words(name));
+ instruction
+ }
+
+ pub(super) fn ext_inst(
+ set_id: Word,
+ op: spirv::GLOp,
+ result_type_id: Word,
+ id: Word,
+ operands: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::ExtInst);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(set_id);
+ instruction.add_operand(op as u32);
+ for operand in operands {
+ instruction.add_operand(*operand)
+ }
+ instruction
+ }
+
+ //
+ // Mode-Setting Instructions
+ //
+
+ pub(super) fn memory_model(
+ addressing_model: spirv::AddressingModel,
+ memory_model: spirv::MemoryModel,
+ ) -> Self {
+ let mut instruction = Self::new(Op::MemoryModel);
+ instruction.add_operand(addressing_model as u32);
+ instruction.add_operand(memory_model as u32);
+ instruction
+ }
+
+ pub(super) fn entry_point(
+ execution_model: spirv::ExecutionModel,
+ entry_point_id: Word,
+ name: &str,
+ interface_ids: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::EntryPoint);
+ instruction.add_operand(execution_model as u32);
+ instruction.add_operand(entry_point_id);
+ instruction.add_operands(helpers::string_to_words(name));
+
+ for interface_id in interface_ids {
+ instruction.add_operand(*interface_id);
+ }
+
+ instruction
+ }
+
+ pub(super) fn execution_mode(
+ entry_point_id: Word,
+ execution_mode: spirv::ExecutionMode,
+ args: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::ExecutionMode);
+ instruction.add_operand(entry_point_id);
+ instruction.add_operand(execution_mode as u32);
+ for arg in args {
+ instruction.add_operand(*arg);
+ }
+ instruction
+ }
+
+ pub(super) fn capability(capability: spirv::Capability) -> Self {
+ let mut instruction = Self::new(Op::Capability);
+ instruction.add_operand(capability as u32);
+ instruction
+ }
+
+ //
+ // Type-Declaration Instructions
+ //
+
+ pub(super) fn type_void(id: Word) -> Self {
+ let mut instruction = Self::new(Op::TypeVoid);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) fn type_bool(id: Word) -> Self {
+ let mut instruction = Self::new(Op::TypeBool);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) fn type_int(id: Word, width: Word, signedness: Signedness) -> Self {
+ let mut instruction = Self::new(Op::TypeInt);
+ instruction.set_result(id);
+ instruction.add_operand(width);
+ instruction.add_operand(signedness as u32);
+ instruction
+ }
+
+ pub(super) fn type_float(id: Word, width: Word) -> Self {
+ let mut instruction = Self::new(Op::TypeFloat);
+ instruction.set_result(id);
+ instruction.add_operand(width);
+ instruction
+ }
+
+ pub(super) fn type_vector(
+ id: Word,
+ component_type_id: Word,
+ component_count: crate::VectorSize,
+ ) -> Self {
+ let mut instruction = Self::new(Op::TypeVector);
+ instruction.set_result(id);
+ instruction.add_operand(component_type_id);
+ instruction.add_operand(component_count as u32);
+ instruction
+ }
+
+ pub(super) fn type_matrix(
+ id: Word,
+ column_type_id: Word,
+ column_count: crate::VectorSize,
+ ) -> Self {
+ let mut instruction = Self::new(Op::TypeMatrix);
+ instruction.set_result(id);
+ instruction.add_operand(column_type_id);
+ instruction.add_operand(column_count as u32);
+ instruction
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ pub(super) fn type_image(
+ id: Word,
+ sampled_type_id: Word,
+ dim: spirv::Dim,
+ flags: super::ImageTypeFlags,
+ image_format: spirv::ImageFormat,
+ ) -> Self {
+ let mut instruction = Self::new(Op::TypeImage);
+ instruction.set_result(id);
+ instruction.add_operand(sampled_type_id);
+ instruction.add_operand(dim as u32);
+ instruction.add_operand(flags.contains(super::ImageTypeFlags::DEPTH) as u32);
+ instruction.add_operand(flags.contains(super::ImageTypeFlags::ARRAYED) as u32);
+ instruction.add_operand(flags.contains(super::ImageTypeFlags::MULTISAMPLED) as u32);
+ instruction.add_operand(if flags.contains(super::ImageTypeFlags::SAMPLED) {
+ 1
+ } else {
+ 2
+ });
+ instruction.add_operand(image_format as u32);
+ instruction
+ }
+
+ pub(super) fn type_sampler(id: Word) -> Self {
+ let mut instruction = Self::new(Op::TypeSampler);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) fn type_sampled_image(id: Word, image_type_id: Word) -> Self {
+ let mut instruction = Self::new(Op::TypeSampledImage);
+ instruction.set_result(id);
+ instruction.add_operand(image_type_id);
+ instruction
+ }
+
+ pub(super) fn type_array(id: Word, element_type_id: Word, length_id: Word) -> Self {
+ let mut instruction = Self::new(Op::TypeArray);
+ instruction.set_result(id);
+ instruction.add_operand(element_type_id);
+ instruction.add_operand(length_id);
+ instruction
+ }
+
+ pub(super) fn type_runtime_array(id: Word, element_type_id: Word) -> Self {
+ let mut instruction = Self::new(Op::TypeRuntimeArray);
+ instruction.set_result(id);
+ instruction.add_operand(element_type_id);
+ instruction
+ }
+
+ pub(super) fn type_struct(id: Word, member_ids: &[Word]) -> Self {
+ let mut instruction = Self::new(Op::TypeStruct);
+ instruction.set_result(id);
+
+ for member_id in member_ids {
+ instruction.add_operand(*member_id)
+ }
+
+ instruction
+ }
+
+ pub(super) fn type_pointer(
+ id: Word,
+ storage_class: spirv::StorageClass,
+ type_id: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::TypePointer);
+ instruction.set_result(id);
+ instruction.add_operand(storage_class as u32);
+ instruction.add_operand(type_id);
+ instruction
+ }
+
+ pub(super) fn type_function(id: Word, return_type_id: Word, parameter_ids: &[Word]) -> Self {
+ let mut instruction = Self::new(Op::TypeFunction);
+ instruction.set_result(id);
+ instruction.add_operand(return_type_id);
+
+ for parameter_id in parameter_ids {
+ instruction.add_operand(*parameter_id);
+ }
+
+ instruction
+ }
+
+ //
+ // Constant-Creation Instructions
+ //
+
+ pub(super) fn constant_null(result_type_id: Word, id: Word) -> Self {
+ let mut instruction = Self::new(Op::ConstantNull);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) fn constant_true(result_type_id: Word, id: Word) -> Self {
+ let mut instruction = Self::new(Op::ConstantTrue);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) fn constant_false(result_type_id: Word, id: Word) -> Self {
+ let mut instruction = Self::new(Op::ConstantFalse);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) fn constant(result_type_id: Word, id: Word, values: &[Word]) -> Self {
+ let mut instruction = Self::new(Op::Constant);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+
+ for value in values {
+ instruction.add_operand(*value);
+ }
+
+ instruction
+ }
+
+ pub(super) fn constant_composite(
+ result_type_id: Word,
+ id: Word,
+ constituent_ids: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::ConstantComposite);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+
+ for constituent_id in constituent_ids {
+ instruction.add_operand(*constituent_id);
+ }
+
+ instruction
+ }
+
+ //
+ // Memory Instructions
+ //
+
+ pub(super) fn variable(
+ result_type_id: Word,
+ id: Word,
+ storage_class: spirv::StorageClass,
+ initializer_id: Option<Word>,
+ ) -> Self {
+ let mut instruction = Self::new(Op::Variable);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(storage_class as u32);
+
+ if let Some(initializer_id) = initializer_id {
+ instruction.add_operand(initializer_id);
+ }
+
+ instruction
+ }
+
+ pub(super) fn load(
+ result_type_id: Word,
+ id: Word,
+ pointer_id: Word,
+ memory_access: Option<spirv::MemoryAccess>,
+ ) -> Self {
+ let mut instruction = Self::new(Op::Load);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(pointer_id);
+
+ if let Some(memory_access) = memory_access {
+ instruction.add_operand(memory_access.bits());
+ }
+
+ instruction
+ }
+
+ pub(super) fn atomic_load(
+ result_type_id: Word,
+ id: Word,
+ pointer_id: Word,
+ scope_id: Word,
+ semantics_id: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::AtomicLoad);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(pointer_id);
+ instruction.add_operand(scope_id);
+ instruction.add_operand(semantics_id);
+ instruction
+ }
+
+ pub(super) fn store(
+ pointer_id: Word,
+ value_id: Word,
+ memory_access: Option<spirv::MemoryAccess>,
+ ) -> Self {
+ let mut instruction = Self::new(Op::Store);
+ instruction.add_operand(pointer_id);
+ instruction.add_operand(value_id);
+
+ if let Some(memory_access) = memory_access {
+ instruction.add_operand(memory_access.bits());
+ }
+
+ instruction
+ }
+
+ pub(super) fn atomic_store(
+ pointer_id: Word,
+ scope_id: Word,
+ semantics_id: Word,
+ value_id: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::AtomicStore);
+ instruction.add_operand(pointer_id);
+ instruction.add_operand(scope_id);
+ instruction.add_operand(semantics_id);
+ instruction.add_operand(value_id);
+ instruction
+ }
+
+ pub(super) fn access_chain(
+ result_type_id: Word,
+ id: Word,
+ base_id: Word,
+ index_ids: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::AccessChain);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(base_id);
+
+ for index_id in index_ids {
+ instruction.add_operand(*index_id);
+ }
+
+ instruction
+ }
+
+ pub(super) fn array_length(
+ result_type_id: Word,
+ id: Word,
+ structure_id: Word,
+ array_member: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::ArrayLength);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(structure_id);
+ instruction.add_operand(array_member);
+ instruction
+ }
+
+ //
+ // Function Instructions
+ //
+
+ pub(super) fn function(
+ return_type_id: Word,
+ id: Word,
+ function_control: spirv::FunctionControl,
+ function_type_id: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::Function);
+ instruction.set_type(return_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(function_control.bits());
+ instruction.add_operand(function_type_id);
+ instruction
+ }
+
+ pub(super) fn function_parameter(result_type_id: Word, id: Word) -> Self {
+ let mut instruction = Self::new(Op::FunctionParameter);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) const fn function_end() -> Self {
+ Self::new(Op::FunctionEnd)
+ }
+
+ pub(super) fn function_call(
+ result_type_id: Word,
+ id: Word,
+ function_id: Word,
+ argument_ids: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::FunctionCall);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(function_id);
+
+ for argument_id in argument_ids {
+ instruction.add_operand(*argument_id);
+ }
+
+ instruction
+ }
+
+ //
+ // Image Instructions
+ //
+
+ pub(super) fn sampled_image(
+ result_type_id: Word,
+ id: Word,
+ image: Word,
+ sampler: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::SampledImage);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(image);
+ instruction.add_operand(sampler);
+ instruction
+ }
+
+ pub(super) fn image_sample(
+ result_type_id: Word,
+ id: Word,
+ lod: SampleLod,
+ sampled_image: Word,
+ coordinates: Word,
+ depth_ref: Option<Word>,
+ ) -> Self {
+ let op = match (lod, depth_ref) {
+ (SampleLod::Explicit, None) => Op::ImageSampleExplicitLod,
+ (SampleLod::Implicit, None) => Op::ImageSampleImplicitLod,
+ (SampleLod::Explicit, Some(_)) => Op::ImageSampleDrefExplicitLod,
+ (SampleLod::Implicit, Some(_)) => Op::ImageSampleDrefImplicitLod,
+ };
+
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(sampled_image);
+ instruction.add_operand(coordinates);
+ if let Some(dref) = depth_ref {
+ instruction.add_operand(dref);
+ }
+
+ instruction
+ }
+
+ pub(super) fn image_gather(
+ result_type_id: Word,
+ id: Word,
+ sampled_image: Word,
+ coordinates: Word,
+ component_id: Word,
+ depth_ref: Option<Word>,
+ ) -> Self {
+ let op = match depth_ref {
+ None => Op::ImageGather,
+ Some(_) => Op::ImageDrefGather,
+ };
+
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(sampled_image);
+ instruction.add_operand(coordinates);
+ if let Some(dref) = depth_ref {
+ instruction.add_operand(dref);
+ } else {
+ instruction.add_operand(component_id);
+ }
+
+ instruction
+ }
+
+ pub(super) fn image_fetch_or_read(
+ op: Op,
+ result_type_id: Word,
+ id: Word,
+ image: Word,
+ coordinates: Word,
+ ) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(image);
+ instruction.add_operand(coordinates);
+ instruction
+ }
+
+ pub(super) fn image_write(image: Word, coordinates: Word, value: Word) -> Self {
+ let mut instruction = Self::new(Op::ImageWrite);
+ instruction.add_operand(image);
+ instruction.add_operand(coordinates);
+ instruction.add_operand(value);
+ instruction
+ }
+
+ pub(super) fn image_query(op: Op, result_type_id: Word, id: Word, image: Word) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(image);
+ instruction
+ }
+
+ //
+ // Conversion Instructions
+ //
+ pub(super) fn unary(op: Op, result_type_id: Word, id: Word, value: Word) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(value);
+ instruction
+ }
+
+ //
+ // Composite Instructions
+ //
+
+ pub(super) fn composite_construct(
+ result_type_id: Word,
+ id: Word,
+ constituent_ids: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::CompositeConstruct);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+
+ for constituent_id in constituent_ids {
+ instruction.add_operand(*constituent_id);
+ }
+
+ instruction
+ }
+
+ pub(super) fn composite_extract(
+ result_type_id: Word,
+ id: Word,
+ composite_id: Word,
+ indices: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::CompositeExtract);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+
+ instruction.add_operand(composite_id);
+ for index in indices {
+ instruction.add_operand(*index);
+ }
+
+ instruction
+ }
+
+ pub(super) fn vector_extract_dynamic(
+ result_type_id: Word,
+ id: Word,
+ vector_id: Word,
+ index_id: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::VectorExtractDynamic);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+
+ instruction.add_operand(vector_id);
+ instruction.add_operand(index_id);
+
+ instruction
+ }
+
+ pub(super) fn vector_shuffle(
+ result_type_id: Word,
+ id: Word,
+ v1_id: Word,
+ v2_id: Word,
+ components: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::VectorShuffle);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(v1_id);
+ instruction.add_operand(v2_id);
+
+ for &component in components {
+ instruction.add_operand(component);
+ }
+
+ instruction
+ }
+
+ //
+ // Arithmetic Instructions
+ //
+ pub(super) fn binary(
+ op: Op,
+ result_type_id: Word,
+ id: Word,
+ operand_1: Word,
+ operand_2: Word,
+ ) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(operand_1);
+ instruction.add_operand(operand_2);
+ instruction
+ }
+
+ pub(super) fn ternary(
+ op: Op,
+ result_type_id: Word,
+ id: Word,
+ operand_1: Word,
+ operand_2: Word,
+ operand_3: Word,
+ ) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(operand_1);
+ instruction.add_operand(operand_2);
+ instruction.add_operand(operand_3);
+ instruction
+ }
+
+ pub(super) fn quaternary(
+ op: Op,
+ result_type_id: Word,
+ id: Word,
+ operand_1: Word,
+ operand_2: Word,
+ operand_3: Word,
+ operand_4: Word,
+ ) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(operand_1);
+ instruction.add_operand(operand_2);
+ instruction.add_operand(operand_3);
+ instruction.add_operand(operand_4);
+ instruction
+ }
+
+ pub(super) fn relational(op: Op, result_type_id: Word, id: Word, expr_id: Word) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(expr_id);
+ instruction
+ }
+
+ pub(super) fn atomic_binary(
+ op: Op,
+ result_type_id: Word,
+ id: Word,
+ pointer: Word,
+ scope_id: Word,
+ semantics_id: Word,
+ value: Word,
+ ) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(pointer);
+ instruction.add_operand(scope_id);
+ instruction.add_operand(semantics_id);
+ instruction.add_operand(value);
+ instruction
+ }
+
+ //
+ // Bit Instructions
+ //
+
+ //
+ // Relational and Logical Instructions
+ //
+
+ //
+ // Derivative Instructions
+ //
+
+ pub(super) fn derivative(op: Op, result_type_id: Word, id: Word, expr_id: Word) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(expr_id);
+ instruction
+ }
+
+ //
+ // Control-Flow Instructions
+ //
+
+ pub(super) fn phi(
+ result_type_id: Word,
+ result_id: Word,
+ var_parent_pairs: &[(Word, Word)],
+ ) -> Self {
+ let mut instruction = Self::new(Op::Phi);
+ instruction.add_operand(result_type_id);
+ instruction.add_operand(result_id);
+ for &(variable, parent) in var_parent_pairs {
+ instruction.add_operand(variable);
+ instruction.add_operand(parent);
+ }
+ instruction
+ }
+
+ pub(super) fn selection_merge(
+ merge_id: Word,
+ selection_control: spirv::SelectionControl,
+ ) -> Self {
+ let mut instruction = Self::new(Op::SelectionMerge);
+ instruction.add_operand(merge_id);
+ instruction.add_operand(selection_control.bits());
+ instruction
+ }
+
+ pub(super) fn loop_merge(
+ merge_id: Word,
+ continuing_id: Word,
+ selection_control: spirv::SelectionControl,
+ ) -> Self {
+ let mut instruction = Self::new(Op::LoopMerge);
+ instruction.add_operand(merge_id);
+ instruction.add_operand(continuing_id);
+ instruction.add_operand(selection_control.bits());
+ instruction
+ }
+
+ pub(super) fn label(id: Word) -> Self {
+ let mut instruction = Self::new(Op::Label);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) fn branch(id: Word) -> Self {
+ let mut instruction = Self::new(Op::Branch);
+ instruction.add_operand(id);
+ instruction
+ }
+
+ // TODO Branch Weights not implemented.
+ pub(super) fn branch_conditional(
+ condition_id: Word,
+ true_label: Word,
+ false_label: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::BranchConditional);
+ instruction.add_operand(condition_id);
+ instruction.add_operand(true_label);
+ instruction.add_operand(false_label);
+ instruction
+ }
+
+ pub(super) fn switch(selector_id: Word, default_id: Word, cases: &[Case]) -> Self {
+ let mut instruction = Self::new(Op::Switch);
+ instruction.add_operand(selector_id);
+ instruction.add_operand(default_id);
+ for case in cases {
+ instruction.add_operand(case.value);
+ instruction.add_operand(case.label_id);
+ }
+ instruction
+ }
+
+ pub(super) fn select(
+ result_type_id: Word,
+ id: Word,
+ condition_id: Word,
+ accept_id: Word,
+ reject_id: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::Select);
+ instruction.add_operand(result_type_id);
+ instruction.add_operand(id);
+ instruction.add_operand(condition_id);
+ instruction.add_operand(accept_id);
+ instruction.add_operand(reject_id);
+ instruction
+ }
+
+ pub(super) const fn kill() -> Self {
+ Self::new(Op::Kill)
+ }
+
+ pub(super) const fn return_void() -> Self {
+ Self::new(Op::Return)
+ }
+
+ pub(super) fn return_value(value_id: Word) -> Self {
+ let mut instruction = Self::new(Op::ReturnValue);
+ instruction.add_operand(value_id);
+ instruction
+ }
+
+ //
+ // Atomic Instructions
+ //
+
+ //
+ // Primitive Instructions
+ //
+
+ // Barriers
+
+ pub(super) fn control_barrier(
+ exec_scope_id: Word,
+ mem_scope_id: Word,
+ semantics_id: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::ControlBarrier);
+ instruction.add_operand(exec_scope_id);
+ instruction.add_operand(mem_scope_id);
+ instruction.add_operand(semantics_id);
+ instruction
+ }
+}
+
+impl From<crate::StorageFormat> for spirv::ImageFormat {
+ fn from(format: crate::StorageFormat) -> Self {
+ use crate::StorageFormat as Sf;
+ match format {
+ Sf::R8Unorm => Self::R8,
+ Sf::R8Snorm => Self::R8Snorm,
+ Sf::R8Uint => Self::R8ui,
+ Sf::R8Sint => Self::R8i,
+ Sf::R16Uint => Self::R16ui,
+ Sf::R16Sint => Self::R16i,
+ Sf::R16Float => Self::R16f,
+ Sf::Rg8Unorm => Self::Rg8,
+ Sf::Rg8Snorm => Self::Rg8Snorm,
+ Sf::Rg8Uint => Self::Rg8ui,
+ Sf::Rg8Sint => Self::Rg8i,
+ Sf::R32Uint => Self::R32ui,
+ Sf::R32Sint => Self::R32i,
+ Sf::R32Float => Self::R32f,
+ Sf::Rg16Uint => Self::Rg16ui,
+ Sf::Rg16Sint => Self::Rg16i,
+ Sf::Rg16Float => Self::Rg16f,
+ Sf::Rgba8Unorm => Self::Rgba8,
+ Sf::Rgba8Snorm => Self::Rgba8Snorm,
+ Sf::Rgba8Uint => Self::Rgba8ui,
+ Sf::Rgba8Sint => Self::Rgba8i,
+ Sf::Rgb10a2Unorm => Self::Rgb10a2ui,
+ Sf::Rg11b10Float => Self::R11fG11fB10f,
+ Sf::Rg32Uint => Self::Rg32ui,
+ Sf::Rg32Sint => Self::Rg32i,
+ Sf::Rg32Float => Self::Rg32f,
+ Sf::Rgba16Uint => Self::Rgba16ui,
+ Sf::Rgba16Sint => Self::Rgba16i,
+ Sf::Rgba16Float => Self::Rgba16f,
+ Sf::Rgba32Uint => Self::Rgba32ui,
+ Sf::Rgba32Sint => Self::Rgba32i,
+ Sf::Rgba32Float => Self::Rgba32f,
+ }
+ }
+}
+
+impl From<crate::ImageDimension> for spirv::Dim {
+ fn from(dim: crate::ImageDimension) -> Self {
+ use crate::ImageDimension as Id;
+ match dim {
+ Id::D1 => Self::Dim1D,
+ Id::D2 => Self::Dim2D,
+ Id::D3 => Self::Dim3D,
+ Id::Cube => Self::DimCube,
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/layout.rs b/third_party/rust/naga/src/back/spv/layout.rs
new file mode 100644
index 0000000000..39117a3d2a
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/layout.rs
@@ -0,0 +1,210 @@
+use super::{Instruction, LogicalLayout, PhysicalLayout};
+use spirv::{Op, Word, MAGIC_NUMBER};
+use std::iter;
+
+// https://github.com/KhronosGroup/SPIRV-Headers/pull/195
+const GENERATOR: Word = 28;
+
+impl PhysicalLayout {
+ pub(super) const fn new(version: Word) -> Self {
+ PhysicalLayout {
+ magic_number: MAGIC_NUMBER,
+ version,
+ generator: GENERATOR,
+ bound: 0,
+ instruction_schema: 0x0u32,
+ }
+ }
+
+ pub(super) fn in_words(&self, sink: &mut impl Extend<Word>) {
+ sink.extend(iter::once(self.magic_number));
+ sink.extend(iter::once(self.version));
+ sink.extend(iter::once(self.generator));
+ sink.extend(iter::once(self.bound));
+ sink.extend(iter::once(self.instruction_schema));
+ }
+}
+
+impl super::recyclable::Recyclable for PhysicalLayout {
+ fn recycle(self) -> Self {
+ PhysicalLayout {
+ magic_number: self.magic_number,
+ version: self.version,
+ generator: self.generator,
+ instruction_schema: self.instruction_schema,
+ bound: 0,
+ }
+ }
+}
+
+impl LogicalLayout {
+ pub(super) fn in_words(&self, sink: &mut impl Extend<Word>) {
+ sink.extend(self.capabilities.iter().cloned());
+ sink.extend(self.extensions.iter().cloned());
+ sink.extend(self.ext_inst_imports.iter().cloned());
+ sink.extend(self.memory_model.iter().cloned());
+ sink.extend(self.entry_points.iter().cloned());
+ sink.extend(self.execution_modes.iter().cloned());
+ sink.extend(self.debugs.iter().cloned());
+ sink.extend(self.annotations.iter().cloned());
+ sink.extend(self.declarations.iter().cloned());
+ sink.extend(self.function_declarations.iter().cloned());
+ sink.extend(self.function_definitions.iter().cloned());
+ }
+}
+
+impl super::recyclable::Recyclable for LogicalLayout {
+ fn recycle(self) -> Self {
+ Self {
+ capabilities: self.capabilities.recycle(),
+ extensions: self.extensions.recycle(),
+ ext_inst_imports: self.ext_inst_imports.recycle(),
+ memory_model: self.memory_model.recycle(),
+ entry_points: self.entry_points.recycle(),
+ execution_modes: self.execution_modes.recycle(),
+ debugs: self.debugs.recycle(),
+ annotations: self.annotations.recycle(),
+ declarations: self.declarations.recycle(),
+ function_declarations: self.function_declarations.recycle(),
+ function_definitions: self.function_definitions.recycle(),
+ }
+ }
+}
+
+impl Instruction {
+ pub(super) const fn new(op: Op) -> Self {
+ Instruction {
+ op,
+ wc: 1, // Always start at 1 for the first word (OP + WC),
+ type_id: None,
+ result_id: None,
+ operands: vec![],
+ }
+ }
+
+ #[allow(clippy::panic)]
+ pub(super) fn set_type(&mut self, id: Word) {
+ assert!(self.type_id.is_none(), "Type can only be set once");
+ self.type_id = Some(id);
+ self.wc += 1;
+ }
+
+ #[allow(clippy::panic)]
+ pub(super) fn set_result(&mut self, id: Word) {
+ assert!(self.result_id.is_none(), "Result can only be set once");
+ self.result_id = Some(id);
+ self.wc += 1;
+ }
+
+ pub(super) fn add_operand(&mut self, operand: Word) {
+ self.operands.push(operand);
+ self.wc += 1;
+ }
+
+ pub(super) fn add_operands(&mut self, operands: Vec<Word>) {
+ for operand in operands.into_iter() {
+ self.add_operand(operand)
+ }
+ }
+
+ pub(super) fn to_words(&self, sink: &mut impl Extend<Word>) {
+ sink.extend(Some(self.wc << 16 | self.op as u32));
+ sink.extend(self.type_id);
+ sink.extend(self.result_id);
+ sink.extend(self.operands.iter().cloned());
+ }
+}
+
+impl Instruction {
+ #[cfg(test)]
+ fn validate(&self, words: &[Word]) {
+ let mut inst_index = 0;
+ let (wc, op) = ((words[inst_index] >> 16) as u16, words[inst_index] as u16);
+ inst_index += 1;
+
+ assert_eq!(wc, words.len() as u16);
+ assert_eq!(op, self.op as u16);
+
+ if self.type_id.is_some() {
+ assert_eq!(words[inst_index], self.type_id.unwrap());
+ inst_index += 1;
+ }
+
+ if self.result_id.is_some() {
+ assert_eq!(words[inst_index], self.result_id.unwrap());
+ inst_index += 1;
+ }
+
+ for (op_index, i) in (inst_index..wc as usize).enumerate() {
+ assert_eq!(words[i], self.operands[op_index]);
+ }
+ }
+}
+
+#[test]
+fn test_physical_layout_in_words() {
+ let bound = 5;
+ let version = 0x10203;
+
+ let mut output = vec![];
+ let mut layout = PhysicalLayout::new(version);
+ layout.bound = bound;
+
+ layout.in_words(&mut output);
+
+ assert_eq!(&output, &[MAGIC_NUMBER, version, GENERATOR, bound, 0,]);
+}
+
+#[test]
+fn test_logical_layout_in_words() {
+ let mut output = vec![];
+ let mut layout = LogicalLayout::default();
+ let layout_vectors = 11;
+ let mut instructions = Vec::with_capacity(layout_vectors);
+
+ let vector_names = &[
+ "Capabilities",
+ "Extensions",
+ "External Instruction Imports",
+ "Memory Model",
+ "Entry Points",
+ "Execution Modes",
+ "Debugs",
+ "Annotations",
+ "Declarations",
+ "Function Declarations",
+ "Function Definitions",
+ ];
+
+ for (i, _) in vector_names.iter().enumerate().take(layout_vectors) {
+ let mut dummy_instruction = Instruction::new(Op::Constant);
+ dummy_instruction.set_type((i + 1) as u32);
+ dummy_instruction.set_result((i + 2) as u32);
+ dummy_instruction.add_operand((i + 3) as u32);
+ dummy_instruction.add_operands(super::helpers::string_to_words(
+ format!("This is the vector: {}", vector_names[i]).as_str(),
+ ));
+ instructions.push(dummy_instruction);
+ }
+
+ instructions[0].to_words(&mut layout.capabilities);
+ instructions[1].to_words(&mut layout.extensions);
+ instructions[2].to_words(&mut layout.ext_inst_imports);
+ instructions[3].to_words(&mut layout.memory_model);
+ instructions[4].to_words(&mut layout.entry_points);
+ instructions[5].to_words(&mut layout.execution_modes);
+ instructions[6].to_words(&mut layout.debugs);
+ instructions[7].to_words(&mut layout.annotations);
+ instructions[8].to_words(&mut layout.declarations);
+ instructions[9].to_words(&mut layout.function_declarations);
+ instructions[10].to_words(&mut layout.function_definitions);
+
+ layout.in_words(&mut output);
+
+ let mut index: usize = 0;
+ for instruction in instructions {
+ let wc = instruction.wc as usize;
+ instruction.validate(&output[index..index + wc]);
+ index += wc;
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/mod.rs b/third_party/rust/naga/src/back/spv/mod.rs
new file mode 100644
index 0000000000..544f5ca4f5
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/mod.rs
@@ -0,0 +1,696 @@
+/*!
+Backend for [SPIR-V][spv] (Standard Portable Intermediate Representation).
+
+[spv]: https://www.khronos.org/registry/SPIR-V/
+*/
+
+mod block;
+mod helpers;
+mod image;
+mod index;
+mod instructions;
+mod layout;
+mod recyclable;
+mod selection;
+mod writer;
+
+pub use spirv::Capability;
+
+use crate::arena::Handle;
+use crate::proc::{BoundsCheckPolicies, TypeResolution};
+
+use spirv::Word;
+use std::ops;
+use thiserror::Error;
+
+#[derive(Clone)]
+struct PhysicalLayout {
+ magic_number: Word,
+ version: Word,
+ generator: Word,
+ bound: Word,
+ instruction_schema: Word,
+}
+
+#[derive(Default)]
+struct LogicalLayout {
+ capabilities: Vec<Word>,
+ extensions: Vec<Word>,
+ ext_inst_imports: Vec<Word>,
+ memory_model: Vec<Word>,
+ entry_points: Vec<Word>,
+ execution_modes: Vec<Word>,
+ debugs: Vec<Word>,
+ annotations: Vec<Word>,
+ declarations: Vec<Word>,
+ function_declarations: Vec<Word>,
+ function_definitions: Vec<Word>,
+}
+
+struct Instruction {
+ op: spirv::Op,
+ wc: u32,
+ type_id: Option<Word>,
+ result_id: Option<Word>,
+ operands: Vec<Word>,
+}
+
+const BITS_PER_BYTE: crate::Bytes = 8;
+
+#[derive(Clone, Debug, Error)]
+pub enum Error {
+ #[error("The requested entry point couldn't be found")]
+ EntryPointNotFound,
+ #[error("target SPIRV-{0}.{1} is not supported")]
+ UnsupportedVersion(u8, u8),
+ #[error("using {0} requires at least one of the capabilities {1:?}, but none are available")]
+ MissingCapabilities(&'static str, Vec<Capability>),
+ #[error("unimplemented {0}")]
+ FeatureNotImplemented(&'static str),
+ #[error("module is not validated properly: {0}")]
+ Validation(&'static str),
+}
+
+#[derive(Default)]
+struct IdGenerator(Word);
+
+impl IdGenerator {
+ fn next(&mut self) -> Word {
+ self.0 += 1;
+ self.0
+ }
+}
+
+/// A SPIR-V block to which we are still adding instructions.
+///
+/// A `Block` represents a SPIR-V block that does not yet have a termination
+/// instruction like `OpBranch` or `OpReturn`.
+///
+/// The `OpLabel` that starts the block is implicit. It will be emitted based on
+/// `label_id` when we write the block to a `LogicalLayout`.
+///
+/// To terminate a `Block`, pass the block and the termination instruction to
+/// `Function::consume`. This takes ownership of the `Block` and transforms it
+/// into a `TerminatedBlock`.
+struct Block {
+ label_id: Word,
+ body: Vec<Instruction>,
+}
+
+/// A SPIR-V block that ends with a termination instruction.
+struct TerminatedBlock {
+ label_id: Word,
+ body: Vec<Instruction>,
+}
+
+impl Block {
+ const fn new(label_id: Word) -> Self {
+ Block {
+ label_id,
+ body: Vec::new(),
+ }
+ }
+}
+
+struct LocalVariable {
+ id: Word,
+ instruction: Instruction,
+}
+
+struct ResultMember {
+ id: Word,
+ type_id: Word,
+ built_in: Option<crate::BuiltIn>,
+}
+
+struct EntryPointContext {
+ argument_ids: Vec<Word>,
+ results: Vec<ResultMember>,
+}
+
+#[derive(Default)]
+struct Function {
+ signature: Option<Instruction>,
+ parameters: Vec<FunctionArgument>,
+ variables: crate::FastHashMap<Handle<crate::LocalVariable>, LocalVariable>,
+ blocks: Vec<TerminatedBlock>,
+ entry_point_context: Option<EntryPointContext>,
+}
+
+impl Function {
+ fn consume(&mut self, mut block: Block, termination: Instruction) {
+ block.body.push(termination);
+ self.blocks.push(TerminatedBlock {
+ label_id: block.label_id,
+ body: block.body,
+ })
+ }
+
+ fn parameter_id(&self, index: u32) -> Word {
+ match self.entry_point_context {
+ Some(ref context) => context.argument_ids[index as usize],
+ None => self.parameters[index as usize]
+ .instruction
+ .result_id
+ .unwrap(),
+ }
+ }
+}
+
+/// Characteristics of a SPIR-V `OpTypeImage` type.
+///
+/// SPIR-V requires non-composite types to be unique, including images. Since we
+/// use `LocalType` for this deduplication, it's essential that `LocalImageType`
+/// be equal whenever the corresponding `OpTypeImage`s would be. To reduce the
+/// likelihood of mistakes, we use fields that correspond exactly to the
+/// operands of an `OpTypeImage` instruction, using the actual SPIR-V types
+/// where practical.
+#[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)]
+struct LocalImageType {
+ sampled_type: crate::ScalarKind,
+ dim: spirv::Dim,
+ flags: ImageTypeFlags,
+ image_format: spirv::ImageFormat,
+}
+
+bitflags::bitflags! {
+ /// Flags corresponding to the boolean(-ish) parameters to OpTypeImage.
+ pub struct ImageTypeFlags: u8 {
+ const DEPTH = 0x1;
+ const ARRAYED = 0x2;
+ const MULTISAMPLED = 0x4;
+ const SAMPLED = 0x8;
+ }
+}
+
+impl LocalImageType {
+ /// Construct a `LocalImageType` from the fields of a `TypeInner::Image`.
+ fn from_inner(dim: crate::ImageDimension, arrayed: bool, class: crate::ImageClass) -> Self {
+ let make_flags = |multi: bool, other: ImageTypeFlags| -> ImageTypeFlags {
+ let mut flags = other;
+ flags.set(ImageTypeFlags::ARRAYED, arrayed);
+ flags.set(ImageTypeFlags::MULTISAMPLED, multi);
+ flags
+ };
+
+ let dim = spirv::Dim::from(dim);
+
+ match class {
+ crate::ImageClass::Sampled { kind, multi } => LocalImageType {
+ sampled_type: kind,
+ dim,
+ flags: make_flags(multi, ImageTypeFlags::SAMPLED),
+ image_format: spirv::ImageFormat::Unknown,
+ },
+ crate::ImageClass::Depth { multi } => LocalImageType {
+ sampled_type: crate::ScalarKind::Float,
+ dim,
+ flags: make_flags(multi, ImageTypeFlags::DEPTH | ImageTypeFlags::SAMPLED),
+ image_format: spirv::ImageFormat::Unknown,
+ },
+ crate::ImageClass::Storage { format, access: _ } => LocalImageType {
+ sampled_type: crate::ScalarKind::from(format),
+ dim,
+ flags: make_flags(false, ImageTypeFlags::empty()),
+ image_format: format.into(),
+ },
+ }
+ }
+}
+
+/// A SPIR-V type constructed during code generation.
+///
+/// This is the variant of [`LookupType`] used to represent types that might not
+/// be available in the arena. Variants are present here for one of two reasons:
+///
+/// - They represent types synthesized during code generation, as explained
+/// in the documentation for [`LookupType`].
+///
+/// - They represent types for which SPIR-V forbids duplicate `OpType...`
+/// instructions, requiring deduplication.
+///
+/// This is not a complete copy of [`TypeInner`]: for example, SPIR-V generation
+/// never synthesizes new struct types, so `LocalType` has nothing for that.
+///
+/// Each `LocalType` variant should be handled identically to its analogous
+/// `TypeInner` variant. You can use the [`make_local`] function to help with
+/// this, by converting everything possible to a `LocalType` before inspecting
+/// it.
+///
+/// ## `Localtype` equality and SPIR-V `OpType` uniqueness
+///
+/// The definition of `Eq` on `LocalType` is carefully chosen to help us follow
+/// certain SPIR-V rules. SPIR-V §2.8 requires some classes of `OpType...`
+/// instructions to be unique; for example, you can't have two `OpTypeInt 32 1`
+/// instructions in the same module. All 32-bit signed integers must use the
+/// same type id.
+///
+/// All SPIR-V types that must be unique can be represented as a `LocalType`,
+/// and two `LocalType`s are always `Eq` if SPIR-V would require them to use the
+/// same `OpType...` instruction. This lets us avoid duplicates by recording the
+/// ids of the type instructions we've already generated in a hash table,
+/// [`Writer::lookup_type`], keyed by `LocalType`.
+///
+/// As another example, [`LocalImageType`], stored in the `LocalType::Image`
+/// variant, is designed to help us deduplicate `OpTypeImage` instructions. See
+/// its documentation for details.
+///
+/// `LocalType` also includes variants like `Pointer` that do not need to be
+/// unique - but it is harmless to avoid the duplication.
+///
+/// As it always must, the `Hash` implementation respects the `Eq` relation.
+///
+/// [`TypeInner`]: crate::TypeInner
+#[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)]
+enum LocalType {
+ /// A scalar, vector, or pointer to one of those.
+ Value {
+ /// If `None`, this represents a scalar type. If `Some`, this represents
+ /// a vector type of the given size.
+ vector_size: Option<crate::VectorSize>,
+ kind: crate::ScalarKind,
+ width: crate::Bytes,
+ pointer_space: Option<spirv::StorageClass>,
+ },
+ /// A matrix of floating-point values.
+ Matrix {
+ columns: crate::VectorSize,
+ rows: crate::VectorSize,
+ width: crate::Bytes,
+ },
+ Pointer {
+ base: Handle<crate::Type>,
+ class: spirv::StorageClass,
+ },
+ Image(LocalImageType),
+ SampledImage {
+ image_type_id: Word,
+ },
+ Sampler,
+ PointerToBindingArray {
+ base: Handle<crate::Type>,
+ size: u64,
+ },
+ BindingArray {
+ base: Handle<crate::Type>,
+ size: u64,
+ },
+}
+
+/// A type encountered during SPIR-V generation.
+///
+/// In the process of writing SPIR-V, we need to synthesize various types for
+/// intermediate results and such: pointer types, vector/matrix component types,
+/// or even booleans, which usually appear in SPIR-V code even when they're not
+/// used by the module source.
+///
+/// However, we can't use `crate::Type` or `crate::TypeInner` for these, as the
+/// type arena may not contain what we need (it only contains types used
+/// directly by other parts of the IR), and the IR module is immutable, so we
+/// can't add anything to it.
+///
+/// So for local use in the SPIR-V writer, we use this type, which holds either
+/// a handle into the arena, or a [`LocalType`] containing something synthesized
+/// locally.
+///
+/// This is very similar to the [`proc::TypeResolution`] enum, with `LocalType`
+/// playing the role of `TypeInner`. However, `LocalType` also has other
+/// properties needed for SPIR-V generation; see the description of
+/// [`LocalType`] for details.
+///
+/// [`proc::TypeResolution`]: crate::proc::TypeResolution
+#[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)]
+enum LookupType {
+ Handle(Handle<crate::Type>),
+ Local(LocalType),
+}
+
+impl From<LocalType> for LookupType {
+ fn from(local: LocalType) -> Self {
+ Self::Local(local)
+ }
+}
+
+#[derive(Debug, PartialEq, Clone, Hash, Eq)]
+struct LookupFunctionType {
+ parameter_type_ids: Vec<Word>,
+ return_type_id: Word,
+}
+
+fn make_local(inner: &crate::TypeInner) -> Option<LocalType> {
+ Some(match *inner {
+ crate::TypeInner::Scalar { kind, width } | crate::TypeInner::Atomic { kind, width } => {
+ LocalType::Value {
+ vector_size: None,
+ kind,
+ width,
+ pointer_space: None,
+ }
+ }
+ crate::TypeInner::Vector { size, kind, width } => LocalType::Value {
+ vector_size: Some(size),
+ kind,
+ width,
+ pointer_space: None,
+ },
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => LocalType::Matrix {
+ columns,
+ rows,
+ width,
+ },
+ crate::TypeInner::Pointer { base, space } => LocalType::Pointer {
+ base,
+ class: helpers::map_storage_class(space),
+ },
+ crate::TypeInner::ValuePointer {
+ size,
+ kind,
+ width,
+ space,
+ } => LocalType::Value {
+ vector_size: size,
+ kind,
+ width,
+ pointer_space: Some(helpers::map_storage_class(space)),
+ },
+ crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => LocalType::Image(LocalImageType::from_inner(dim, arrayed, class)),
+ crate::TypeInner::Sampler { comparison: _ } => LocalType::Sampler,
+ _ => return None,
+ })
+}
+
+#[derive(Debug)]
+enum Dimension {
+ Scalar,
+ Vector,
+ Matrix,
+}
+
+/// A map from evaluated [`Expression`](crate::Expression)s to their SPIR-V ids.
+///
+/// When we emit code to evaluate a given `Expression`, we record the
+/// SPIR-V id of its value here, under its `Handle<Expression>` index.
+///
+/// A `CachedExpressions` value can be indexed by a `Handle<Expression>` value.
+///
+/// [emit]: index.html#expression-evaluation-time-and-scope
+#[derive(Default)]
+struct CachedExpressions {
+ ids: Vec<Word>,
+}
+impl CachedExpressions {
+ fn reset(&mut self, length: usize) {
+ self.ids.clear();
+ self.ids.resize(length, 0);
+ }
+}
+impl ops::Index<Handle<crate::Expression>> for CachedExpressions {
+ type Output = Word;
+ fn index(&self, h: Handle<crate::Expression>) -> &Word {
+ let id = &self.ids[h.index()];
+ if *id == 0 {
+ unreachable!("Expression {:?} is not cached!", h);
+ }
+ id
+ }
+}
+impl ops::IndexMut<Handle<crate::Expression>> for CachedExpressions {
+ fn index_mut(&mut self, h: Handle<crate::Expression>) -> &mut Word {
+ let id = &mut self.ids[h.index()];
+ if *id != 0 {
+ unreachable!("Expression {:?} is already cached!", h);
+ }
+ id
+ }
+}
+impl recyclable::Recyclable for CachedExpressions {
+ fn recycle(self) -> Self {
+ CachedExpressions {
+ ids: self.ids.recycle(),
+ }
+ }
+}
+
+#[derive(Clone)]
+struct GlobalVariable {
+ /// ID of the OpVariable that declares the global.
+ ///
+ /// If you need the variable's value, use [`access_id`] instead of this
+ /// field. If we wrapped the Naga IR `GlobalVariable`'s type in a struct to
+ /// comply with Vulkan's requirements, then this points to the `OpVariable`
+ /// with the synthesized struct type, whereas `access_id` points to the
+ /// field of said struct that holds the variable's actual value.
+ ///
+ /// This is used to compute the `access_id` pointer in function prologues,
+ /// and used for `ArrayLength` expressions, which do need the struct.
+ ///
+ /// [`access_id`]: GlobalVariable::access_id
+ var_id: Word,
+
+ /// For `AddressSpace::Handle` variables, this ID is recorded in the function
+ /// prelude block (and reset before every function) as `OpLoad` of the variable.
+ /// It is then used for all the global ops, such as `OpImageSample`.
+ handle_id: Word,
+
+ /// Actual ID used to access this variable.
+ /// For wrapped buffer variables, this ID is `OpAccessChain` into the
+ /// wrapper. Otherwise, the same as `var_id`.
+ ///
+ /// Vulkan requires that globals in the `StorageBuffer` and `Uniform` storage
+ /// classes must be structs with the `Block` decoration, but WGSL and Naga IR
+ /// make no such requirement. So for such variables, we generate a wrapper struct
+ /// type with a single element of the type given by Naga, generate an
+ /// `OpAccessChain` for that member in the function prelude, and use that pointer
+ /// to refer to the global in the function body. This is the id of that access,
+ /// updated for each function in `write_function`.
+ access_id: Word,
+}
+
+impl GlobalVariable {
+ const fn dummy() -> Self {
+ Self {
+ var_id: 0,
+ handle_id: 0,
+ access_id: 0,
+ }
+ }
+
+ const fn new(id: Word) -> Self {
+ Self {
+ var_id: id,
+ handle_id: 0,
+ access_id: 0,
+ }
+ }
+
+ /// Prepare `self` for use within a single function.
+ fn reset_for_function(&mut self) {
+ self.handle_id = 0;
+ self.access_id = 0;
+ }
+}
+
+struct FunctionArgument {
+ /// Actual instruction of the argument.
+ instruction: Instruction,
+ handle_id: Word,
+}
+
+/// General information needed to emit SPIR-V for Naga statements.
+struct BlockContext<'w> {
+ /// The writer handling the module to which this code belongs.
+ writer: &'w mut Writer,
+
+ /// The [`Module`](crate::Module) for which we're generating code.
+ ir_module: &'w crate::Module,
+
+ /// The [`Function`](crate::Function) for which we're generating code.
+ ir_function: &'w crate::Function,
+
+ /// Information module validation produced about
+ /// [`ir_function`](BlockContext::ir_function).
+ fun_info: &'w crate::valid::FunctionInfo,
+
+ /// The [`spv::Function`](Function) to which we are contributing SPIR-V instructions.
+ function: &'w mut Function,
+
+ /// SPIR-V ids for expressions we've evaluated.
+ cached: CachedExpressions,
+
+ /// The `Writer`'s temporary vector, for convenience.
+ temp_list: Vec<Word>,
+}
+
+impl BlockContext<'_> {
+ fn gen_id(&mut self) -> Word {
+ self.writer.id_gen.next()
+ }
+
+ fn get_type_id(&mut self, lookup_type: LookupType) -> Word {
+ self.writer.get_type_id(lookup_type)
+ }
+
+ fn get_expression_type_id(&mut self, tr: &TypeResolution) -> Word {
+ self.writer.get_expression_type_id(tr)
+ }
+
+ fn get_index_constant(&mut self, index: Word) -> Word {
+ self.writer
+ .get_constant_scalar(crate::ScalarValue::Uint(index as _), 4)
+ }
+
+ fn get_scope_constant(&mut self, scope: Word) -> Word {
+ self.writer
+ .get_constant_scalar(crate::ScalarValue::Sint(scope as _), 4)
+ }
+}
+
+#[derive(Clone, Copy, Default)]
+struct LoopContext {
+ continuing_id: Option<Word>,
+ break_id: Option<Word>,
+}
+
+pub struct Writer {
+ physical_layout: PhysicalLayout,
+ logical_layout: LogicalLayout,
+ id_gen: IdGenerator,
+
+ /// The set of capabilities modules are permitted to use.
+ ///
+ /// This is initialized from `Options::capabilities`.
+ capabilities_available: Option<crate::FastHashSet<Capability>>,
+
+ /// The set of capabilities used by this module.
+ ///
+ /// If `capabilities_available` is `Some`, then this is always a subset of
+ /// that.
+ capabilities_used: crate::FastHashSet<Capability>,
+
+ /// The set of spirv extensions used.
+ extensions_used: crate::FastHashSet<&'static str>,
+
+ debugs: Vec<Instruction>,
+ annotations: Vec<Instruction>,
+ flags: WriterFlags,
+ bounds_check_policies: BoundsCheckPolicies,
+ void_type: Word,
+ //TODO: convert most of these into vectors, addressable by handle indices
+ lookup_type: crate::FastHashMap<LookupType, Word>,
+ lookup_function: crate::FastHashMap<Handle<crate::Function>, Word>,
+ lookup_function_type: crate::FastHashMap<LookupFunctionType, Word>,
+ constant_ids: Vec<Word>,
+ cached_constants: crate::FastHashMap<(crate::ScalarValue, crate::Bytes), Word>,
+ global_variables: Vec<GlobalVariable>,
+ binding_map: BindingMap,
+
+ // Cached expressions are only meaningful within a BlockContext, but we
+ // retain the table here between functions to save heap allocations.
+ saved_cached: CachedExpressions,
+
+ gl450_ext_inst_id: Word,
+ // Just a temporary list of SPIR-V ids
+ temp_list: Vec<Word>,
+}
+
+bitflags::bitflags! {
+ pub struct WriterFlags: u32 {
+ /// Include debug labels for everything.
+ const DEBUG = 0x1;
+ /// Flip Y coordinate of `BuiltIn::Position` output.
+ const ADJUST_COORDINATE_SPACE = 0x2;
+ /// Emit `OpName` for input/output locations.
+ /// Contrary to spec, some drivers treat it as semantic, not allowing
+ /// any conflicts.
+ const LABEL_VARYINGS = 0x4;
+ /// Emit `PointSize` output builtin to vertex shaders, which is
+ /// required for drawing with `PointList` topology.
+ const FORCE_POINT_SIZE = 0x8;
+ /// Clamp `BuiltIn::FragDepth` output between 0 and 1.
+ const CLAMP_FRAG_DEPTH = 0x10;
+ }
+}
+
+#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct BindingInfo {
+ /// If the binding is an unsized binding array, this overrides the size.
+ pub binding_array_size: Option<u32>,
+}
+
+// Using `BTreeMap` instead of `HashMap` so that we can hash itself.
+pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, BindingInfo>;
+
+#[derive(Debug, Clone)]
+pub struct Options {
+ /// (Major, Minor) target version of the SPIR-V.
+ pub lang_version: (u8, u8),
+
+ /// Configuration flags for the writer.
+ pub flags: WriterFlags,
+
+ /// Map of resources to information about the binding.
+ pub binding_map: BindingMap,
+
+ /// If given, the set of capabilities modules are allowed to use. Code that
+ /// requires capabilities beyond these is rejected with an error.
+ ///
+ /// If this is `None`, all capabilities are permitted.
+ pub capabilities: Option<crate::FastHashSet<Capability>>,
+
+ /// How should generate code handle array, vector, matrix, or image texel
+ /// indices that are out of range?
+ pub bounds_check_policies: BoundsCheckPolicies,
+}
+
+impl Default for Options {
+ fn default() -> Self {
+ let mut flags = WriterFlags::ADJUST_COORDINATE_SPACE
+ | WriterFlags::LABEL_VARYINGS
+ | WriterFlags::CLAMP_FRAG_DEPTH;
+ if cfg!(debug_assertions) {
+ flags |= WriterFlags::DEBUG;
+ }
+ Options {
+ lang_version: (1, 0),
+ flags,
+ binding_map: BindingMap::default(),
+ capabilities: None,
+ bounds_check_policies: crate::proc::BoundsCheckPolicies::default(),
+ }
+ }
+}
+
+// A subset of options meant to be changed per pipeline.
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct PipelineOptions {
+ /// The stage of the entry point.
+ pub shader_stage: crate::ShaderStage,
+ /// The name of the entry point.
+ ///
+ /// If no entry point that matches is found while creating a [`Writer`], a error will be thrown.
+ pub entry_point: String,
+}
+
+pub fn write_vec(
+ module: &crate::Module,
+ info: &crate::valid::ModuleInfo,
+ options: &Options,
+ pipeline_options: Option<&PipelineOptions>,
+) -> Result<Vec<u32>, Error> {
+ let mut words = Vec::new();
+ let mut w = Writer::new(options)?;
+ w.write(module, info, pipeline_options, &mut words)?;
+ Ok(words)
+}
diff --git a/third_party/rust/naga/src/back/spv/recyclable.rs b/third_party/rust/naga/src/back/spv/recyclable.rs
new file mode 100644
index 0000000000..49f3a02741
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/recyclable.rs
@@ -0,0 +1,60 @@
+/*!
+Reusing collections' previous allocations.
+*/
+
+/// A value that can be reset to its initial state, retaining its current allocations.
+///
+/// Naga attempts to lower the cost of SPIR-V generation by allowing clients to
+/// reuse the same `Writer` for multiple Module translations. Reusing a `Writer`
+/// means that the `Vec`s, `HashMap`s, and other heap-allocated structures the
+/// `Writer` uses internally begin the translation with heap-allocated buffers
+/// ready to use.
+///
+/// But this approach introduces the risk of `Writer` state leaking from one
+/// module to the next. When a developer adds fields to `Writer` or its internal
+/// types, they must remember to reset their contents between modules.
+///
+/// One trick to ensure that every field has been accounted for is to use Rust's
+/// struct literal syntax to construct a new, reset value. If a developer adds a
+/// field, but neglects to update the reset code, the compiler will complain
+/// that a field is missing from the literal. This trait's `recycle` method
+/// takes `self` by value, and returns `Self` by value, encouraging the use of
+/// struct literal expressions in its implementation.
+pub trait Recyclable {
+ /// Clear `self`, retaining its current memory allocations.
+ ///
+ /// Shrink the buffer if it's currently much larger than was actually used.
+ /// This prevents a module with exceptionally large allocations from causing
+ /// the `Writer` to retain more memory than it needs indefinitely.
+ fn recycle(self) -> Self;
+}
+
+// Stock values for various collections.
+
+impl<T> Recyclable for Vec<T> {
+ fn recycle(mut self) -> Self {
+ self.clear();
+ self
+ }
+}
+
+impl<K, V, S: Clone> Recyclable for std::collections::HashMap<K, V, S> {
+ fn recycle(mut self) -> Self {
+ self.clear();
+ self
+ }
+}
+
+impl<K, S: Clone> Recyclable for std::collections::HashSet<K, S> {
+ fn recycle(mut self) -> Self {
+ self.clear();
+ self
+ }
+}
+
+impl<K: Ord, V> Recyclable for std::collections::BTreeMap<K, V> {
+ fn recycle(mut self) -> Self {
+ self.clear();
+ self
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/selection.rs b/third_party/rust/naga/src/back/spv/selection.rs
new file mode 100644
index 0000000000..788b1f10ab
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/selection.rs
@@ -0,0 +1,257 @@
+/*!
+Generate SPIR-V conditional structures.
+
+Builders for `if` structures with `and`s.
+
+The types in this module track the information needed to emit SPIR-V code
+for complex conditional structures, like those whose conditions involve
+short-circuiting 'and' and 'or' structures. These track labels and can emit
+`OpPhi` instructions to merge values produced along different paths.
+
+This currently only supports exactly the forms Naga uses, so it doesn't
+support `or` or `else`, and only supports zero or one merged values.
+
+Naga needs to emit code roughly like this:
+
+```ignore
+
+ value = DEFAULT;
+ if COND1 && COND2 {
+ value = THEN_VALUE;
+ }
+ // use value
+
+```
+
+Assuming `ctx` and `block` are a mutable references to a [`BlockContext`]
+and the current [`Block`], and `merge_type` is the SPIR-V type for the
+merged value `value`, we can build SPIR-V for the code above like so:
+
+```ignore
+
+ let cond = Selection::start(block, merge_type);
+ // ... compute `cond1` ...
+ cond.if_true(ctx, cond1, DEFAULT);
+ // ... compute `cond2` ...
+ cond.if_true(ctx, cond2, DEFAULT);
+ // ... compute THEN_VALUE
+ let merged_value = cond.finish(ctx, THEN_VALUE);
+
+```
+
+After this, `merged_value` is either `DEFAULT` or `THEN_VALUE`, depending on
+the path by which the merged block was reached.
+
+This takes care of writing all branch instructions, including an
+`OpSelectionMerge` annotation in the header block; starting new blocks and
+assigning them labels; and emitting the `OpPhi` that gathers together the
+right sources for the merged values, for every path through the selection
+construct.
+
+When there is no merged value to produce, you can pass `()` for `merge_type`
+and the merge values. In this case no `OpPhi` instructions are produced, and
+the `finish` method returns `()`.
+
+To enforce proper nesting, a `Selection` takes ownership of the `&mut Block`
+pointer for the duration of its lifetime. To obtain the block for generating
+code in the selection's body, call the `Selection::block` method.
+*/
+
+use super::{Block, BlockContext, Instruction};
+use spirv::Word;
+
+/// A private struct recording what we know about the selection construct so far.
+pub(super) struct Selection<'b, M: MergeTuple> {
+ /// The block pointer we're emitting code into.
+ block: &'b mut Block,
+
+ /// The label of the selection construct's merge block, or `None` if we
+ /// haven't yet written the `OpSelectionMerge` merge instruction.
+ merge_label: Option<Word>,
+
+ /// A set of `(VALUES, PARENT)` pairs, used to build `OpPhi` instructions in
+ /// the merge block. Each `PARENT` is the label of a predecessor block of
+ /// the merge block. The corresponding `VALUES` holds the ids of the values
+ /// that `PARENT` contributes to the merged values.
+ ///
+ /// We emit all branches to the merge block, so we know all its
+ /// predecessors. And we refuse to emit a branch unless we're given the
+ /// values the branching block contributes to the merge, so we always have
+ /// everything we need to emit the correct phis, by construction.
+ values: Vec<(M, Word)>,
+
+ /// The types of the values in each element of `values`.
+ merge_types: M,
+}
+
+impl<'b, M: MergeTuple> Selection<'b, M> {
+ /// Start a new selection construct.
+ ///
+ /// The `block` argument indicates the selection's header block.
+ ///
+ /// The `merge_types` argument should be a `Word` or tuple of `Word`s, each
+ /// value being the SPIR-V result type id of an `OpPhi` instruction that
+ /// will be written to the selection's merge block when this selection's
+ /// [`finish`] method is called. This argument may also be `()`, for
+ /// selections that produce no values.
+ ///
+ /// (This function writes no code to `block` itself; it simply constructs a
+ /// fresh `Selection`.)
+ ///
+ /// [`finish`]: Selection::finish
+ pub(super) fn start(block: &'b mut Block, merge_types: M) -> Self {
+ Selection {
+ block,
+ merge_label: None,
+ values: vec![],
+ merge_types,
+ }
+ }
+
+ pub(super) fn block(&mut self) -> &mut Block {
+ self.block
+ }
+
+ /// Branch to a successor block if `cond` is true, otherwise merge.
+ ///
+ /// If `cond` is false, branch to the merge block, using `values` as the
+ /// merged values. Otherwise, proceed to a new block.
+ ///
+ /// The `values` argument must be the same shape as the `merge_types`
+ /// argument passed to `Selection::start`.
+ pub(super) fn if_true(&mut self, ctx: &mut BlockContext, cond: Word, values: M) {
+ self.values.push((values, self.block.label_id));
+
+ let merge_label = self.make_merge_label(ctx);
+ let next_label = ctx.gen_id();
+ ctx.function.consume(
+ std::mem::replace(self.block, Block::new(next_label)),
+ Instruction::branch_conditional(cond, next_label, merge_label),
+ );
+ }
+
+ /// Emit an unconditional branch to the merge block, and compute merged
+ /// values.
+ ///
+ /// Use `final_values` as the merged values contributed by the current
+ /// block, and transition to the merge block, emitting `OpPhi` instructions
+ /// to produce the merged values. This must be the same shape as the
+ /// `merge_types` argument passed to [`Selection::start`].
+ ///
+ /// Return the SPIR-V ids of the merged values. This value has the same
+ /// shape as the `merge_types` argument passed to `Selection::start`.
+ pub(super) fn finish(self, ctx: &mut BlockContext, final_values: M) -> M {
+ match self {
+ Selection {
+ merge_label: None, ..
+ } => {
+ // We didn't actually emit any branches, so `self.values` must
+ // be empty, and `final_values` are the only sources we have for
+ // the merged values. Easy peasy.
+ final_values
+ }
+
+ Selection {
+ block,
+ merge_label: Some(merge_label),
+ mut values,
+ merge_types,
+ } => {
+ // Emit the final branch and transition to the merge block.
+ values.push((final_values, block.label_id));
+ ctx.function.consume(
+ std::mem::replace(block, Block::new(merge_label)),
+ Instruction::branch(merge_label),
+ );
+
+ // Now that we're in the merge block, build the phi instructions.
+ merge_types.write_phis(ctx, block, &values)
+ }
+ }
+ }
+
+ /// Return the id of the merge block, writing a merge instruction if needed.
+ fn make_merge_label(&mut self, ctx: &mut BlockContext) -> Word {
+ match self.merge_label {
+ None => {
+ let merge_label = ctx.gen_id();
+ self.block.body.push(Instruction::selection_merge(
+ merge_label,
+ spirv::SelectionControl::NONE,
+ ));
+ self.merge_label = Some(merge_label);
+ merge_label
+ }
+ Some(merge_label) => merge_label,
+ }
+ }
+}
+
+/// A trait to help `Selection` manage any number of merged values.
+///
+/// Some selection constructs, like a `ReadZeroSkipWrite` bounds check on a
+/// [`Load`] expression, produce a single merged value. Others produce no merged
+/// value, like a bounds check on a [`Store`] statement.
+///
+/// To let `Selection` work nicely with both cases, we let the merge type
+/// argument passed to [`Selection::start`] be any type that implements this
+/// `MergeTuple` trait. `MergeTuple` is then implemented for `()`, `Word`,
+/// `(Word, Word)`, and so on.
+///
+/// A `MergeTuple` type can represent either a bunch of SPIR-V types or values;
+/// the `merge_types` argument to `Selection::start` are type ids, whereas the
+/// `values` arguments to the [`if_true`] and [`finish`] methods are value ids.
+/// The set of merged value returned by `finish` is a tuple of value ids.
+///
+/// In fact, since Naga only uses zero- and single-valued selection constructs
+/// at present, we only implement `MergeTuple` for `()` and `Word`. But if you
+/// add more cases, feel free to add more implementations. Once const generics
+/// are available, we could have a single implementation of `MergeTuple` for all
+/// lengths of arrays, and be done with it.
+///
+/// [`Load`]: crate::Expression::Load
+/// [`Store`]: crate::Statement::Store
+/// [`if_true`]: Selection::if_true
+/// [`finish`]: Selection::finish
+pub(super) trait MergeTuple: Sized {
+ /// Write OpPhi instructions for the given set of predecessors.
+ ///
+ /// The `predecessors` vector should be a vector of `(LABEL, VALUES)` pairs,
+ /// where each `VALUES` holds the values contributed by the branch from
+ /// `LABEL`, which should be one of the current block's predecessors.
+ fn write_phis(
+ self,
+ ctx: &mut BlockContext,
+ block: &mut Block,
+ predecessors: &[(Self, Word)],
+ ) -> Self;
+}
+
+/// Selections that produce a single merged value.
+///
+/// For example, `ImageLoad` with `BoundsCheckPolicy::ReadZeroSkipWrite` either
+/// returns a texel value or zeros.
+impl MergeTuple for Word {
+ fn write_phis(
+ self,
+ ctx: &mut BlockContext,
+ block: &mut Block,
+ predecessors: &[(Word, Word)],
+ ) -> Word {
+ let merged_value = ctx.gen_id();
+ block
+ .body
+ .push(Instruction::phi(self, merged_value, predecessors));
+ merged_value
+ }
+}
+
+/// Selections that produce no merged values.
+///
+/// For example, `ImageStore` under `BoundsCheckPolicy::ReadZeroSkipWrite`
+/// either does the store or skips it, but in neither case does it produce a
+/// value.
+impl MergeTuple for () {
+ /// No phis need to be generated.
+ fn write_phis(self, _: &mut BlockContext, _: &mut Block, _: &[((), Word)]) {}
+}
diff --git a/third_party/rust/naga/src/back/spv/writer.rs b/third_party/rust/naga/src/back/spv/writer.rs
new file mode 100644
index 0000000000..59fe739f2c
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/writer.rs
@@ -0,0 +1,1695 @@
+use super::{
+ helpers::{contains_builtin, global_needs_wrapper, map_storage_class},
+ make_local, Block, BlockContext, CachedExpressions, EntryPointContext, Error, Function,
+ FunctionArgument, GlobalVariable, IdGenerator, Instruction, LocalType, LocalVariable,
+ LogicalLayout, LookupFunctionType, LookupType, LoopContext, Options, PhysicalLayout,
+ PipelineOptions, ResultMember, Writer, WriterFlags, BITS_PER_BYTE,
+};
+use crate::{
+ arena::{Handle, UniqueArena},
+ back::spv::BindingInfo,
+ proc::{Alignment, TypeResolution},
+ valid::{FunctionInfo, ModuleInfo},
+};
+use spirv::Word;
+use std::collections::hash_map::Entry;
+
+struct FunctionInterface<'a> {
+ varying_ids: &'a mut Vec<Word>,
+ stage: crate::ShaderStage,
+}
+
+impl Function {
+ fn to_words(&self, sink: &mut impl Extend<Word>) {
+ self.signature.as_ref().unwrap().to_words(sink);
+ for argument in self.parameters.iter() {
+ argument.instruction.to_words(sink);
+ }
+ for (index, block) in self.blocks.iter().enumerate() {
+ Instruction::label(block.label_id).to_words(sink);
+ if index == 0 {
+ for local_var in self.variables.values() {
+ local_var.instruction.to_words(sink);
+ }
+ }
+ for instruction in block.body.iter() {
+ instruction.to_words(sink);
+ }
+ }
+ }
+}
+
+impl Writer {
+ pub fn new(options: &Options) -> Result<Self, Error> {
+ let (major, minor) = options.lang_version;
+ if major != 1 {
+ return Err(Error::UnsupportedVersion(major, minor));
+ }
+ let raw_version = ((major as u32) << 16) | ((minor as u32) << 8);
+
+ let mut capabilities_used = crate::FastHashSet::default();
+ capabilities_used.insert(spirv::Capability::Shader);
+
+ let mut id_gen = IdGenerator::default();
+ let gl450_ext_inst_id = id_gen.next();
+ let void_type = id_gen.next();
+
+ Ok(Writer {
+ physical_layout: PhysicalLayout::new(raw_version),
+ logical_layout: LogicalLayout::default(),
+ id_gen,
+ capabilities_available: options.capabilities.clone(),
+ capabilities_used,
+ extensions_used: crate::FastHashSet::default(),
+ debugs: vec![],
+ annotations: vec![],
+ flags: options.flags,
+ bounds_check_policies: options.bounds_check_policies,
+ void_type,
+ lookup_type: crate::FastHashMap::default(),
+ lookup_function: crate::FastHashMap::default(),
+ lookup_function_type: crate::FastHashMap::default(),
+ constant_ids: Vec::new(),
+ cached_constants: crate::FastHashMap::default(),
+ global_variables: Vec::new(),
+ binding_map: options.binding_map.clone(),
+ saved_cached: CachedExpressions::default(),
+ gl450_ext_inst_id,
+ temp_list: Vec::new(),
+ })
+ }
+
+ /// Reset `Writer` to its initial state, retaining any allocations.
+ ///
+ /// Why not just implement `Recyclable` for `Writer`? By design,
+ /// `Recyclable::recycle` requires ownership of the value, not just
+ /// `&mut`; see the trait documentation. But we need to use this method
+ /// from functions like `Writer::write`, which only have `&mut Writer`.
+ /// Workarounds include unsafe code (`std::ptr::read`, then `write`, ugh)
+ /// or something like a `Default` impl that returns an oddly-initialized
+ /// `Writer`, which is worse.
+ fn reset(&mut self) {
+ use super::recyclable::Recyclable;
+ use std::mem::take;
+
+ let mut id_gen = IdGenerator::default();
+ let gl450_ext_inst_id = id_gen.next();
+ let void_type = id_gen.next();
+
+ // Every field of the old writer that is not determined by the `Options`
+ // passed to `Writer::new` should be reset somehow.
+ let fresh = Writer {
+ // Copied from the old Writer:
+ flags: self.flags,
+ bounds_check_policies: self.bounds_check_policies,
+ capabilities_available: take(&mut self.capabilities_available),
+ binding_map: take(&mut self.binding_map),
+
+ // Initialized afresh:
+ id_gen,
+ void_type,
+ gl450_ext_inst_id,
+
+ // Recycled:
+ capabilities_used: take(&mut self.capabilities_used).recycle(),
+ extensions_used: take(&mut self.extensions_used).recycle(),
+ physical_layout: self.physical_layout.clone().recycle(),
+ logical_layout: take(&mut self.logical_layout).recycle(),
+ debugs: take(&mut self.debugs).recycle(),
+ annotations: take(&mut self.annotations).recycle(),
+ lookup_type: take(&mut self.lookup_type).recycle(),
+ lookup_function: take(&mut self.lookup_function).recycle(),
+ lookup_function_type: take(&mut self.lookup_function_type).recycle(),
+ constant_ids: take(&mut self.constant_ids).recycle(),
+ cached_constants: take(&mut self.cached_constants).recycle(),
+ global_variables: take(&mut self.global_variables).recycle(),
+ saved_cached: take(&mut self.saved_cached).recycle(),
+ temp_list: take(&mut self.temp_list).recycle(),
+ };
+
+ *self = fresh;
+
+ self.capabilities_used.insert(spirv::Capability::Shader);
+ }
+
+ /// Indicate that the code requires any one of the listed capabilities.
+ ///
+ /// If nothing in `capabilities` appears in the available capabilities
+ /// specified in the [`Options`] from which this `Writer` was created,
+ /// return an error. The `what` string is used in the error message to
+ /// explain what provoked the requirement. (If no available capabilities were
+ /// given, assume everything is available.)
+ ///
+ /// The first acceptable capability will be added to this `Writer`'s
+ /// [`capabilities_used`] table, and an `OpCapability` emitted for it in the
+ /// result. For this reason, more specific capabilities should be listed
+ /// before more general.
+ ///
+ /// [`capabilities_used`]: Writer::capabilities_used
+ pub(super) fn require_any(
+ &mut self,
+ what: &'static str,
+ capabilities: &[spirv::Capability],
+ ) -> Result<(), Error> {
+ match *capabilities {
+ [] => Ok(()),
+ [first, ..] => {
+ // Find the first acceptable capability, or return an error if
+ // there is none.
+ let selected = match self.capabilities_available {
+ None => first,
+ Some(ref available) => {
+ match capabilities.iter().find(|cap| available.contains(cap)) {
+ Some(&cap) => cap,
+ None => {
+ return Err(Error::MissingCapabilities(what, capabilities.to_vec()))
+ }
+ }
+ }
+ };
+ self.capabilities_used.insert(selected);
+ Ok(())
+ }
+ }
+ }
+
+ /// Indicate that the code uses the given extension.
+ pub(super) fn use_extension(&mut self, extension: &'static str) {
+ self.extensions_used.insert(extension);
+ }
+
+ pub(super) fn get_type_id(&mut self, lookup_ty: LookupType) -> Word {
+ match self.lookup_type.entry(lookup_ty) {
+ Entry::Occupied(e) => *e.get(),
+ Entry::Vacant(e) => {
+ let local = match lookup_ty {
+ LookupType::Handle(_handle) => unreachable!("Handles are populated at start"),
+ LookupType::Local(local) => local,
+ };
+
+ let id = self.id_gen.next();
+ e.insert(id);
+ self.write_type_declaration_local(id, local);
+ id
+ }
+ }
+ }
+
+ pub(super) fn get_expression_type_id(&mut self, tr: &TypeResolution) -> Word {
+ let lookup_ty = match *tr {
+ TypeResolution::Handle(ty_handle) => LookupType::Handle(ty_handle),
+ TypeResolution::Value(ref inner) => LookupType::Local(make_local(inner).unwrap()),
+ };
+ self.get_type_id(lookup_ty)
+ }
+
+ pub(super) fn get_pointer_id(
+ &mut self,
+ arena: &UniqueArena<crate::Type>,
+ handle: Handle<crate::Type>,
+ class: spirv::StorageClass,
+ ) -> Result<Word, Error> {
+ let ty_id = self.get_type_id(LookupType::Handle(handle));
+ if let crate::TypeInner::Pointer { .. } = arena[handle].inner {
+ return Ok(ty_id);
+ }
+ let lookup_type = LookupType::Local(LocalType::Pointer {
+ base: handle,
+ class,
+ });
+ Ok(if let Some(&id) = self.lookup_type.get(&lookup_type) {
+ id
+ } else {
+ let id = self.id_gen.next();
+ let instruction = Instruction::type_pointer(id, class, ty_id);
+ instruction.to_words(&mut self.logical_layout.declarations);
+ self.lookup_type.insert(lookup_type, id);
+ id
+ })
+ }
+
+ pub(super) fn get_uint_type_id(&mut self) -> Word {
+ let local_type = LocalType::Value {
+ vector_size: None,
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ pointer_space: None,
+ };
+ self.get_type_id(local_type.into())
+ }
+
+ pub(super) fn get_float_type_id(&mut self) -> Word {
+ let local_type = LocalType::Value {
+ vector_size: None,
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ pointer_space: None,
+ };
+ self.get_type_id(local_type.into())
+ }
+
+ pub(super) fn get_float_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
+ let lookup_type = LookupType::Local(LocalType::Value {
+ vector_size: None,
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ pointer_space: Some(class),
+ });
+ if let Some(&id) = self.lookup_type.get(&lookup_type) {
+ id
+ } else {
+ let id = self.id_gen.next();
+ let ty_id = self.get_float_type_id();
+ let instruction = Instruction::type_pointer(id, class, ty_id);
+ instruction.to_words(&mut self.logical_layout.declarations);
+ self.lookup_type.insert(lookup_type, id);
+ id
+ }
+ }
+
+ pub(super) fn get_bool_type_id(&mut self) -> Word {
+ let local_type = LocalType::Value {
+ vector_size: None,
+ kind: crate::ScalarKind::Bool,
+ width: 1,
+ pointer_space: None,
+ };
+ self.get_type_id(local_type.into())
+ }
+
+ pub(super) fn decorate(&mut self, id: Word, decoration: spirv::Decoration, operands: &[Word]) {
+ self.annotations
+ .push(Instruction::decorate(id, decoration, operands));
+ }
+
+ fn write_function(
+ &mut self,
+ ir_function: &crate::Function,
+ info: &FunctionInfo,
+ ir_module: &crate::Module,
+ mut interface: Option<FunctionInterface>,
+ ) -> Result<Word, Error> {
+ let mut function = Function::default();
+
+ for (handle, variable) in ir_function.local_variables.iter() {
+ let id = self.id_gen.next();
+
+ if self.flags.contains(WriterFlags::DEBUG) {
+ if let Some(ref name) = variable.name {
+ self.debugs.push(Instruction::name(id, name));
+ }
+ }
+
+ let init_word = variable
+ .init
+ .map(|constant| self.constant_ids[constant.index()]);
+ let pointer_type_id =
+ self.get_pointer_id(&ir_module.types, variable.ty, spirv::StorageClass::Function)?;
+ let instruction = Instruction::variable(
+ pointer_type_id,
+ id,
+ spirv::StorageClass::Function,
+ init_word.or_else(|| {
+ let type_id = self.get_type_id(LookupType::Handle(variable.ty));
+ Some(self.write_constant_null(type_id))
+ }),
+ );
+ function
+ .variables
+ .insert(handle, LocalVariable { id, instruction });
+ }
+
+ let prelude_id = self.id_gen.next();
+ let mut prelude = Block::new(prelude_id);
+ let mut ep_context = EntryPointContext {
+ argument_ids: Vec::new(),
+ results: Vec::new(),
+ };
+
+ let mut parameter_type_ids = Vec::with_capacity(ir_function.arguments.len());
+ for argument in ir_function.arguments.iter() {
+ let class = spirv::StorageClass::Input;
+ let handle_ty = ir_module.types[argument.ty].inner.is_handle();
+ let argument_type_id = match handle_ty {
+ true => self.get_pointer_id(
+ &ir_module.types,
+ argument.ty,
+ spirv::StorageClass::UniformConstant,
+ )?,
+ false => self.get_type_id(LookupType::Handle(argument.ty)),
+ };
+
+ if let Some(ref mut iface) = interface {
+ let id = if let Some(ref binding) = argument.binding {
+ let name = argument.name.as_deref();
+
+ let varying_id = self.write_varying(
+ ir_module,
+ iface.stage,
+ class,
+ name,
+ argument.ty,
+ binding,
+ )?;
+ iface.varying_ids.push(varying_id);
+ let id = self.id_gen.next();
+ prelude
+ .body
+ .push(Instruction::load(argument_type_id, id, varying_id, None));
+ id
+ } else if let crate::TypeInner::Struct { ref members, .. } =
+ ir_module.types[argument.ty].inner
+ {
+ let struct_id = self.id_gen.next();
+ let mut constituent_ids = Vec::with_capacity(members.len());
+ for member in members {
+ let type_id = self.get_type_id(LookupType::Handle(member.ty));
+ let name = member.name.as_deref();
+ let binding = member.binding.as_ref().unwrap();
+ let varying_id = self.write_varying(
+ ir_module,
+ iface.stage,
+ class,
+ name,
+ member.ty,
+ binding,
+ )?;
+ iface.varying_ids.push(varying_id);
+ let id = self.id_gen.next();
+ prelude
+ .body
+ .push(Instruction::load(type_id, id, varying_id, None));
+ constituent_ids.push(id);
+ }
+ prelude.body.push(Instruction::composite_construct(
+ argument_type_id,
+ struct_id,
+ &constituent_ids,
+ ));
+ struct_id
+ } else {
+ unreachable!("Missing argument binding on an entry point");
+ };
+ ep_context.argument_ids.push(id);
+ } else {
+ let argument_id = self.id_gen.next();
+ let instruction = Instruction::function_parameter(argument_type_id, argument_id);
+ if self.flags.contains(WriterFlags::DEBUG) {
+ if let Some(ref name) = argument.name {
+ self.debugs.push(Instruction::name(argument_id, name));
+ }
+ }
+ function.parameters.push(FunctionArgument {
+ instruction,
+ handle_id: if handle_ty {
+ let id = self.id_gen.next();
+ prelude.body.push(Instruction::load(
+ self.get_type_id(LookupType::Handle(argument.ty)),
+ id,
+ argument_id,
+ None,
+ ));
+ id
+ } else {
+ 0
+ },
+ });
+ parameter_type_ids.push(argument_type_id);
+ };
+ }
+
+ let return_type_id = match ir_function.result {
+ Some(ref result) => {
+ if let Some(ref mut iface) = interface {
+ let mut has_point_size = false;
+ let class = spirv::StorageClass::Output;
+ if let Some(ref binding) = result.binding {
+ has_point_size |=
+ *binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize);
+ let type_id = self.get_type_id(LookupType::Handle(result.ty));
+ let varying_id = self.write_varying(
+ ir_module,
+ iface.stage,
+ class,
+ None,
+ result.ty,
+ binding,
+ )?;
+ iface.varying_ids.push(varying_id);
+ ep_context.results.push(ResultMember {
+ id: varying_id,
+ type_id,
+ built_in: binding.to_built_in(),
+ });
+ } else if let crate::TypeInner::Struct { ref members, .. } =
+ ir_module.types[result.ty].inner
+ {
+ for member in members {
+ let type_id = self.get_type_id(LookupType::Handle(member.ty));
+ let name = member.name.as_deref();
+ let binding = member.binding.as_ref().unwrap();
+ has_point_size |=
+ *binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize);
+ let varying_id = self.write_varying(
+ ir_module,
+ iface.stage,
+ class,
+ name,
+ member.ty,
+ binding,
+ )?;
+ iface.varying_ids.push(varying_id);
+ ep_context.results.push(ResultMember {
+ id: varying_id,
+ type_id,
+ built_in: binding.to_built_in(),
+ });
+ }
+ } else {
+ unreachable!("Missing result binding on an entry point");
+ }
+
+ if self.flags.contains(WriterFlags::FORCE_POINT_SIZE)
+ && iface.stage == crate::ShaderStage::Vertex
+ && !has_point_size
+ {
+ // add point size artificially
+ let varying_id = self.id_gen.next();
+ let pointer_type_id = self.get_float_pointer_type_id(class);
+ Instruction::variable(pointer_type_id, varying_id, class, None)
+ .to_words(&mut self.logical_layout.declarations);
+ self.decorate(
+ varying_id,
+ spirv::Decoration::BuiltIn,
+ &[spirv::BuiltIn::PointSize as u32],
+ );
+ iface.varying_ids.push(varying_id);
+
+ let default_value_id =
+ self.get_constant_scalar(crate::ScalarValue::Float(1.0), 4);
+ prelude
+ .body
+ .push(Instruction::store(varying_id, default_value_id, None));
+ }
+ self.void_type
+ } else {
+ self.get_type_id(LookupType::Handle(result.ty))
+ }
+ }
+ None => self.void_type,
+ };
+
+ let lookup_function_type = LookupFunctionType {
+ parameter_type_ids,
+ return_type_id,
+ };
+
+ let function_id = self.id_gen.next();
+ if self.flags.contains(WriterFlags::DEBUG) {
+ if let Some(ref name) = ir_function.name {
+ self.debugs.push(Instruction::name(function_id, name));
+ }
+ }
+
+ let function_type = self.get_function_type(lookup_function_type);
+ function.signature = Some(Instruction::function(
+ return_type_id,
+ function_id,
+ spirv::FunctionControl::empty(),
+ function_type,
+ ));
+
+ if interface.is_some() {
+ function.entry_point_context = Some(ep_context);
+ }
+
+ // fill up the `GlobalVariable::access_id`
+ for gv in self.global_variables.iter_mut() {
+ gv.reset_for_function();
+ }
+ for (handle, var) in ir_module.global_variables.iter() {
+ if info[handle].is_empty() {
+ continue;
+ }
+
+ let mut gv = self.global_variables[handle.index()].clone();
+
+ // Handle globals are pre-emitted and should be loaded automatically.
+ //
+ // Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing.
+ let is_binding_array = match ir_module.types[var.ty].inner {
+ crate::TypeInner::BindingArray { .. } => true,
+ _ => false,
+ };
+
+ if var.space == crate::AddressSpace::Handle && !is_binding_array {
+ let var_type_id = self.get_type_id(LookupType::Handle(var.ty));
+ let id = self.id_gen.next();
+ prelude
+ .body
+ .push(Instruction::load(var_type_id, id, gv.var_id, None));
+ gv.access_id = gv.var_id;
+ gv.handle_id = id;
+ } else if global_needs_wrapper(ir_module, var) {
+ let class = map_storage_class(var.space);
+ let pointer_type_id = self.get_pointer_id(&ir_module.types, var.ty, class)?;
+ let index_id = self.get_index_constant(0);
+
+ let id = self.id_gen.next();
+ prelude.body.push(Instruction::access_chain(
+ pointer_type_id,
+ id,
+ gv.var_id,
+ &[index_id],
+ ));
+ gv.access_id = id;
+ } else {
+ // by default, the variable ID is accessed as is
+ gv.access_id = gv.var_id;
+ };
+
+ // work around borrow checking in the presence of `self.xxx()` calls
+ self.global_variables[handle.index()] = gv;
+ }
+
+ // Create a `BlockContext` for generating SPIR-V for the function's
+ // body.
+ let mut context = BlockContext {
+ ir_module,
+ ir_function,
+ fun_info: info,
+ function: &mut function,
+ // Re-use the cached expression table from prior functions.
+ cached: std::mem::take(&mut self.saved_cached),
+
+ // Steal the Writer's temp list for a bit.
+ temp_list: std::mem::take(&mut self.temp_list),
+ writer: self,
+ };
+
+ // fill up the pre-emitted expressions
+ context.cached.reset(ir_function.expressions.len());
+ for (handle, expr) in ir_function.expressions.iter() {
+ if expr.needs_pre_emit() {
+ context.cache_expression_value(handle, &mut prelude)?;
+ }
+ }
+
+ let main_id = context.gen_id();
+ context
+ .function
+ .consume(prelude, Instruction::branch(main_id));
+ context.write_block(
+ main_id,
+ &ir_function.body,
+ super::block::BlockExit::Return,
+ LoopContext::default(),
+ )?;
+
+ // Consume the `BlockContext`, ending its borrows and letting the
+ // `Writer` steal back its cached expression table and temp_list.
+ let BlockContext {
+ cached, temp_list, ..
+ } = context;
+ self.saved_cached = cached;
+ self.temp_list = temp_list;
+
+ function.to_words(&mut self.logical_layout.function_definitions);
+ Instruction::function_end().to_words(&mut self.logical_layout.function_definitions);
+
+ Ok(function_id)
+ }
+
+ fn write_execution_mode(
+ &mut self,
+ function_id: Word,
+ mode: spirv::ExecutionMode,
+ ) -> Result<(), Error> {
+ //self.check(mode.required_capabilities())?;
+ Instruction::execution_mode(function_id, mode, &[])
+ .to_words(&mut self.logical_layout.execution_modes);
+ Ok(())
+ }
+
+ // TODO Move to instructions module
+ fn write_entry_point(
+ &mut self,
+ entry_point: &crate::EntryPoint,
+ info: &FunctionInfo,
+ ir_module: &crate::Module,
+ ) -> Result<Instruction, Error> {
+ let mut interface_ids = Vec::new();
+ let function_id = self.write_function(
+ &entry_point.function,
+ info,
+ ir_module,
+ Some(FunctionInterface {
+ varying_ids: &mut interface_ids,
+ stage: entry_point.stage,
+ }),
+ )?;
+
+ let exec_model = match entry_point.stage {
+ crate::ShaderStage::Vertex => spirv::ExecutionModel::Vertex,
+ crate::ShaderStage::Fragment => {
+ self.write_execution_mode(function_id, spirv::ExecutionMode::OriginUpperLeft)?;
+ if let Some(ref result) = entry_point.function.result {
+ if contains_builtin(
+ result.binding.as_ref(),
+ result.ty,
+ &ir_module.types,
+ crate::BuiltIn::FragDepth,
+ ) {
+ self.write_execution_mode(
+ function_id,
+ spirv::ExecutionMode::DepthReplacing,
+ )?;
+ }
+ }
+ spirv::ExecutionModel::Fragment
+ }
+ crate::ShaderStage::Compute => {
+ let execution_mode = spirv::ExecutionMode::LocalSize;
+ //self.check(execution_mode.required_capabilities())?;
+ Instruction::execution_mode(
+ function_id,
+ execution_mode,
+ &entry_point.workgroup_size,
+ )
+ .to_words(&mut self.logical_layout.execution_modes);
+ spirv::ExecutionModel::GLCompute
+ }
+ };
+ //self.check(exec_model.required_capabilities())?;
+
+ Ok(Instruction::entry_point(
+ exec_model,
+ function_id,
+ &entry_point.name,
+ interface_ids.as_slice(),
+ ))
+ }
+
+ fn make_scalar(
+ &mut self,
+ id: Word,
+ kind: crate::ScalarKind,
+ width: crate::Bytes,
+ ) -> Instruction {
+ use crate::ScalarKind as Sk;
+
+ let bits = (width * BITS_PER_BYTE) as u32;
+ match kind {
+ Sk::Sint | Sk::Uint => {
+ let signedness = if kind == Sk::Sint {
+ super::instructions::Signedness::Signed
+ } else {
+ super::instructions::Signedness::Unsigned
+ };
+ let cap = match bits {
+ 8 => Some(spirv::Capability::Int8),
+ 16 => Some(spirv::Capability::Int16),
+ 64 => Some(spirv::Capability::Int64),
+ _ => None,
+ };
+ if let Some(cap) = cap {
+ self.capabilities_used.insert(cap);
+ }
+ Instruction::type_int(id, bits, signedness)
+ }
+ Sk::Float => {
+ if bits == 64 {
+ self.capabilities_used.insert(spirv::Capability::Float64);
+ }
+ Instruction::type_float(id, bits)
+ }
+ Sk::Bool => Instruction::type_bool(id),
+ }
+ }
+
+ fn request_image_capabilities(&mut self, inner: &crate::TypeInner) -> Result<(), Error> {
+ if let crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } = *inner
+ {
+ let sampled = match class {
+ crate::ImageClass::Sampled { .. } => true,
+ crate::ImageClass::Depth { .. } => true,
+ crate::ImageClass::Storage { format, .. } => {
+ self.request_image_format_capabilities(format.into())?;
+ false
+ }
+ };
+
+ match dim {
+ crate::ImageDimension::D1 => {
+ if sampled {
+ self.require_any("sampled 1D images", &[spirv::Capability::Sampled1D])?;
+ } else {
+ self.require_any("1D storage images", &[spirv::Capability::Image1D])?;
+ }
+ }
+ crate::ImageDimension::Cube if arrayed => {
+ if sampled {
+ self.require_any(
+ "sampled cube array images",
+ &[spirv::Capability::SampledCubeArray],
+ )?;
+ } else {
+ self.require_any(
+ "cube array storage images",
+ &[spirv::Capability::ImageCubeArray],
+ )?;
+ }
+ }
+ _ => {}
+ }
+ }
+
+ Ok(())
+ }
+
+ fn write_type_declaration_local(&mut self, id: Word, local_ty: LocalType) {
+ let instruction = match local_ty {
+ LocalType::Value {
+ vector_size: None,
+ kind,
+ width,
+ pointer_space: None,
+ } => self.make_scalar(id, kind, width),
+ LocalType::Value {
+ vector_size: Some(size),
+ kind,
+ width,
+ pointer_space: None,
+ } => {
+ let scalar_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ kind,
+ width,
+ pointer_space: None,
+ }));
+ Instruction::type_vector(id, scalar_id, size)
+ }
+ LocalType::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ let vector_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(rows),
+ kind: crate::ScalarKind::Float,
+ width,
+ pointer_space: None,
+ }));
+ Instruction::type_matrix(id, vector_id, columns)
+ }
+ LocalType::Pointer { base, class } => {
+ let type_id = self.get_type_id(LookupType::Handle(base));
+ Instruction::type_pointer(id, class, type_id)
+ }
+ LocalType::Value {
+ vector_size,
+ kind,
+ width,
+ pointer_space: Some(class),
+ } => {
+ let type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size,
+ kind,
+ width,
+ pointer_space: None,
+ }));
+ Instruction::type_pointer(id, class, type_id)
+ }
+ LocalType::Image(image) => {
+ let local_type = LocalType::Value {
+ vector_size: None,
+ kind: image.sampled_type,
+ width: 4,
+ pointer_space: None,
+ };
+ let type_id = self.get_type_id(LookupType::Local(local_type));
+ Instruction::type_image(id, type_id, image.dim, image.flags, image.image_format)
+ }
+ LocalType::Sampler => Instruction::type_sampler(id),
+ LocalType::SampledImage { image_type_id } => {
+ Instruction::type_sampled_image(id, image_type_id)
+ }
+ LocalType::BindingArray { base, size } => {
+ let inner_ty = self.get_type_id(LookupType::Handle(base));
+ let scalar_id = self.get_constant_scalar(crate::ScalarValue::Uint(size), 4);
+ Instruction::type_array(id, inner_ty, scalar_id)
+ }
+ LocalType::PointerToBindingArray { base, size } => {
+ let inner_ty =
+ self.get_type_id(LookupType::Local(LocalType::BindingArray { base, size }));
+ Instruction::type_pointer(id, spirv::StorageClass::UniformConstant, inner_ty)
+ }
+ };
+
+ instruction.to_words(&mut self.logical_layout.declarations);
+ }
+
+ fn write_type_declaration_arena(
+ &mut self,
+ arena: &UniqueArena<crate::Type>,
+ handle: Handle<crate::Type>,
+ ) -> Result<Word, Error> {
+ let ty = &arena[handle];
+ let id = if let Some(local) = make_local(&ty.inner) {
+ // This type can be represented as a `LocalType`, so check if we've
+ // already written an instruction for it. If not, do so now, with
+ // `write_type_declaration_local`.
+ match self.lookup_type.entry(LookupType::Local(local)) {
+ // We already have an id for this `LocalType`.
+ Entry::Occupied(e) => *e.get(),
+
+ // It's a type we haven't seen before.
+ Entry::Vacant(e) => {
+ let id = self.id_gen.next();
+ e.insert(id);
+
+ self.write_type_declaration_local(id, local);
+
+ // If it's an image type, request SPIR-V capabilities here, so
+ // write_type_declaration_local can stay infallible.
+ self.request_image_capabilities(&ty.inner)?;
+
+ id
+ }
+ }
+ } else {
+ use spirv::Decoration;
+
+ let id = self.id_gen.next();
+ let instruction = match ty.inner {
+ crate::TypeInner::Array { base, size, stride } => {
+ self.decorate(id, Decoration::ArrayStride, &[stride]);
+
+ let type_id = self.get_type_id(LookupType::Handle(base));
+ match size {
+ crate::ArraySize::Constant(const_handle) => {
+ let length_id = self.constant_ids[const_handle.index()];
+ Instruction::type_array(id, type_id, length_id)
+ }
+ crate::ArraySize::Dynamic => Instruction::type_runtime_array(id, type_id),
+ }
+ }
+ crate::TypeInner::BindingArray { base, size } => {
+ let type_id = self.get_type_id(LookupType::Handle(base));
+ match size {
+ crate::ArraySize::Constant(const_handle) => {
+ let length_id = self.constant_ids[const_handle.index()];
+ Instruction::type_array(id, type_id, length_id)
+ }
+ crate::ArraySize::Dynamic => Instruction::type_runtime_array(id, type_id),
+ }
+ }
+ crate::TypeInner::Struct {
+ ref members,
+ span: _,
+ } => {
+ let mut member_ids = Vec::with_capacity(members.len());
+ for (index, member) in members.iter().enumerate() {
+ self.decorate_struct_member(id, index, member, arena)?;
+ let member_id = self.get_type_id(LookupType::Handle(member.ty));
+ member_ids.push(member_id);
+ }
+ Instruction::type_struct(id, member_ids.as_slice())
+ }
+
+ // These all have TypeLocal representations, so they should have been
+ // handled by `write_type_declaration_local` above.
+ crate::TypeInner::Scalar { .. }
+ | crate::TypeInner::Atomic { .. }
+ | crate::TypeInner::Vector { .. }
+ | crate::TypeInner::Matrix { .. }
+ | crate::TypeInner::Pointer { .. }
+ | crate::TypeInner::ValuePointer { .. }
+ | crate::TypeInner::Image { .. }
+ | crate::TypeInner::Sampler { .. } => unreachable!(),
+ };
+
+ instruction.to_words(&mut self.logical_layout.declarations);
+ id
+ };
+
+ // Add this handle as a new alias for that type.
+ self.lookup_type.insert(LookupType::Handle(handle), id);
+
+ if self.flags.contains(WriterFlags::DEBUG) {
+ if let Some(ref name) = ty.name {
+ self.debugs.push(Instruction::name(id, name));
+ }
+ }
+
+ Ok(id)
+ }
+
+ fn request_image_format_capabilities(
+ &mut self,
+ format: spirv::ImageFormat,
+ ) -> Result<(), Error> {
+ use spirv::ImageFormat as If;
+ match format {
+ If::Rg32f
+ | If::Rg16f
+ | If::R11fG11fB10f
+ | If::R16f
+ | If::Rgba16
+ | If::Rgb10A2
+ | If::Rg16
+ | If::Rg8
+ | If::R16
+ | If::R8
+ | If::Rgba16Snorm
+ | If::Rg16Snorm
+ | If::Rg8Snorm
+ | If::R16Snorm
+ | If::R8Snorm
+ | If::Rg32i
+ | If::Rg16i
+ | If::Rg8i
+ | If::R16i
+ | If::R8i
+ | If::Rgb10a2ui
+ | If::Rg32ui
+ | If::Rg16ui
+ | If::Rg8ui
+ | If::R16ui
+ | If::R8ui => self.require_any(
+ "storage image format",
+ &[spirv::Capability::StorageImageExtendedFormats],
+ ),
+ If::R64ui | If::R64i => self.require_any(
+ "64-bit integer storage image format",
+ &[spirv::Capability::Int64ImageEXT],
+ ),
+ If::Unknown
+ | If::Rgba32f
+ | If::Rgba16f
+ | If::R32f
+ | If::Rgba8
+ | If::Rgba8Snorm
+ | If::Rgba32i
+ | If::Rgba16i
+ | If::Rgba8i
+ | If::R32i
+ | If::Rgba32ui
+ | If::Rgba16ui
+ | If::Rgba8ui
+ | If::R32ui => Ok(()),
+ }
+ }
+
+ pub(super) fn get_index_constant(&mut self, index: Word) -> Word {
+ self.get_constant_scalar(crate::ScalarValue::Uint(index as _), 4)
+ }
+
+ pub(super) fn get_constant_scalar(
+ &mut self,
+ value: crate::ScalarValue,
+ width: crate::Bytes,
+ ) -> Word {
+ if let Some(&id) = self.cached_constants.get(&(value, width)) {
+ return id;
+ }
+ let id = self.id_gen.next();
+ self.write_constant_scalar(id, &value, width, None);
+ self.cached_constants.insert((value, width), id);
+ id
+ }
+
+ fn write_constant_scalar(
+ &mut self,
+ id: Word,
+ value: &crate::ScalarValue,
+ width: crate::Bytes,
+ debug_name: Option<&String>,
+ ) {
+ if self.flags.contains(WriterFlags::DEBUG) {
+ if let Some(name) = debug_name {
+ self.debugs.push(Instruction::name(id, name));
+ }
+ }
+ let type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ kind: value.scalar_kind(),
+ width,
+ pointer_space: None,
+ }));
+ let (solo, pair);
+ let instruction = match *value {
+ crate::ScalarValue::Sint(val) => {
+ let words = match width {
+ 4 => {
+ solo = [val as u32];
+ &solo[..]
+ }
+ 8 => {
+ pair = [val as u32, (val >> 32) as u32];
+ &pair
+ }
+ _ => unreachable!(),
+ };
+ Instruction::constant(type_id, id, words)
+ }
+ crate::ScalarValue::Uint(val) => {
+ let words = match width {
+ 4 => {
+ solo = [val as u32];
+ &solo[..]
+ }
+ 8 => {
+ pair = [val as u32, (val >> 32) as u32];
+ &pair
+ }
+ _ => unreachable!(),
+ };
+ Instruction::constant(type_id, id, words)
+ }
+ crate::ScalarValue::Float(val) => {
+ let words = match width {
+ 4 => {
+ solo = [(val as f32).to_bits()];
+ &solo[..]
+ }
+ 8 => {
+ let bits = f64::to_bits(val);
+ pair = [bits as u32, (bits >> 32) as u32];
+ &pair
+ }
+ _ => unreachable!(),
+ };
+ Instruction::constant(type_id, id, words)
+ }
+ crate::ScalarValue::Bool(true) => Instruction::constant_true(type_id, id),
+ crate::ScalarValue::Bool(false) => Instruction::constant_false(type_id, id),
+ };
+
+ instruction.to_words(&mut self.logical_layout.declarations);
+ }
+
+ fn write_constant_composite(
+ &mut self,
+ id: Word,
+ ty: Handle<crate::Type>,
+ components: &[Handle<crate::Constant>],
+ ) -> Result<(), Error> {
+ let mut constituent_ids = Vec::with_capacity(components.len());
+ for constituent in components.iter() {
+ let constituent_id = self.constant_ids[constituent.index()];
+ constituent_ids.push(constituent_id);
+ }
+
+ let type_id = self.get_type_id(LookupType::Handle(ty));
+ Instruction::constant_composite(type_id, id, constituent_ids.as_slice())
+ .to_words(&mut self.logical_layout.declarations);
+ Ok(())
+ }
+
+ pub(super) fn write_constant_null(&mut self, type_id: Word) -> Word {
+ let null_id = self.id_gen.next();
+ Instruction::constant_null(type_id, null_id)
+ .to_words(&mut self.logical_layout.declarations);
+ null_id
+ }
+
+ /// Generate an `OpVariable` for one value in an [`EntryPoint`]'s IO interface.
+ ///
+ /// The [`Binding`]s of the arguments and result of an [`EntryPoint`]'s
+ /// [`Function`] describe a SPIR-V shader interface. In SPIR-V, the
+ /// interface is represented by global variables in the `Input` and `Output`
+ /// storage classes, with decorations indicating which builtin or location
+ /// each variable corresponds to.
+ ///
+ /// This function emits a single global `OpVariable` for a single value from
+ /// the interface, and adds appropriate decorations to indicate which
+ /// builtin or location it represents, how it should be interpolated, and so
+ /// on. The `class` argument gives the variable's SPIR-V storage class,
+ /// which should be either [`Input`] or [`Output`].
+ ///
+ /// [`Binding`]: crate::Binding
+ /// [`Function`]: crate::Function
+ /// [`EntryPoint`]: crate::EntryPoint
+ /// [`Input`]: spirv::StorageClass::Input
+ /// [`Output`]: spirv::StorageClass::Output
+ fn write_varying(
+ &mut self,
+ ir_module: &crate::Module,
+ stage: crate::ShaderStage,
+ class: spirv::StorageClass,
+ debug_name: Option<&str>,
+ ty: Handle<crate::Type>,
+ binding: &crate::Binding,
+ ) -> Result<Word, Error> {
+ let id = self.id_gen.next();
+ let pointer_type_id = self.get_pointer_id(&ir_module.types, ty, class)?;
+ Instruction::variable(pointer_type_id, id, class, None)
+ .to_words(&mut self.logical_layout.declarations);
+
+ if self
+ .flags
+ .contains(WriterFlags::DEBUG | WriterFlags::LABEL_VARYINGS)
+ {
+ if let Some(name) = debug_name {
+ self.debugs.push(Instruction::name(id, name));
+ }
+ }
+
+ use spirv::{BuiltIn, Decoration};
+
+ match *binding {
+ crate::Binding::Location {
+ location,
+ interpolation,
+ sampling,
+ } => {
+ self.decorate(id, Decoration::Location, &[location]);
+
+ // The Vulkan spec says: VUID-StandaloneSpirv-Flat-06202
+ //
+ // > The Flat, NoPerspective, Sample, and Centroid decorations
+ // > must not be used on variables with the Input storage class in
+ // > a vertex shader
+ if class != spirv::StorageClass::Input || stage != crate::ShaderStage::Vertex {
+ match interpolation {
+ // Perspective-correct interpolation is the default in SPIR-V.
+ None | Some(crate::Interpolation::Perspective) => (),
+ Some(crate::Interpolation::Flat) => {
+ self.decorate(id, Decoration::Flat, &[]);
+ }
+ Some(crate::Interpolation::Linear) => {
+ self.decorate(id, Decoration::NoPerspective, &[]);
+ }
+ }
+ }
+
+ match sampling {
+ // Center sampling is the default in SPIR-V.
+ None | Some(crate::Sampling::Center) => (),
+ Some(crate::Sampling::Centroid) => {
+ self.decorate(id, Decoration::Centroid, &[]);
+ }
+ Some(crate::Sampling::Sample) => {
+ self.require_any(
+ "per-sample interpolation",
+ &[spirv::Capability::SampleRateShading],
+ )?;
+ self.decorate(id, Decoration::Sample, &[]);
+ }
+ }
+ }
+ crate::Binding::BuiltIn(built_in) => {
+ use crate::BuiltIn as Bi;
+ let built_in = match built_in {
+ Bi::Position { invariant } => {
+ if invariant {
+ self.decorate(id, Decoration::Invariant, &[]);
+ }
+
+ if class == spirv::StorageClass::Output {
+ BuiltIn::Position
+ } else {
+ BuiltIn::FragCoord
+ }
+ }
+ Bi::ViewIndex => {
+ self.require_any("`view_index` built-in", &[spirv::Capability::MultiView])?;
+ BuiltIn::ViewIndex
+ }
+ // vertex
+ Bi::BaseInstance => BuiltIn::BaseInstance,
+ Bi::BaseVertex => BuiltIn::BaseVertex,
+ Bi::ClipDistance => BuiltIn::ClipDistance,
+ Bi::CullDistance => BuiltIn::CullDistance,
+ Bi::InstanceIndex => BuiltIn::InstanceIndex,
+ Bi::PointSize => BuiltIn::PointSize,
+ Bi::VertexIndex => BuiltIn::VertexIndex,
+ // fragment
+ Bi::FragDepth => BuiltIn::FragDepth,
+ Bi::FrontFacing => BuiltIn::FrontFacing,
+ Bi::PrimitiveIndex => {
+ self.require_any(
+ "`primitive_index` built-in",
+ &[spirv::Capability::Geometry],
+ )?;
+ BuiltIn::PrimitiveId
+ }
+ Bi::SampleIndex => {
+ self.require_any(
+ "`sample_index` built-in",
+ &[spirv::Capability::SampleRateShading],
+ )?;
+
+ BuiltIn::SampleId
+ }
+ Bi::SampleMask => BuiltIn::SampleMask,
+ // compute
+ Bi::GlobalInvocationId => BuiltIn::GlobalInvocationId,
+ Bi::LocalInvocationId => BuiltIn::LocalInvocationId,
+ Bi::LocalInvocationIndex => BuiltIn::LocalInvocationIndex,
+ Bi::WorkGroupId => BuiltIn::WorkgroupId,
+ Bi::WorkGroupSize => BuiltIn::WorkgroupSize,
+ Bi::NumWorkGroups => BuiltIn::NumWorkgroups,
+ };
+
+ self.decorate(id, Decoration::BuiltIn, &[built_in as u32]);
+
+ use crate::ScalarKind as Sk;
+
+ // Per the Vulkan spec, `VUID-StandaloneSpirv-Flat-04744`:
+ //
+ // > Any variable with integer or double-precision floating-
+ // > point type and with Input storage class in a fragment
+ // > shader, must be decorated Flat
+ if class == spirv::StorageClass::Input && stage == crate::ShaderStage::Fragment {
+ let is_flat = match ir_module.types[ty].inner {
+ crate::TypeInner::Scalar { kind, .. }
+ | crate::TypeInner::Vector { kind, .. } => match kind {
+ Sk::Uint | Sk::Sint | Sk::Bool => true,
+ Sk::Float => false,
+ },
+ _ => false,
+ };
+
+ if is_flat {
+ self.decorate(id, Decoration::Flat, &[]);
+ }
+ }
+ }
+ }
+
+ Ok(id)
+ }
+
+ fn write_global_variable(
+ &mut self,
+ ir_module: &crate::Module,
+ global_variable: &crate::GlobalVariable,
+ ) -> Result<Word, Error> {
+ use spirv::Decoration;
+
+ let id = self.id_gen.next();
+ let class = map_storage_class(global_variable.space);
+
+ //self.check(class.required_capabilities())?;
+
+ if self.flags.contains(WriterFlags::DEBUG) {
+ if let Some(ref name) = global_variable.name {
+ self.debugs.push(Instruction::name(id, name));
+ }
+ }
+
+ let storage_access = match global_variable.space {
+ crate::AddressSpace::Storage { access } => Some(access),
+ _ => match ir_module.types[global_variable.ty].inner {
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Storage { access, .. },
+ ..
+ } => Some(access),
+ _ => None,
+ },
+ };
+ if let Some(storage_access) = storage_access {
+ if !storage_access.contains(crate::StorageAccess::LOAD) {
+ self.decorate(id, Decoration::NonReadable, &[]);
+ }
+ if !storage_access.contains(crate::StorageAccess::STORE) {
+ self.decorate(id, Decoration::NonWritable, &[]);
+ }
+ }
+
+ let mut substitute_inner_type_lookup = None;
+ if let Some(ref res_binding) = global_variable.binding {
+ self.decorate(id, Decoration::DescriptorSet, &[res_binding.group]);
+ self.decorate(id, Decoration::Binding, &[res_binding.binding]);
+
+ if let Some(&BindingInfo {
+ binding_array_size: Some(remapped_binding_array_size),
+ }) = self.binding_map.get(res_binding)
+ {
+ if let crate::TypeInner::BindingArray { base, .. } =
+ ir_module.types[global_variable.ty].inner
+ {
+ substitute_inner_type_lookup =
+ Some(LookupType::Local(LocalType::PointerToBindingArray {
+ base,
+ size: remapped_binding_array_size as u64,
+ }))
+ }
+ } else {
+ }
+ };
+
+ let init_word = global_variable
+ .init
+ .map(|constant| self.constant_ids[constant.index()]);
+ let inner_type_id = self.get_type_id(
+ substitute_inner_type_lookup.unwrap_or(LookupType::Handle(global_variable.ty)),
+ );
+
+ // generate the wrapping structure if needed
+ let pointer_type_id = if global_needs_wrapper(ir_module, global_variable) {
+ let wrapper_type_id = self.id_gen.next();
+
+ self.decorate(wrapper_type_id, Decoration::Block, &[]);
+ let member = crate::StructMember {
+ name: None,
+ ty: global_variable.ty,
+ binding: None,
+ offset: 0,
+ };
+ self.decorate_struct_member(wrapper_type_id, 0, &member, &ir_module.types)?;
+
+ Instruction::type_struct(wrapper_type_id, &[inner_type_id])
+ .to_words(&mut self.logical_layout.declarations);
+
+ let pointer_type_id = self.id_gen.next();
+ Instruction::type_pointer(pointer_type_id, class, wrapper_type_id)
+ .to_words(&mut self.logical_layout.declarations);
+
+ pointer_type_id
+ } else {
+ // This is a global variable in the Storage address space. The only
+ // way it could have `global_needs_wrapper() == false` is if it has
+ // a runtime-sized array. In this case, we need to decorate it with
+ // Block.
+ if let crate::AddressSpace::Storage { .. } = global_variable.space {
+ self.decorate(inner_type_id, Decoration::Block, &[]);
+ }
+ if substitute_inner_type_lookup.is_some() {
+ inner_type_id
+ } else {
+ self.get_pointer_id(&ir_module.types, global_variable.ty, class)?
+ }
+ };
+
+ let init_word = match global_variable.space {
+ crate::AddressSpace::Private => {
+ init_word.or_else(|| Some(self.write_constant_null(inner_type_id)))
+ }
+ _ => init_word,
+ };
+
+ Instruction::variable(pointer_type_id, id, class, init_word)
+ .to_words(&mut self.logical_layout.declarations);
+ Ok(id)
+ }
+
+ /// Write the necessary decorations for a struct member.
+ ///
+ /// Emit decorations for the `index`'th member of the struct type
+ /// designated by `struct_id`, described by `member`.
+ fn decorate_struct_member(
+ &mut self,
+ struct_id: Word,
+ index: usize,
+ member: &crate::StructMember,
+ arena: &UniqueArena<crate::Type>,
+ ) -> Result<(), Error> {
+ use spirv::Decoration;
+
+ self.annotations.push(Instruction::member_decorate(
+ struct_id,
+ index as u32,
+ Decoration::Offset,
+ &[member.offset],
+ ));
+
+ if self.flags.contains(WriterFlags::DEBUG) {
+ if let Some(ref name) = member.name {
+ self.debugs
+ .push(Instruction::member_name(struct_id, index as u32, name));
+ }
+ }
+
+ // Matrices and arrays of matrices both require decorations,
+ // so "see through" an array to determine if they're needed.
+ let member_array_subty_inner = match arena[member.ty].inner {
+ crate::TypeInner::Array { base, .. } => &arena[base].inner,
+ ref other => other,
+ };
+ if let crate::TypeInner::Matrix {
+ columns: _,
+ rows,
+ width,
+ } = *member_array_subty_inner
+ {
+ let byte_stride = Alignment::from(rows) * width as u32;
+ self.annotations.push(Instruction::member_decorate(
+ struct_id,
+ index as u32,
+ Decoration::ColMajor,
+ &[],
+ ));
+ self.annotations.push(Instruction::member_decorate(
+ struct_id,
+ index as u32,
+ Decoration::MatrixStride,
+ &[byte_stride],
+ ));
+ }
+
+ Ok(())
+ }
+
+ fn get_function_type(&mut self, lookup_function_type: LookupFunctionType) -> Word {
+ match self
+ .lookup_function_type
+ .entry(lookup_function_type.clone())
+ {
+ Entry::Occupied(e) => *e.get(),
+ _ => {
+ let id = self.id_gen.next();
+ let instruction = Instruction::type_function(
+ id,
+ lookup_function_type.return_type_id,
+ &lookup_function_type.parameter_type_ids,
+ );
+ instruction.to_words(&mut self.logical_layout.declarations);
+ self.lookup_function_type.insert(lookup_function_type, id);
+ id
+ }
+ }
+ }
+
+ fn write_physical_layout(&mut self) {
+ self.physical_layout.bound = self.id_gen.0 + 1;
+ }
+
+ fn write_logical_layout(
+ &mut self,
+ ir_module: &crate::Module,
+ mod_info: &ModuleInfo,
+ ep_index: Option<usize>,
+ ) -> Result<(), Error> {
+ fn has_view_index_check(
+ ir_module: &crate::Module,
+ binding: Option<&crate::Binding>,
+ ty: Handle<crate::Type>,
+ ) -> bool {
+ match ir_module.types[ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => members.iter().any(|member| {
+ has_view_index_check(ir_module, member.binding.as_ref(), member.ty)
+ }),
+ _ => binding == Some(&crate::Binding::BuiltIn(crate::BuiltIn::ViewIndex)),
+ }
+ }
+
+ let has_storage_buffers =
+ ir_module
+ .global_variables
+ .iter()
+ .any(|(_, var)| match var.space {
+ crate::AddressSpace::Storage { .. } => true,
+ _ => false,
+ });
+ let has_view_index = ir_module
+ .entry_points
+ .iter()
+ .flat_map(|entry| entry.function.arguments.iter())
+ .any(|arg| has_view_index_check(ir_module, arg.binding.as_ref(), arg.ty));
+
+ if self.physical_layout.version < 0x10300 && has_storage_buffers {
+ // enable the storage buffer class on < SPV-1.3
+ Instruction::extension("SPV_KHR_storage_buffer_storage_class")
+ .to_words(&mut self.logical_layout.extensions);
+ }
+ if has_view_index {
+ Instruction::extension("SPV_KHR_multiview")
+ .to_words(&mut self.logical_layout.extensions)
+ }
+ Instruction::type_void(self.void_type).to_words(&mut self.logical_layout.declarations);
+ Instruction::ext_inst_import(self.gl450_ext_inst_id, "GLSL.std.450")
+ .to_words(&mut self.logical_layout.ext_inst_imports);
+
+ if self.flags.contains(WriterFlags::DEBUG) {
+ self.debugs
+ .push(Instruction::source(spirv::SourceLanguage::GLSL, 450));
+ }
+
+ self.constant_ids.resize(ir_module.constants.len(), 0);
+ // first, output all the scalar constants
+ for (handle, constant) in ir_module.constants.iter() {
+ match constant.inner {
+ crate::ConstantInner::Composite { .. } => continue,
+ crate::ConstantInner::Scalar { width, ref value } => {
+ self.constant_ids[handle.index()] = match constant.name {
+ Some(ref name) => {
+ let id = self.id_gen.next();
+ self.write_constant_scalar(id, value, width, Some(name));
+ id
+ }
+ None => self.get_constant_scalar(*value, width),
+ };
+ }
+ }
+ }
+
+ // then all types, some of them may rely on constants and struct type set
+ for (handle, _) in ir_module.types.iter() {
+ self.write_type_declaration_arena(&ir_module.types, handle)?;
+ }
+
+ // the all the composite constants, they rely on types
+ for (handle, constant) in ir_module.constants.iter() {
+ match constant.inner {
+ crate::ConstantInner::Scalar { .. } => continue,
+ crate::ConstantInner::Composite { ty, ref components } => {
+ let id = self.id_gen.next();
+ self.constant_ids[handle.index()] = id;
+ if self.flags.contains(WriterFlags::DEBUG) {
+ if let Some(ref name) = constant.name {
+ self.debugs.push(Instruction::name(id, name));
+ }
+ }
+ self.write_constant_composite(id, ty, components)?;
+ }
+ }
+ }
+ debug_assert_eq!(self.constant_ids.iter().position(|&id| id == 0), None);
+
+ // now write all globals
+ for (handle, var) in ir_module.global_variables.iter() {
+ // If a single entry point was specified, only write `OpVariable` instructions
+ // for the globals it actually uses. Emit dummies for the others,
+ // to preserve the indices in `global_variables`.
+ let gvar = match ep_index {
+ Some(index) if mod_info.get_entry_point(index)[handle].is_empty() => {
+ GlobalVariable::dummy()
+ }
+ _ => {
+ let id = self.write_global_variable(ir_module, var)?;
+ GlobalVariable::new(id)
+ }
+ };
+ self.global_variables.push(gvar);
+ }
+
+ // all functions
+ for (handle, ir_function) in ir_module.functions.iter() {
+ let info = &mod_info[handle];
+ if let Some(index) = ep_index {
+ let ep_info = mod_info.get_entry_point(index);
+ // If this function uses globals that we omitted from the SPIR-V
+ // because the entry point and its callees didn't use them,
+ // then we must skip it.
+ if !ep_info.dominates_global_use(info) {
+ log::info!("Skip function {:?}", ir_function.name);
+ continue;
+ }
+ }
+ let id = self.write_function(ir_function, info, ir_module, None)?;
+ self.lookup_function.insert(handle, id);
+ }
+
+ // and entry points
+ for (index, ir_ep) in ir_module.entry_points.iter().enumerate() {
+ if ep_index.is_some() && ep_index != Some(index) {
+ continue;
+ }
+ let info = mod_info.get_entry_point(index);
+ let ep_instruction = self.write_entry_point(ir_ep, info, ir_module)?;
+ ep_instruction.to_words(&mut self.logical_layout.entry_points);
+ }
+
+ for capability in self.capabilities_used.iter() {
+ Instruction::capability(*capability).to_words(&mut self.logical_layout.capabilities);
+ }
+ for extension in self.extensions_used.iter() {
+ Instruction::extension(extension).to_words(&mut self.logical_layout.extensions);
+ }
+ if ir_module.entry_points.is_empty() {
+ // SPIR-V doesn't like modules without entry points
+ Instruction::capability(spirv::Capability::Linkage)
+ .to_words(&mut self.logical_layout.capabilities);
+ }
+
+ let addressing_model = spirv::AddressingModel::Logical;
+ let memory_model = spirv::MemoryModel::GLSL450;
+ //self.check(addressing_model.required_capabilities())?;
+ //self.check(memory_model.required_capabilities())?;
+
+ Instruction::memory_model(addressing_model, memory_model)
+ .to_words(&mut self.logical_layout.memory_model);
+
+ if self.flags.contains(WriterFlags::DEBUG) {
+ for debug in self.debugs.iter() {
+ debug.to_words(&mut self.logical_layout.debugs);
+ }
+ }
+
+ for annotation in self.annotations.iter() {
+ annotation.to_words(&mut self.logical_layout.annotations);
+ }
+
+ Ok(())
+ }
+
+ pub fn write(
+ &mut self,
+ ir_module: &crate::Module,
+ info: &ModuleInfo,
+ pipeline_options: Option<&PipelineOptions>,
+ words: &mut Vec<Word>,
+ ) -> Result<(), Error> {
+ self.reset();
+
+ // Try to find the entry point and corresponding index
+ let ep_index = match pipeline_options {
+ Some(po) => {
+ let index = ir_module
+ .entry_points
+ .iter()
+ .position(|ep| po.shader_stage == ep.stage && po.entry_point == ep.name)
+ .ok_or(Error::EntryPointNotFound)?;
+ Some(index)
+ }
+ None => None,
+ };
+
+ self.write_logical_layout(ir_module, info, ep_index)?;
+ self.write_physical_layout();
+
+ self.physical_layout.in_words(words);
+ self.logical_layout.in_words(words);
+ Ok(())
+ }
+
+ /// Return the set of capabilities the last module written used.
+ pub const fn get_capabilities_used(&self) -> &crate::FastHashSet<spirv::Capability> {
+ &self.capabilities_used
+ }
+}
+
+#[test]
+fn test_write_physical_layout() {
+ let mut writer = Writer::new(&Options::default()).unwrap();
+ assert_eq!(writer.physical_layout.bound, 0);
+ writer.write_physical_layout();
+ assert_eq!(writer.physical_layout.bound, 3);
+}
diff --git a/third_party/rust/naga/src/back/wgsl/mod.rs b/third_party/rust/naga/src/back/wgsl/mod.rs
new file mode 100644
index 0000000000..d731b1ca0c
--- /dev/null
+++ b/third_party/rust/naga/src/back/wgsl/mod.rs
@@ -0,0 +1,52 @@
+/*!
+Backend for [WGSL][wgsl] (WebGPU Shading Language).
+
+[wgsl]: https://gpuweb.github.io/gpuweb/wgsl.html
+*/
+
+mod writer;
+
+use thiserror::Error;
+
+pub use writer::{Writer, WriterFlags};
+
+#[derive(Error, Debug)]
+pub enum Error {
+ #[error(transparent)]
+ FmtError(#[from] std::fmt::Error),
+ #[error("{0}")]
+ Custom(String),
+ #[error("{0}")]
+ Unimplemented(String), // TODO: Error used only during development
+ #[error("Unsupported math function: {0:?}")]
+ UnsupportedMathFunction(crate::MathFunction),
+ #[error("Unsupported relational function: {0:?}")]
+ UnsupportedRelationalFunction(crate::RelationalFunction),
+}
+
+pub fn write_string(
+ module: &crate::Module,
+ info: &crate::valid::ModuleInfo,
+ flags: WriterFlags,
+) -> Result<String, Error> {
+ let mut w = Writer::new(String::new(), flags);
+ w.write(module, info)?;
+ let output = w.finish();
+ Ok(output)
+}
+
+impl crate::AtomicFunction {
+ const fn to_wgsl(self) -> &'static str {
+ match self {
+ Self::Add => "Add",
+ Self::Subtract => "Sub",
+ Self::And => "And",
+ Self::InclusiveOr => "Or",
+ Self::ExclusiveOr => "Xor",
+ Self::Min => "Min",
+ Self::Max => "Max",
+ Self::Exchange { compare: None } => "Exchange",
+ Self::Exchange { .. } => "CompareExchangeWeak",
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/back/wgsl/writer.rs b/third_party/rust/naga/src/back/wgsl/writer.rs
new file mode 100644
index 0000000000..817fa78b0a
--- /dev/null
+++ b/third_party/rust/naga/src/back/wgsl/writer.rs
@@ -0,0 +1,2061 @@
+use super::Error;
+use crate::{
+ back,
+ proc::{self, NameKey},
+ valid, Handle, Module, ShaderStage, TypeInner,
+};
+use std::fmt::Write;
+
+/// Shorthand result used internally by the backend
+type BackendResult = Result<(), Error>;
+
+/// WGSL [attribute](https://gpuweb.github.io/gpuweb/wgsl/#attributes)
+enum Attribute {
+ Binding(u32),
+ BuiltIn(crate::BuiltIn),
+ Group(u32),
+ Invariant,
+ Interpolate(Option<crate::Interpolation>, Option<crate::Sampling>),
+ Location(u32),
+ Stage(ShaderStage),
+ WorkGroupSize([u32; 3]),
+}
+
+/// The WGSL form that `write_expr_with_indirection` should use to render a Naga
+/// expression.
+///
+/// Sometimes a Naga `Expression` alone doesn't provide enough information to
+/// choose the right rendering for it in WGSL. For example, one natural WGSL
+/// rendering of a Naga `LocalVariable(x)` expression might be `&x`, since
+/// `LocalVariable` produces a pointer to the local variable's storage. But when
+/// rendering a `Store` statement, the `pointer` operand must be the left hand
+/// side of a WGSL assignment, so the proper rendering is `x`.
+///
+/// The caller of `write_expr_with_indirection` must provide an `Expected` value
+/// to indicate how ambiguous expressions should be rendered.
+#[derive(Clone, Copy, Debug)]
+enum Indirection {
+ /// Render pointer-construction expressions as WGSL `ptr`-typed expressions.
+ ///
+ /// This is the right choice for most cases. Whenever a Naga pointer
+ /// expression is not the `pointer` operand of a `Load` or `Store`, it
+ /// must be a WGSL pointer expression.
+ Ordinary,
+
+ /// Render pointer-construction expressions as WGSL reference-typed
+ /// expressions.
+ ///
+ /// For example, this is the right choice for the `pointer` operand when
+ /// rendering a `Store` statement as a WGSL assignment.
+ Reference,
+}
+
+bitflags::bitflags! {
+ #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+ pub struct WriterFlags: u32 {
+ /// Always annotate the type information instead of inferring.
+ const EXPLICIT_TYPES = 0x1;
+ }
+}
+
+pub struct Writer<W> {
+ out: W,
+ flags: WriterFlags,
+ names: crate::FastHashMap<NameKey, String>,
+ namer: proc::Namer,
+ named_expressions: crate::NamedExpressions,
+ ep_results: Vec<(ShaderStage, Handle<crate::Type>)>,
+}
+
+impl<W: Write> Writer<W> {
+ pub fn new(out: W, flags: WriterFlags) -> Self {
+ Writer {
+ out,
+ flags,
+ names: crate::FastHashMap::default(),
+ namer: proc::Namer::default(),
+ named_expressions: crate::NamedExpressions::default(),
+ ep_results: vec![],
+ }
+ }
+
+ fn reset(&mut self, module: &Module) {
+ self.names.clear();
+ self.namer.reset(
+ module,
+ crate::keywords::wgsl::RESERVED,
+ // an identifier must not start with two underscore
+ &["__"],
+ &mut self.names,
+ );
+ self.named_expressions.clear();
+ self.ep_results.clear();
+ }
+
+ pub fn write(&mut self, module: &Module, info: &valid::ModuleInfo) -> BackendResult {
+ self.reset(module);
+
+ // Save all ep result types
+ for (_, ep) in module.entry_points.iter().enumerate() {
+ if let Some(ref result) = ep.function.result {
+ self.ep_results.push((ep.stage, result.ty));
+ }
+ }
+
+ // Write all structs
+ for (handle, ty) in module.types.iter() {
+ if let TypeInner::Struct {
+ ref members,
+ span: _,
+ } = ty.inner
+ {
+ self.write_struct(module, handle, members)?;
+ writeln!(self.out)?;
+ }
+ }
+
+ // Write all constants
+ for (handle, constant) in module.constants.iter() {
+ if constant.name.is_some() {
+ self.write_global_constant(module, &constant.inner, handle)?;
+ }
+ }
+
+ // Write all globals
+ for (ty, global) in module.global_variables.iter() {
+ self.write_global(module, global, ty)?;
+ }
+
+ if !module.global_variables.is_empty() {
+ // Add extra newline for readability
+ writeln!(self.out)?;
+ }
+
+ // Write all regular functions
+ for (handle, function) in module.functions.iter() {
+ let fun_info = &info[handle];
+
+ let func_ctx = back::FunctionCtx {
+ ty: back::FunctionType::Function(handle),
+ info: fun_info,
+ expressions: &function.expressions,
+ named_expressions: &function.named_expressions,
+ };
+
+ // Write the function
+ self.write_function(module, function, &func_ctx)?;
+
+ writeln!(self.out)?;
+ }
+
+ // Write all entry points
+ for (index, ep) in module.entry_points.iter().enumerate() {
+ let attributes = match ep.stage {
+ ShaderStage::Vertex | ShaderStage::Fragment => vec![Attribute::Stage(ep.stage)],
+ ShaderStage::Compute => vec![
+ Attribute::Stage(ShaderStage::Compute),
+ Attribute::WorkGroupSize(ep.workgroup_size),
+ ],
+ };
+
+ self.write_attributes(&attributes)?;
+ // Add a newline after attribute
+ writeln!(self.out)?;
+
+ let func_ctx = back::FunctionCtx {
+ ty: back::FunctionType::EntryPoint(index as u16),
+ info: info.get_entry_point(index),
+ expressions: &ep.function.expressions,
+ named_expressions: &ep.function.named_expressions,
+ };
+ self.write_function(module, &ep.function, &func_ctx)?;
+
+ if index < module.entry_points.len() - 1 {
+ writeln!(self.out)?;
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write [`ScalarValue`](crate::ScalarValue)
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ fn write_scalar_value(&mut self, value: crate::ScalarValue) -> BackendResult {
+ use crate::ScalarValue as Sv;
+
+ match value {
+ Sv::Sint(value) => write!(self.out, "{}", value)?,
+ Sv::Uint(value) => write!(self.out, "{}u", value)?,
+ // Floats are written using `Debug` instead of `Display` because it always appends the
+ // decimal part even it's zero
+ Sv::Float(value) => write!(self.out, "{:?}", value)?,
+ Sv::Bool(value) => write!(self.out, "{}", value)?,
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write struct name
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ fn write_struct_name(&mut self, module: &Module, handle: Handle<crate::Type>) -> BackendResult {
+ if module.types[handle].name.is_none() {
+ if let Some(&(stage, _)) = self.ep_results.iter().find(|&&(_, ty)| ty == handle) {
+ let name = match stage {
+ ShaderStage::Compute => "ComputeOutput",
+ ShaderStage::Fragment => "FragmentOutput",
+ ShaderStage::Vertex => "VertexOutput",
+ };
+
+ write!(self.out, "{}", name)?;
+ return Ok(());
+ }
+ }
+
+ write!(self.out, "{}", self.names[&NameKey::Type(handle)])?;
+
+ Ok(())
+ }
+
+ /// Helper method used to write
+ /// [functions](https://gpuweb.github.io/gpuweb/wgsl/#functions)
+ ///
+ /// # Notes
+ /// Ends in a newline
+ fn write_function(
+ &mut self,
+ module: &Module,
+ func: &crate::Function,
+ func_ctx: &back::FunctionCtx<'_>,
+ ) -> BackendResult {
+ let func_name = match func_ctx.ty {
+ back::FunctionType::EntryPoint(index) => &self.names[&NameKey::EntryPoint(index)],
+ back::FunctionType::Function(handle) => &self.names[&NameKey::Function(handle)],
+ };
+
+ // Write function name
+ write!(self.out, "fn {}(", func_name)?;
+
+ // Write function arguments
+ for (index, arg) in func.arguments.iter().enumerate() {
+ // Write argument attribute if a binding is present
+ if let Some(ref binding) = arg.binding {
+ self.write_attributes(&map_binding_to_attribute(
+ binding,
+ module.types[arg.ty].inner.scalar_kind(),
+ ))?;
+ }
+ // Write argument name
+ let argument_name = match func_ctx.ty {
+ back::FunctionType::Function(handle) => {
+ &self.names[&NameKey::FunctionArgument(handle, index as u32)]
+ }
+ back::FunctionType::EntryPoint(ep_index) => {
+ &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)]
+ }
+ };
+
+ write!(self.out, "{}: ", argument_name)?;
+ // Write argument type
+ self.write_type(module, arg.ty)?;
+ if index < func.arguments.len() - 1 {
+ // Add a separator between args
+ write!(self.out, ", ")?;
+ }
+ }
+
+ write!(self.out, ")")?;
+
+ // Write function return type
+ if let Some(ref result) = func.result {
+ write!(self.out, " -> ")?;
+ if let Some(ref binding) = result.binding {
+ self.write_attributes(&map_binding_to_attribute(
+ binding,
+ module.types[result.ty].inner.scalar_kind(),
+ ))?;
+ }
+ self.write_type(module, result.ty)?;
+ }
+
+ write!(self.out, " {{")?;
+ writeln!(self.out)?;
+
+ // Write function local variables
+ for (handle, local) in func.local_variables.iter() {
+ // Write indentation (only for readability)
+ write!(self.out, "{}", back::INDENT)?;
+
+ // Write the local name
+ // The leading space is important
+ write!(self.out, "var {}: ", self.names[&func_ctx.name_key(handle)])?;
+
+ // Write the local type
+ self.write_type(module, local.ty)?;
+
+ // Write the local initializer if needed
+ if let Some(init) = local.init {
+ // Put the equal signal only if there's a initializer
+ // The leading and trailing spaces aren't needed but help with readability
+ write!(self.out, " = ")?;
+
+ // Write the constant
+ // `write_constant` adds no trailing or leading space/newline
+ self.write_constant(module, init)?;
+ }
+
+ // Finish the local with `;` and add a newline (only for readability)
+ writeln!(self.out, ";")?
+ }
+
+ if !func.local_variables.is_empty() {
+ writeln!(self.out)?;
+ }
+
+ // Write the function body (statement list)
+ for sta in func.body.iter() {
+ // The indentation should always be 1 when writing the function body
+ self.write_stmt(module, sta, func_ctx, back::Level(1))?;
+ }
+
+ writeln!(self.out, "}}")?;
+
+ self.named_expressions.clear();
+
+ Ok(())
+ }
+
+ /// Helper method to write a attribute
+ fn write_attributes(&mut self, attributes: &[Attribute]) -> BackendResult {
+ for attribute in attributes {
+ match *attribute {
+ Attribute::Location(id) => write!(self.out, "@location({}) ", id)?,
+ Attribute::BuiltIn(builtin_attrib) => {
+ if let Some(builtin) = builtin_str(builtin_attrib) {
+ write!(self.out, "@builtin({}) ", builtin)?;
+ } else {
+ log::warn!("Unsupported builtin attribute: {:?}", builtin_attrib);
+ }
+ }
+ Attribute::Stage(shader_stage) => {
+ let stage_str = match shader_stage {
+ ShaderStage::Vertex => "vertex",
+ ShaderStage::Fragment => "fragment",
+ ShaderStage::Compute => "compute",
+ };
+ write!(self.out, "@{} ", stage_str)?;
+ }
+ Attribute::WorkGroupSize(size) => {
+ write!(
+ self.out,
+ "@workgroup_size({}, {}, {}) ",
+ size[0], size[1], size[2]
+ )?;
+ }
+ Attribute::Binding(id) => write!(self.out, "@binding({}) ", id)?,
+ Attribute::Group(id) => write!(self.out, "@group({}) ", id)?,
+ Attribute::Invariant => write!(self.out, "@invariant ")?,
+ Attribute::Interpolate(interpolation, sampling) => {
+ if sampling.is_some() && sampling != Some(crate::Sampling::Center) {
+ write!(
+ self.out,
+ "@interpolate({}, {}) ",
+ interpolation_str(
+ interpolation.unwrap_or(crate::Interpolation::Perspective)
+ ),
+ sampling_str(sampling.unwrap_or(crate::Sampling::Center))
+ )?;
+ } else if interpolation.is_some()
+ && interpolation != Some(crate::Interpolation::Perspective)
+ {
+ write!(
+ self.out,
+ "@interpolate({}) ",
+ interpolation_str(
+ interpolation.unwrap_or(crate::Interpolation::Perspective)
+ )
+ )?;
+ }
+ }
+ };
+ }
+ Ok(())
+ }
+
+ /// Helper method used to write structs
+ ///
+ /// # Notes
+ /// Ends in a newline
+ fn write_struct(
+ &mut self,
+ module: &Module,
+ handle: Handle<crate::Type>,
+ members: &[crate::StructMember],
+ ) -> BackendResult {
+ write!(self.out, "struct ")?;
+ self.write_struct_name(module, handle)?;
+ write!(self.out, " {{")?;
+ writeln!(self.out)?;
+ for (index, member) in members.iter().enumerate() {
+ // Skip struct member with unsupported built in
+ if let Some(crate::Binding::BuiltIn(built_in)) = member.binding {
+ if builtin_str(built_in).is_none() {
+ log::warn!("Skip member with unsupported builtin {:?}", built_in);
+ continue;
+ }
+ }
+
+ // The indentation is only for readability
+ write!(self.out, "{}", back::INDENT)?;
+ if let Some(ref binding) = member.binding {
+ self.write_attributes(&map_binding_to_attribute(
+ binding,
+ module.types[member.ty].inner.scalar_kind(),
+ ))?;
+ }
+ // Write struct member name and type
+ let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
+ write!(self.out, "{}: ", member_name)?;
+ self.write_type(module, member.ty)?;
+ write!(self.out, ",")?;
+ writeln!(self.out)?;
+ }
+
+ write!(self.out, "}}")?;
+
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ /// Helper method used to write non image/sampler types
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
+ let inner = &module.types[ty].inner;
+ match *inner {
+ TypeInner::Struct { .. } => self.write_struct_name(module, ty)?,
+ ref other => self.write_value_type(module, other)?,
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write value types
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult {
+ match *inner {
+ TypeInner::Vector { size, kind, .. } => write!(
+ self.out,
+ "vec{}<{}>",
+ back::vector_size_str(size),
+ scalar_kind_str(kind),
+ )?,
+ TypeInner::Sampler { comparison: false } => {
+ write!(self.out, "sampler")?;
+ }
+ TypeInner::Sampler { comparison: true } => {
+ write!(self.out, "sampler_comparison")?;
+ }
+ TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ // More about texture types: https://gpuweb.github.io/gpuweb/wgsl/#sampled-texture-type
+ use crate::ImageClass as Ic;
+
+ let dim_str = image_dimension_str(dim);
+ let arrayed_str = if arrayed { "_array" } else { "" };
+ let (class_str, multisampled_str, format_str, storage_str) = match class {
+ Ic::Sampled { kind, multi } => (
+ "",
+ if multi { "multisampled_" } else { "" },
+ scalar_kind_str(kind),
+ "",
+ ),
+ Ic::Depth { multi } => {
+ ("depth_", if multi { "multisampled_" } else { "" }, "", "")
+ }
+ Ic::Storage { format, access } => (
+ "storage_",
+ "",
+ storage_format_str(format),
+ if access.contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE)
+ {
+ ",read_write"
+ } else if access.contains(crate::StorageAccess::LOAD) {
+ ",read"
+ } else {
+ ",write"
+ },
+ ),
+ };
+ write!(
+ self.out,
+ "texture_{}{}{}{}",
+ class_str, multisampled_str, dim_str, arrayed_str
+ )?;
+
+ if !format_str.is_empty() {
+ write!(self.out, "<{}{}>", format_str, storage_str)?;
+ }
+ }
+ TypeInner::Scalar { kind, .. } => {
+ write!(self.out, "{}", scalar_kind_str(kind))?;
+ }
+ TypeInner::Atomic { kind, .. } => {
+ write!(self.out, "atomic<{}>", scalar_kind_str(kind))?;
+ }
+ TypeInner::Array {
+ base,
+ size,
+ stride: _,
+ } => {
+ // More info https://gpuweb.github.io/gpuweb/wgsl/#array-types
+ // array<A, 3> -- Constant array
+ // array<A> -- Dynamic array
+ write!(self.out, "array<")?;
+ match size {
+ crate::ArraySize::Constant(handle) => {
+ self.write_type(module, base)?;
+ write!(self.out, ",")?;
+ self.write_constant(module, handle)?;
+ }
+ crate::ArraySize::Dynamic => {
+ self.write_type(module, base)?;
+ }
+ }
+ write!(self.out, ">")?;
+ }
+ TypeInner::BindingArray { base, size } => {
+ // More info https://github.com/gpuweb/gpuweb/issues/2105
+ write!(self.out, "binding_array<")?;
+ match size {
+ crate::ArraySize::Constant(handle) => {
+ self.write_type(module, base)?;
+ write!(self.out, ",")?;
+ self.write_constant(module, handle)?;
+ }
+ crate::ArraySize::Dynamic => {
+ self.write_type(module, base)?;
+ }
+ }
+ write!(self.out, ">")?;
+ }
+ TypeInner::Matrix {
+ columns,
+ rows,
+ width: _,
+ } => {
+ write!(
+ self.out,
+ //TODO: Can matrix be other than f32?
+ "mat{}x{}<f32>",
+ back::vector_size_str(columns),
+ back::vector_size_str(rows),
+ )?;
+ }
+ TypeInner::Pointer { base, space } => {
+ let (address, maybe_access) = address_space_str(space);
+ // Everything but `AddressSpace::Handle` gives us a `address` name, but
+ // Naga IR never produces pointers to handles, so it doesn't matter much
+ // how we write such a type. Just write it as the base type alone.
+ if let Some(space) = address {
+ write!(self.out, "ptr<{}, ", space)?;
+ }
+ self.write_type(module, base)?;
+ if address.is_some() {
+ if let Some(access) = maybe_access {
+ write!(self.out, ", {}", access)?;
+ }
+ write!(self.out, ">")?;
+ }
+ }
+ TypeInner::ValuePointer {
+ size: None,
+ kind,
+ width: _,
+ space,
+ } => {
+ let (address, maybe_access) = address_space_str(space);
+ if let Some(space) = address {
+ write!(self.out, "ptr<{}, {}", space, scalar_kind_str(kind))?;
+ if let Some(access) = maybe_access {
+ write!(self.out, ", {}", access)?;
+ }
+ write!(self.out, ">")?;
+ } else {
+ return Err(Error::Unimplemented(format!(
+ "ValuePointer to AddressSpace::Handle {:?}",
+ inner
+ )));
+ }
+ }
+ TypeInner::ValuePointer {
+ size: Some(size),
+ kind,
+ width: _,
+ space,
+ } => {
+ let (address, maybe_access) = address_space_str(space);
+ if let Some(space) = address {
+ write!(
+ self.out,
+ "ptr<{}, vec{}<{}>",
+ space,
+ back::vector_size_str(size),
+ scalar_kind_str(kind)
+ )?;
+ if let Some(access) = maybe_access {
+ write!(self.out, ", {}", access)?;
+ }
+ write!(self.out, ">")?;
+ } else {
+ return Err(Error::Unimplemented(format!(
+ "ValuePointer to AddressSpace::Handle {:?}",
+ inner
+ )));
+ }
+ write!(self.out, ">")?;
+ }
+ _ => {
+ return Err(Error::Unimplemented(format!(
+ "write_value_type {:?}",
+ inner
+ )));
+ }
+ }
+
+ Ok(())
+ }
+ /// Helper method used to write statements
+ ///
+ /// # Notes
+ /// Always adds a newline
+ fn write_stmt(
+ &mut self,
+ module: &Module,
+ stmt: &crate::Statement,
+ func_ctx: &back::FunctionCtx<'_>,
+ level: back::Level,
+ ) -> BackendResult {
+ use crate::{Expression, Statement};
+
+ match *stmt {
+ Statement::Emit(ref range) => {
+ for handle in range.clone() {
+ let info = &func_ctx.info[handle];
+ let expr_name = if let Some(name) = func_ctx.named_expressions.get(&handle) {
+ // Front end provides names for all variables at the start of writing.
+ // But we write them to step by step. We need to recache them
+ // Otherwise, we could accidentally write variable name instead of full expression.
+ // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords.
+ Some(self.namer.call(name))
+ } else if info.ref_count == 0 {
+ write!(self.out, "{}_ = ", level)?;
+ self.write_expr(module, handle, func_ctx)?;
+ writeln!(self.out, ";")?;
+ continue;
+ } else {
+ let expr = &func_ctx.expressions[handle];
+ let min_ref_count = expr.bake_ref_count();
+ // Forcefully creating baking expressions in some cases to help with readability
+ let required_baking_expr = match *expr {
+ Expression::ImageLoad { .. }
+ | Expression::ImageQuery { .. }
+ | Expression::ImageSample { .. } => true,
+ _ => false,
+ };
+ if min_ref_count <= info.ref_count || required_baking_expr {
+ // If expression contains unsupported builtin we should skip it
+ if let Expression::Load { pointer } = func_ctx.expressions[handle] {
+ if let Expression::AccessIndex { base, index } =
+ func_ctx.expressions[pointer]
+ {
+ if access_to_unsupported_builtin(
+ base,
+ index,
+ module,
+ func_ctx.info,
+ ) {
+ return Ok(());
+ }
+ }
+ }
+
+ Some(format!("{}{}", back::BAKE_PREFIX, handle.index()))
+ } else {
+ None
+ }
+ };
+
+ if let Some(name) = expr_name {
+ write!(self.out, "{}", level)?;
+ self.start_named_expr(module, handle, func_ctx, &name)?;
+ self.write_expr(module, handle, func_ctx)?;
+ self.named_expressions.insert(handle, name);
+ writeln!(self.out, ";")?;
+ }
+ }
+ }
+ // TODO: copy-paste from glsl-out
+ Statement::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ write!(self.out, "{}", level)?;
+ write!(self.out, "if ")?;
+ self.write_expr(module, condition, func_ctx)?;
+ writeln!(self.out, " {{")?;
+
+ let l2 = level.next();
+ for sta in accept {
+ // Increase indentation to help with readability
+ self.write_stmt(module, sta, func_ctx, l2)?;
+ }
+
+ // If there are no statements in the reject block we skip writing it
+ // This is only for readability
+ if !reject.is_empty() {
+ writeln!(self.out, "{}}} else {{", level)?;
+
+ for sta in reject {
+ // Increase indentation to help with readability
+ self.write_stmt(module, sta, func_ctx, l2)?;
+ }
+ }
+
+ writeln!(self.out, "{}}}", level)?
+ }
+ Statement::Return { value } => {
+ write!(self.out, "{}", level)?;
+ write!(self.out, "return")?;
+ if let Some(return_value) = value {
+ // The leading space is important
+ write!(self.out, " ")?;
+ self.write_expr(module, return_value, func_ctx)?;
+ }
+ writeln!(self.out, ";")?;
+ }
+ // TODO: copy-paste from glsl-out
+ Statement::Kill => {
+ write!(self.out, "{}", level)?;
+ writeln!(self.out, "discard;")?
+ }
+ Statement::Store { pointer, value } => {
+ // WGSL does not support all SPIR-V builtins and we should skip it in generated shaders.
+ // We already skip them when we generate struct type.
+ // Now we need to find expression that used struct with ignored builtins
+ if let Expression::AccessIndex { base, index } = func_ctx.expressions[pointer] {
+ if access_to_unsupported_builtin(base, index, module, func_ctx.info) {
+ return Ok(());
+ }
+ }
+ write!(self.out, "{}", level)?;
+
+ let is_atomic = match *func_ctx.info[pointer].ty.inner_with(&module.types) {
+ crate::TypeInner::Pointer { base, .. } => match module.types[base].inner {
+ crate::TypeInner::Atomic { .. } => true,
+ _ => false,
+ },
+ _ => false,
+ };
+ if is_atomic {
+ write!(self.out, "atomicStore(")?;
+ self.write_expr(module, pointer, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, value, func_ctx)?;
+ write!(self.out, ")")?;
+ } else {
+ self.write_expr_with_indirection(
+ module,
+ pointer,
+ func_ctx,
+ Indirection::Reference,
+ )?;
+ write!(self.out, " = ")?;
+ self.write_expr(module, value, func_ctx)?;
+ }
+ writeln!(self.out, ";")?
+ }
+ Statement::Call {
+ function,
+ ref arguments,
+ result,
+ } => {
+ write!(self.out, "{}", level)?;
+ if let Some(expr) = result {
+ let name = format!("{}{}", back::BAKE_PREFIX, expr.index());
+ self.start_named_expr(module, expr, func_ctx, &name)?;
+ self.named_expressions.insert(expr, name);
+ }
+ let func_name = &self.names[&NameKey::Function(function)];
+ write!(self.out, "{}(", func_name)?;
+ for (index, &argument) in arguments.iter().enumerate() {
+ self.write_expr(module, argument, func_ctx)?;
+ // Only write a comma if isn't the last element
+ if index != arguments.len().saturating_sub(1) {
+ // The leading space is for readability only
+ write!(self.out, ", ")?;
+ }
+ }
+ writeln!(self.out, ");")?
+ }
+ Statement::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result,
+ } => {
+ write!(self.out, "{}", level)?;
+ let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ self.start_named_expr(module, result, func_ctx, &res_name)?;
+ self.named_expressions.insert(result, res_name);
+
+ let fun_str = fun.to_wgsl();
+ write!(self.out, "atomic{}(", fun_str)?;
+ self.write_expr(module, pointer, func_ctx)?;
+ if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
+ write!(self.out, ", ")?;
+ self.write_expr(module, cmp, func_ctx)?;
+ }
+ write!(self.out, ", ")?;
+ self.write_expr(module, value, func_ctx)?;
+ writeln!(self.out, ");")?
+ }
+ Statement::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => {
+ write!(self.out, "{}", level)?;
+ write!(self.out, "textureStore(")?;
+ self.write_expr(module, image, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, coordinate, func_ctx)?;
+ if let Some(array_index_expr) = array_index {
+ write!(self.out, ", ")?;
+ self.write_expr(module, array_index_expr, func_ctx)?;
+ }
+ write!(self.out, ", ")?;
+ self.write_expr(module, value, func_ctx)?;
+ writeln!(self.out, ");")?;
+ }
+ // TODO: copy-paste from glsl-out
+ Statement::Block(ref block) => {
+ write!(self.out, "{}", level)?;
+ writeln!(self.out, "{{")?;
+ for sta in block.iter() {
+ // Increase the indentation to help with readability
+ self.write_stmt(module, sta, func_ctx, level.next())?
+ }
+ writeln!(self.out, "{}}}", level)?
+ }
+ Statement::Switch {
+ selector,
+ ref cases,
+ } => {
+ // Start the switch
+ write!(self.out, "{}", level)?;
+ write!(self.out, "switch ")?;
+ self.write_expr(module, selector, func_ctx)?;
+ writeln!(self.out, " {{")?;
+
+ let type_postfix = match *func_ctx.info[selector].ty.inner_with(&module.types) {
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Uint,
+ ..
+ } => "u",
+ _ => "",
+ };
+
+ let l2 = level.next();
+ if !cases.is_empty() {
+ for case in cases {
+ match case.value {
+ crate::SwitchValue::Integer(value) => {
+ writeln!(self.out, "{}case {}{}: {{", l2, value, type_postfix)?;
+ }
+ crate::SwitchValue::Default => {
+ writeln!(self.out, "{}default: {{", l2)?;
+ }
+ }
+
+ for sta in case.body.iter() {
+ self.write_stmt(module, sta, func_ctx, l2.next())?;
+ }
+
+ if case.fall_through {
+ writeln!(self.out, "{}fallthrough;", l2.next())?;
+ }
+
+ writeln!(self.out, "{}}}", l2)?;
+ }
+ }
+
+ writeln!(self.out, "{}}}", level)?
+ }
+ Statement::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ write!(self.out, "{}", level)?;
+ writeln!(self.out, "loop {{")?;
+
+ let l2 = level.next();
+ for sta in body.iter() {
+ self.write_stmt(module, sta, func_ctx, l2)?;
+ }
+
+ // The continuing is optional so we don't need to write it if
+ // it is empty, but the `break if` counts as a continuing statement
+ // so even if `continuing` is empty we must generate it if a
+ // `break if` exists
+ if !continuing.is_empty() || break_if.is_some() {
+ writeln!(self.out, "{}continuing {{", l2)?;
+ for sta in continuing.iter() {
+ self.write_stmt(module, sta, func_ctx, l2.next())?;
+ }
+
+ // The `break if` is always the last
+ // statement of the `continuing` block
+ if let Some(condition) = break_if {
+ // The trailing space is important
+ write!(self.out, "{}break if ", l2.next())?;
+ self.write_expr(module, condition, func_ctx)?;
+ // Close the `break if` statement
+ writeln!(self.out, ";")?;
+ }
+
+ writeln!(self.out, "{}}}", l2)?;
+ }
+
+ writeln!(self.out, "{}}}", level)?
+ }
+ Statement::Break => {
+ writeln!(self.out, "{}break;", level)?;
+ }
+ Statement::Continue => {
+ writeln!(self.out, "{}continue;", level)?;
+ }
+ Statement::Barrier(barrier) => {
+ if barrier.contains(crate::Barrier::STORAGE) {
+ writeln!(self.out, "{}storageBarrier();", level)?;
+ }
+
+ if barrier.contains(crate::Barrier::WORK_GROUP) {
+ writeln!(self.out, "{}workgroupBarrier();", level)?;
+ }
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Return the sort of indirection that `expr`'s plain form evaluates to.
+ ///
+ /// An expression's 'plain form' is the most general rendition of that
+ /// expression into WGSL, lacking `&` or `*` operators:
+ ///
+ /// - The plain form of `LocalVariable(x)` is simply `x`, which is a reference
+ /// to the local variable's storage.
+ ///
+ /// - The plain form of `GlobalVariable(g)` is simply `g`, which is usually a
+ /// reference to the global variable's storage. However, globals in the
+ /// `Handle` address space are immutable, and `GlobalVariable` expressions for
+ /// those produce the value directly, not a pointer to it. Such
+ /// `GlobalVariable` expressions are `Ordinary`.
+ ///
+ /// - `Access` and `AccessIndex` are `Reference` when their `base` operand is a
+ /// pointer. If they are applied directly to a composite value, they are
+ /// `Ordinary`.
+ ///
+ /// Note that `FunctionArgument` expressions are never `Reference`, even when
+ /// the argument's type is `Pointer`. `FunctionArgument` always evaluates to the
+ /// argument's value directly, so any pointer it produces is merely the value
+ /// passed by the caller.
+ fn plain_form_indirection(
+ &self,
+ expr: Handle<crate::Expression>,
+ module: &Module,
+ func_ctx: &back::FunctionCtx<'_>,
+ ) -> Indirection {
+ use crate::Expression as Ex;
+
+ // Named expressions are `let` expressions, which apply the Load Rule,
+ // so if their type is a Naga pointer, then that must be a WGSL pointer
+ // as well.
+ if self.named_expressions.contains_key(&expr) {
+ return Indirection::Ordinary;
+ }
+
+ match func_ctx.expressions[expr] {
+ Ex::LocalVariable(_) => Indirection::Reference,
+ Ex::GlobalVariable(handle) => {
+ let global = &module.global_variables[handle];
+ match global.space {
+ crate::AddressSpace::Handle => Indirection::Ordinary,
+ _ => Indirection::Reference,
+ }
+ }
+ Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => {
+ let base_ty = func_ctx.info[base].ty.inner_with(&module.types);
+ match *base_ty {
+ crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } => {
+ Indirection::Reference
+ }
+ _ => Indirection::Ordinary,
+ }
+ }
+ _ => Indirection::Ordinary,
+ }
+ }
+
+ fn start_named_expr(
+ &mut self,
+ module: &Module,
+ handle: Handle<crate::Expression>,
+ func_ctx: &back::FunctionCtx,
+ name: &str,
+ ) -> BackendResult {
+ // Write variable name
+ write!(self.out, "let {}", name)?;
+ if self.flags.contains(WriterFlags::EXPLICIT_TYPES) {
+ write!(self.out, ": ")?;
+ let ty = &func_ctx.info[handle].ty;
+ // Write variable type
+ match *ty {
+ proc::TypeResolution::Handle(handle) => {
+ self.write_type(module, handle)?;
+ }
+ proc::TypeResolution::Value(ref inner) => {
+ self.write_value_type(module, inner)?;
+ }
+ }
+ }
+
+ write!(self.out, " = ")?;
+ Ok(())
+ }
+
+ /// Write the ordinary WGSL form of `expr`.
+ ///
+ /// See `write_expr_with_indirection` for details.
+ fn write_expr(
+ &mut self,
+ module: &Module,
+ expr: Handle<crate::Expression>,
+ func_ctx: &back::FunctionCtx<'_>,
+ ) -> BackendResult {
+ self.write_expr_with_indirection(module, expr, func_ctx, Indirection::Ordinary)
+ }
+
+ /// Write `expr` as a WGSL expression with the requested indirection.
+ ///
+ /// In terms of the WGSL grammar, the resulting expression is a
+ /// `singular_expression`. It may be parenthesized. This makes it suitable
+ /// for use as the operand of a unary or binary operator without worrying
+ /// about precedence.
+ ///
+ /// This does not produce newlines or indentation.
+ ///
+ /// The `requested` argument indicates (roughly) whether Naga
+ /// `Pointer`-valued expressions represent WGSL references or pointers. See
+ /// `Indirection` for details.
+ fn write_expr_with_indirection(
+ &mut self,
+ module: &Module,
+ expr: Handle<crate::Expression>,
+ func_ctx: &back::FunctionCtx<'_>,
+ requested: Indirection,
+ ) -> BackendResult {
+ // If the plain form of the expression is not what we need, emit the
+ // operator necessary to correct that.
+ let plain = self.plain_form_indirection(expr, module, func_ctx);
+ match (requested, plain) {
+ (Indirection::Ordinary, Indirection::Reference) => {
+ write!(self.out, "(&")?;
+ self.write_expr_plain_form(module, expr, func_ctx, plain)?;
+ write!(self.out, ")")?;
+ }
+ (Indirection::Reference, Indirection::Ordinary) => {
+ write!(self.out, "(*")?;
+ self.write_expr_plain_form(module, expr, func_ctx, plain)?;
+ write!(self.out, ")")?;
+ }
+ (_, _) => self.write_expr_plain_form(module, expr, func_ctx, plain)?,
+ }
+
+ Ok(())
+ }
+
+ /// Write the 'plain form' of `expr`.
+ ///
+ /// An expression's 'plain form' is the most general rendition of that
+ /// expression into WGSL, lacking `&` or `*` operators. The plain forms of
+ /// `LocalVariable(x)` and `GlobalVariable(g)` are simply `x` and `g`. Such
+ /// Naga expressions represent both WGSL pointers and references; it's the
+ /// caller's responsibility to distinguish those cases appropriately.
+ fn write_expr_plain_form(
+ &mut self,
+ module: &Module,
+ expr: Handle<crate::Expression>,
+ func_ctx: &back::FunctionCtx<'_>,
+ indirection: Indirection,
+ ) -> BackendResult {
+ use crate::Expression;
+
+ if let Some(name) = self.named_expressions.get(&expr) {
+ write!(self.out, "{}", name)?;
+ return Ok(());
+ }
+
+ let expression = &func_ctx.expressions[expr];
+
+ // Write the plain WGSL form of a Naga expression.
+ //
+ // The plain form of `LocalVariable` and `GlobalVariable` expressions is
+ // simply the variable name; `*` and `&` operators are never emitted.
+ //
+ // The plain form of `Access` and `AccessIndex` expressions are WGSL
+ // `postfix_expression` forms for member/component access and
+ // subscripting.
+ match *expression {
+ Expression::Constant(constant) => self.write_constant(module, constant)?,
+ Expression::Compose { ty, ref components } => {
+ self.write_type(module, ty)?;
+ write!(self.out, "(")?;
+ // !spv-in specific notes!
+ // WGSL does not support all SPIR-V builtins and we should skip it in generated shaders.
+ // We already skip them when we generate struct type.
+ // Now we need to find components that used struct with ignored builtins.
+
+ // So, why we can't just return the error to a user?
+ // We can, but otherwise, we can't generate WGSL shader from any glslang SPIR-V shaders.
+ // glslang generates gl_PerVertex struct with gl_CullDistance, gl_ClipDistance and gl_PointSize builtin inside by default.
+ // All of them are not supported by WGSL.
+
+ // We need to copy components to another vec because we don't know which of them we should write.
+ let mut components_to_write = Vec::with_capacity(components.len());
+ for component in components {
+ let mut skip_component = false;
+ if let Expression::Load { pointer } = func_ctx.expressions[*component] {
+ if let Expression::AccessIndex { base, index } =
+ func_ctx.expressions[pointer]
+ {
+ if access_to_unsupported_builtin(base, index, module, func_ctx.info) {
+ skip_component = true;
+ }
+ }
+ }
+ if skip_component {
+ continue;
+ } else {
+ components_to_write.push(*component);
+ }
+ }
+
+ // non spv-in specific notes!
+ // Real `Expression::Compose` logic generates here.
+ for (index, component) in components_to_write.iter().enumerate() {
+ self.write_expr(module, *component, func_ctx)?;
+ // Only write a comma if isn't the last element
+ if index != components_to_write.len().saturating_sub(1) {
+ // The leading space is for readability only
+ write!(self.out, ", ")?;
+ }
+ }
+ write!(self.out, ")")?
+ }
+ Expression::FunctionArgument(pos) => {
+ let name_key = func_ctx.argument_key(pos);
+ let name = &self.names[&name_key];
+ write!(self.out, "{}", name)?;
+ }
+ Expression::Binary { op, left, right } => {
+ write!(self.out, "(")?;
+ self.write_expr(module, left, func_ctx)?;
+ write!(self.out, " {} ", back::binary_operation_str(op))?;
+ self.write_expr(module, right, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ Expression::Access { base, index } => {
+ self.write_expr_with_indirection(module, base, func_ctx, indirection)?;
+ write!(self.out, "[")?;
+ self.write_expr(module, index, func_ctx)?;
+ write!(self.out, "]")?
+ }
+ Expression::AccessIndex { base, index } => {
+ let base_ty_res = &func_ctx.info[base].ty;
+ let mut resolved = base_ty_res.inner_with(&module.types);
+
+ self.write_expr_with_indirection(module, base, func_ctx, indirection)?;
+
+ let base_ty_handle = match *resolved {
+ TypeInner::Pointer { base, space: _ } => {
+ resolved = &module.types[base].inner;
+ Some(base)
+ }
+ _ => base_ty_res.handle(),
+ };
+
+ match *resolved {
+ TypeInner::Vector { .. } => {
+ // Write vector access as a swizzle
+ write!(self.out, ".{}", back::COMPONENTS[index as usize])?
+ }
+ TypeInner::Matrix { .. }
+ | TypeInner::Array { .. }
+ | TypeInner::BindingArray { .. }
+ | TypeInner::ValuePointer { .. } => write!(self.out, "[{}]", index)?,
+ TypeInner::Struct { .. } => {
+ // This will never panic in case the type is a `Struct`, this is not true
+ // for other types so we can only check while inside this match arm
+ let ty = base_ty_handle.unwrap();
+
+ write!(
+ self.out,
+ ".{}",
+ &self.names[&NameKey::StructMember(ty, index)]
+ )?
+ }
+ ref other => return Err(Error::Custom(format!("Cannot index {:?}", other))),
+ }
+ }
+ Expression::ImageSample {
+ image,
+ sampler,
+ gather: None,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ } => {
+ use crate::SampleLevel as Sl;
+
+ let suffix_cmp = match depth_ref {
+ Some(_) => "Compare",
+ None => "",
+ };
+ let suffix_level = match level {
+ Sl::Auto => "",
+ Sl::Zero | Sl::Exact(_) => "Level",
+ Sl::Bias(_) => "Bias",
+ Sl::Gradient { .. } => "Grad",
+ };
+
+ write!(self.out, "textureSample{}{}(", suffix_cmp, suffix_level)?;
+ self.write_expr(module, image, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, sampler, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, coordinate, func_ctx)?;
+
+ if let Some(array_index) = array_index {
+ write!(self.out, ", ")?;
+ self.write_expr(module, array_index, func_ctx)?;
+ }
+
+ if let Some(depth_ref) = depth_ref {
+ write!(self.out, ", ")?;
+ self.write_expr(module, depth_ref, func_ctx)?;
+ }
+
+ match level {
+ Sl::Auto => {}
+ Sl::Zero => {
+ // Level 0 is implied for depth comparison
+ if depth_ref.is_none() {
+ write!(self.out, ", 0.0")?;
+ }
+ }
+ Sl::Exact(expr) => {
+ write!(self.out, ", ")?;
+ self.write_expr(module, expr, func_ctx)?;
+ }
+ Sl::Bias(expr) => {
+ write!(self.out, ", ")?;
+ self.write_expr(module, expr, func_ctx)?;
+ }
+ Sl::Gradient { x, y } => {
+ write!(self.out, ", ")?;
+ self.write_expr(module, x, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, y, func_ctx)?;
+ }
+ }
+
+ if let Some(offset) = offset {
+ write!(self.out, ", ")?;
+ self.write_constant(module, offset)?;
+ }
+
+ write!(self.out, ")")?;
+ }
+ Expression::ImageSample {
+ image,
+ sampler,
+ gather: Some(component),
+ coordinate,
+ array_index,
+ offset,
+ level: _,
+ depth_ref,
+ } => {
+ let suffix_cmp = match depth_ref {
+ Some(_) => "Compare",
+ None => "",
+ };
+
+ write!(self.out, "textureGather{}(", suffix_cmp)?;
+ match *func_ctx.info[image].ty.inner_with(&module.types) {
+ TypeInner::Image {
+ class: crate::ImageClass::Depth { multi: _ },
+ ..
+ } => {}
+ _ => {
+ write!(self.out, "{}, ", component as u8)?;
+ }
+ }
+ self.write_expr(module, image, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, sampler, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, coordinate, func_ctx)?;
+
+ if let Some(array_index) = array_index {
+ write!(self.out, ", ")?;
+ self.write_expr(module, array_index, func_ctx)?;
+ }
+
+ if let Some(depth_ref) = depth_ref {
+ write!(self.out, ", ")?;
+ self.write_expr(module, depth_ref, func_ctx)?;
+ }
+
+ if let Some(offset) = offset {
+ write!(self.out, ", ")?;
+ self.write_constant(module, offset)?;
+ }
+
+ write!(self.out, ")")?;
+ }
+ Expression::ImageQuery { image, query } => {
+ use crate::ImageQuery as Iq;
+
+ let texture_function = match query {
+ Iq::Size { .. } => "textureDimensions",
+ Iq::NumLevels => "textureNumLevels",
+ Iq::NumLayers => "textureNumLayers",
+ Iq::NumSamples => "textureNumSamples",
+ };
+
+ write!(self.out, "{}(", texture_function)?;
+ self.write_expr(module, image, func_ctx)?;
+ if let Iq::Size { level: Some(level) } = query {
+ write!(self.out, ", ")?;
+ self.write_expr(module, level, func_ctx)?;
+ };
+ write!(self.out, ")")?;
+ }
+ Expression::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => {
+ write!(self.out, "textureLoad(")?;
+ self.write_expr(module, image, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, coordinate, func_ctx)?;
+ if let Some(array_index) = array_index {
+ write!(self.out, ", ")?;
+ self.write_expr(module, array_index, func_ctx)?;
+ }
+ if let Some(index) = sample.or(level) {
+ write!(self.out, ", ")?;
+ self.write_expr(module, index, func_ctx)?;
+ }
+ write!(self.out, ")")?;
+ }
+ Expression::GlobalVariable(handle) => {
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ write!(self.out, "{}", name)?;
+ }
+ Expression::As {
+ expr,
+ kind,
+ convert,
+ } => {
+ let inner = func_ctx.info[expr].ty.inner_with(&module.types);
+ match *inner {
+ TypeInner::Matrix { columns, rows, .. } => {
+ write!(
+ self.out,
+ "mat{}x{}<f32>",
+ back::vector_size_str(columns),
+ back::vector_size_str(rows)
+ )?;
+ }
+ TypeInner::Vector { size, .. } => {
+ let vector_size_str = back::vector_size_str(size);
+ let scalar_kind_str = scalar_kind_str(kind);
+ if convert.is_some() {
+ write!(self.out, "vec{}<{}>", vector_size_str, scalar_kind_str)?;
+ } else {
+ write!(
+ self.out,
+ "bitcast<vec{}<{}>>",
+ vector_size_str, scalar_kind_str
+ )?;
+ }
+ }
+ TypeInner::Scalar { .. } => {
+ if convert.is_some() {
+ write!(self.out, "{}", scalar_kind_str(kind))?
+ } else {
+ write!(self.out, "bitcast<{}>", scalar_kind_str(kind))?
+ }
+ }
+ _ => {
+ return Err(Error::Unimplemented(format!(
+ "write_expr expression::as {:?}",
+ inner
+ )));
+ }
+ };
+ write!(self.out, "(")?;
+ self.write_expr(module, expr, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ Expression::Splat { size, value } => {
+ let inner = func_ctx.info[value].ty.inner_with(&module.types);
+ let scalar_kind = match *inner {
+ crate::TypeInner::Scalar { kind, .. } => kind,
+ _ => {
+ return Err(Error::Unimplemented(format!(
+ "write_expr expression::splat {:?}",
+ inner
+ )));
+ }
+ };
+ let scalar = scalar_kind_str(scalar_kind);
+ let size = back::vector_size_str(size);
+
+ write!(self.out, "vec{}<{}>(", size, scalar)?;
+ self.write_expr(module, value, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ Expression::Load { pointer } => {
+ let is_atomic = match *func_ctx.info[pointer].ty.inner_with(&module.types) {
+ crate::TypeInner::Pointer { base, .. } => match module.types[base].inner {
+ crate::TypeInner::Atomic { .. } => true,
+ _ => false,
+ },
+ _ => false,
+ };
+
+ if is_atomic {
+ write!(self.out, "atomicLoad(")?;
+ self.write_expr(module, pointer, func_ctx)?;
+ write!(self.out, ")")?;
+ } else {
+ self.write_expr_with_indirection(
+ module,
+ pointer,
+ func_ctx,
+ Indirection::Reference,
+ )?;
+ }
+ }
+ Expression::LocalVariable(handle) => {
+ write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
+ }
+ Expression::ArrayLength(expr) => {
+ write!(self.out, "arrayLength(")?;
+ self.write_expr(module, expr, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ use crate::MathFunction as Mf;
+
+ enum Function {
+ Asincosh { is_sin: bool },
+ Atanh,
+ Regular(&'static str),
+ }
+
+ // NOTE: If https://github.com/gpuweb/gpuweb/issues/1622 ever is
+ // accepted, replace this with the builtin functions
+ let function = match fun {
+ Mf::Abs => Function::Regular("abs"),
+ Mf::Min => Function::Regular("min"),
+ Mf::Max => Function::Regular("max"),
+ Mf::Clamp => Function::Regular("clamp"),
+ Mf::Saturate => Function::Regular("saturate"),
+ // trigonometry
+ Mf::Cos => Function::Regular("cos"),
+ Mf::Cosh => Function::Regular("cosh"),
+ Mf::Sin => Function::Regular("sin"),
+ Mf::Sinh => Function::Regular("sinh"),
+ Mf::Tan => Function::Regular("tan"),
+ Mf::Tanh => Function::Regular("tanh"),
+ Mf::Acos => Function::Regular("acos"),
+ Mf::Asin => Function::Regular("asin"),
+ Mf::Atan => Function::Regular("atan"),
+ Mf::Atan2 => Function::Regular("atan2"),
+ Mf::Asinh => Function::Asincosh { is_sin: true },
+ Mf::Acosh => Function::Asincosh { is_sin: false },
+ Mf::Atanh => Function::Atanh,
+ Mf::Radians => Function::Regular("radians"),
+ Mf::Degrees => Function::Regular("degrees"),
+ // decomposition
+ Mf::Ceil => Function::Regular("ceil"),
+ Mf::Floor => Function::Regular("floor"),
+ Mf::Round => Function::Regular("round"),
+ Mf::Fract => Function::Regular("fract"),
+ Mf::Trunc => Function::Regular("trunc"),
+ Mf::Modf => Function::Regular("modf"),
+ Mf::Frexp => Function::Regular("frexp"),
+ Mf::Ldexp => Function::Regular("ldexp"),
+ // exponent
+ Mf::Exp => Function::Regular("exp"),
+ Mf::Exp2 => Function::Regular("exp2"),
+ Mf::Log => Function::Regular("log"),
+ Mf::Log2 => Function::Regular("log2"),
+ Mf::Pow => Function::Regular("pow"),
+ // geometry
+ Mf::Dot => Function::Regular("dot"),
+ Mf::Outer => Function::Regular("outerProduct"),
+ Mf::Cross => Function::Regular("cross"),
+ Mf::Distance => Function::Regular("distance"),
+ Mf::Length => Function::Regular("length"),
+ Mf::Normalize => Function::Regular("normalize"),
+ Mf::FaceForward => Function::Regular("faceForward"),
+ Mf::Reflect => Function::Regular("reflect"),
+ // computational
+ Mf::Sign => Function::Regular("sign"),
+ Mf::Fma => Function::Regular("fma"),
+ Mf::Mix => Function::Regular("mix"),
+ Mf::Step => Function::Regular("step"),
+ Mf::SmoothStep => Function::Regular("smoothstep"),
+ Mf::Sqrt => Function::Regular("sqrt"),
+ Mf::InverseSqrt => Function::Regular("inverseSqrt"),
+ Mf::Transpose => Function::Regular("transpose"),
+ Mf::Determinant => Function::Regular("determinant"),
+ // bits
+ Mf::CountOneBits => Function::Regular("countOneBits"),
+ Mf::ReverseBits => Function::Regular("reverseBits"),
+ Mf::ExtractBits => Function::Regular("extractBits"),
+ Mf::InsertBits => Function::Regular("insertBits"),
+ Mf::FindLsb => Function::Regular("firstTrailingBit"),
+ Mf::FindMsb => Function::Regular("firstLeadingBit"),
+ // data packing
+ Mf::Pack4x8snorm => Function::Regular("pack4x8snorm"),
+ Mf::Pack4x8unorm => Function::Regular("pack4x8unorm"),
+ Mf::Pack2x16snorm => Function::Regular("pack2x16snorm"),
+ Mf::Pack2x16unorm => Function::Regular("pack2x16unorm"),
+ Mf::Pack2x16float => Function::Regular("pack2x16float"),
+ // data unpacking
+ Mf::Unpack4x8snorm => Function::Regular("unpack4x8snorm"),
+ Mf::Unpack4x8unorm => Function::Regular("unpack4x8unorm"),
+ Mf::Unpack2x16snorm => Function::Regular("unpack2x16snorm"),
+ Mf::Unpack2x16unorm => Function::Regular("unpack2x16unorm"),
+ Mf::Unpack2x16float => Function::Regular("unpack2x16float"),
+ _ => {
+ return Err(Error::UnsupportedMathFunction(fun));
+ }
+ };
+
+ match function {
+ Function::Asincosh { is_sin } => {
+ write!(self.out, "log(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " + sqrt(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " * ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ match is_sin {
+ true => write!(self.out, " + 1.0))")?,
+ false => write!(self.out, " - 1.0))")?,
+ }
+ }
+ Function::Atanh => {
+ write!(self.out, "0.5 * log((1.0 + ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ") / (1.0 - ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "))")?;
+ }
+ Function::Regular(fun_name) => {
+ write!(self.out, "{}(", fun_name)?;
+ self.write_expr(module, arg, func_ctx)?;
+ for arg in IntoIterator::into_iter([arg1, arg2, arg3]).flatten() {
+ write!(self.out, ", ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ }
+ write!(self.out, ")")?
+ }
+ }
+ }
+ Expression::Swizzle {
+ size,
+ vector,
+ pattern,
+ } => {
+ self.write_expr(module, vector, func_ctx)?;
+ write!(self.out, ".")?;
+ for &sc in pattern[..size as usize].iter() {
+ self.out.write_char(back::COMPONENTS[sc as usize])?;
+ }
+ }
+ Expression::Unary { op, expr } => {
+ let unary = match op {
+ crate::UnaryOperator::Negate => "-",
+ crate::UnaryOperator::Not => {
+ match *func_ctx.info[expr].ty.inner_with(&module.types) {
+ TypeInner::Scalar {
+ kind: crate::ScalarKind::Bool,
+ ..
+ }
+ | TypeInner::Vector { .. } => "!",
+ _ => "~",
+ }
+ }
+ };
+
+ write!(self.out, "{}(", unary)?;
+ self.write_expr(module, expr, func_ctx)?;
+
+ write!(self.out, ")")?
+ }
+ Expression::Select {
+ condition,
+ accept,
+ reject,
+ } => {
+ write!(self.out, "select(")?;
+ self.write_expr(module, reject, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, accept, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, condition, func_ctx)?;
+ write!(self.out, ")")?
+ }
+ Expression::Derivative { axis, expr } => {
+ use crate::DerivativeAxis as Da;
+
+ let op = match axis {
+ Da::X => "dpdx",
+ Da::Y => "dpdy",
+ Da::Width => "fwidth",
+ };
+ write!(self.out, "{}(", op)?;
+ self.write_expr(module, expr, func_ctx)?;
+ write!(self.out, ")")?
+ }
+ Expression::Relational { fun, argument } => {
+ use crate::RelationalFunction as Rf;
+
+ let fun_name = match fun {
+ Rf::IsFinite => "isFinite",
+ Rf::IsNormal => "isNormal",
+ Rf::All => "all",
+ Rf::Any => "any",
+ _ => return Err(Error::UnsupportedRelationalFunction(fun)),
+ };
+ write!(self.out, "{}(", fun_name)?;
+
+ self.write_expr(module, argument, func_ctx)?;
+
+ write!(self.out, ")")?
+ }
+ // Nothing to do here, since call expression already cached
+ Expression::CallResult(_) | Expression::AtomicResult { .. } => {}
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write global variables
+ /// # Notes
+ /// Always adds a newline
+ fn write_global(
+ &mut self,
+ module: &Module,
+ global: &crate::GlobalVariable,
+ handle: Handle<crate::GlobalVariable>,
+ ) -> BackendResult {
+ // Write group and binding attributes if present
+ if let Some(ref binding) = global.binding {
+ self.write_attributes(&[
+ Attribute::Group(binding.group),
+ Attribute::Binding(binding.binding),
+ ])?;
+ writeln!(self.out)?;
+ }
+
+ // First write global name and address space if supported
+ write!(self.out, "var")?;
+ let (address, maybe_access) = address_space_str(global.space);
+ if let Some(space) = address {
+ write!(self.out, "<{}", space)?;
+ if let Some(access) = maybe_access {
+ write!(self.out, ", {}", access)?;
+ }
+ write!(self.out, ">")?;
+ }
+ write!(
+ self.out,
+ " {}: ",
+ &self.names[&NameKey::GlobalVariable(handle)]
+ )?;
+
+ // Write global type
+ self.write_type(module, global.ty)?;
+
+ // Write initializer
+ if let Some(init) = global.init {
+ write!(self.out, " = ")?;
+ self.write_constant(module, init)?;
+ }
+
+ // End with semicolon
+ writeln!(self.out, ";")?;
+
+ Ok(())
+ }
+
+ /// Helper method used to write constants
+ ///
+ /// # Notes
+ /// Doesn't add any newlines or leading/trailing spaces
+ fn write_constant(
+ &mut self,
+ module: &Module,
+ handle: Handle<crate::Constant>,
+ ) -> BackendResult {
+ let constant = &module.constants[handle];
+ match constant.inner {
+ crate::ConstantInner::Scalar {
+ width: _,
+ ref value,
+ } => {
+ if constant.name.is_some() {
+ write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
+ } else {
+ self.write_scalar_value(*value)?;
+ }
+ }
+ crate::ConstantInner::Composite { ty, ref components } => {
+ self.write_type(module, ty)?;
+ write!(self.out, "(")?;
+
+ let members = match module.types[ty].inner {
+ TypeInner::Struct { ref members, .. } => Some(members),
+ _ => None,
+ };
+
+ // Write the comma separated constants
+ for (index, constant) in components.iter().enumerate() {
+ if let Some(&crate::Binding::BuiltIn(built_in)) =
+ members.and_then(|members| members.get(index)?.binding.as_ref())
+ {
+ if builtin_str(built_in).is_none() {
+ log::warn!(
+ "Skip constant for struct member with unsupported builtin {:?}",
+ built_in
+ );
+ continue;
+ }
+ }
+
+ self.write_constant(module, *constant)?;
+ // Only write a comma if isn't the last element
+ if index != components.len().saturating_sub(1) {
+ // The leading space is for readability only
+ write!(self.out, ", ")?;
+ }
+ }
+ write!(self.out, ")")?
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write global constants
+ ///
+ /// # Notes
+ /// Ends in a newline
+ fn write_global_constant(
+ &mut self,
+ module: &Module,
+ inner: &crate::ConstantInner,
+ handle: Handle<crate::Constant>,
+ ) -> BackendResult {
+ match *inner {
+ crate::ConstantInner::Scalar {
+ width: _,
+ ref value,
+ } => {
+ let name = &self.names[&NameKey::Constant(handle)];
+ // First write only constant name
+ write!(self.out, "let {}: ", name)?;
+ // Next write constant type and value
+ match *value {
+ crate::ScalarValue::Sint(value) => {
+ write!(self.out, "i32 = {}", value)?;
+ }
+ crate::ScalarValue::Uint(value) => {
+ write!(self.out, "u32 = {}u", value)?;
+ }
+ crate::ScalarValue::Float(value) => {
+ // Floats are written using `Debug` instead of `Display` because it always appends the
+ // decimal part even it's zero
+ write!(self.out, "f32 = {:?}", value)?;
+ }
+ crate::ScalarValue::Bool(value) => {
+ write!(self.out, "bool = {}", value)?;
+ }
+ };
+ // End with semicolon
+ writeln!(self.out, ";")?;
+ }
+ crate::ConstantInner::Composite { ty, ref components } => {
+ let name = &self.names[&NameKey::Constant(handle)];
+ // First write only constant name
+ write!(self.out, "let {}: ", name)?;
+ // Next write constant type
+ self.write_type(module, ty)?;
+
+ write!(self.out, " = ")?;
+ self.write_type(module, ty)?;
+
+ write!(self.out, "(")?;
+ for (index, constant) in components.iter().enumerate() {
+ self.write_constant(module, *constant)?;
+ // Only write a comma if isn't the last element
+ if index != components.len().saturating_sub(1) {
+ // The leading space is for readability only
+ write!(self.out, ", ")?;
+ }
+ }
+ write!(self.out, ");")?;
+ }
+ }
+ // End with extra newline for readability
+ writeln!(self.out)?;
+ Ok(())
+ }
+
+ // See https://github.com/rust-lang/rust-clippy/issues/4979.
+ #[allow(clippy::missing_const_for_fn)]
+ pub fn finish(self) -> W {
+ self.out
+ }
+}
+
+const fn builtin_str(built_in: crate::BuiltIn) -> Option<&'static str> {
+ use crate::BuiltIn as Bi;
+
+ match built_in {
+ Bi::VertexIndex => Some("vertex_index"),
+ Bi::InstanceIndex => Some("instance_index"),
+ Bi::Position { .. } => Some("position"),
+ Bi::FrontFacing => Some("front_facing"),
+ Bi::FragDepth => Some("frag_depth"),
+ Bi::LocalInvocationId => Some("local_invocation_id"),
+ Bi::LocalInvocationIndex => Some("local_invocation_index"),
+ Bi::GlobalInvocationId => Some("global_invocation_id"),
+ Bi::WorkGroupId => Some("workgroup_id"),
+ Bi::WorkGroupSize => Some("workgroup_size"),
+ Bi::NumWorkGroups => Some("num_workgroups"),
+ Bi::SampleIndex => Some("sample_index"),
+ Bi::SampleMask => Some("sample_mask"),
+ Bi::PrimitiveIndex => Some("primitive_index"),
+ Bi::ViewIndex => Some("view_index"),
+ _ => None,
+ }
+}
+
+const fn image_dimension_str(dim: crate::ImageDimension) -> &'static str {
+ use crate::ImageDimension as IDim;
+
+ match dim {
+ IDim::D1 => "1d",
+ IDim::D2 => "2d",
+ IDim::D3 => "3d",
+ IDim::Cube => "cube",
+ }
+}
+
+const fn scalar_kind_str(kind: crate::ScalarKind) -> &'static str {
+ use crate::ScalarKind as Sk;
+
+ match kind {
+ Sk::Float => "f32",
+ Sk::Sint => "i32",
+ Sk::Uint => "u32",
+ Sk::Bool => "bool",
+ }
+}
+
+const fn storage_format_str(format: crate::StorageFormat) -> &'static str {
+ use crate::StorageFormat as Sf;
+
+ match format {
+ Sf::R8Unorm => "r8unorm",
+ Sf::R8Snorm => "r8snorm",
+ Sf::R8Uint => "r8uint",
+ Sf::R8Sint => "r8sint",
+ Sf::R16Uint => "r16uint",
+ Sf::R16Sint => "r16sint",
+ Sf::R16Float => "r16float",
+ Sf::Rg8Unorm => "rg8unorm",
+ Sf::Rg8Snorm => "rg8snorm",
+ Sf::Rg8Uint => "rg8uint",
+ Sf::Rg8Sint => "rg8sint",
+ Sf::R32Uint => "r32uint",
+ Sf::R32Sint => "r32sint",
+ Sf::R32Float => "r32float",
+ Sf::Rg16Uint => "rg16uint",
+ Sf::Rg16Sint => "rg16sint",
+ Sf::Rg16Float => "rg16float",
+ Sf::Rgba8Unorm => "rgba8unorm",
+ Sf::Rgba8Snorm => "rgba8snorm",
+ Sf::Rgba8Uint => "rgba8uint",
+ Sf::Rgba8Sint => "rgba8sint",
+ Sf::Rgb10a2Unorm => "rgb10a2unorm",
+ Sf::Rg11b10Float => "rg11b10float",
+ Sf::Rg32Uint => "rg32uint",
+ Sf::Rg32Sint => "rg32sint",
+ Sf::Rg32Float => "rg32float",
+ Sf::Rgba16Uint => "rgba16uint",
+ Sf::Rgba16Sint => "rgba16sint",
+ Sf::Rgba16Float => "rgba16float",
+ Sf::Rgba32Uint => "rgba32uint",
+ Sf::Rgba32Sint => "rgba32sint",
+ Sf::Rgba32Float => "rgba32float",
+ }
+}
+
+/// Helper function that returns the string corresponding to the WGSL interpolation qualifier
+const fn interpolation_str(interpolation: crate::Interpolation) -> &'static str {
+ use crate::Interpolation as I;
+
+ match interpolation {
+ I::Perspective => "perspective",
+ I::Linear => "linear",
+ I::Flat => "flat",
+ }
+}
+
+/// Return the WGSL auxiliary qualifier for the given sampling value.
+const fn sampling_str(sampling: crate::Sampling) -> &'static str {
+ use crate::Sampling as S;
+
+ match sampling {
+ S::Center => "",
+ S::Centroid => "centroid",
+ S::Sample => "sample",
+ }
+}
+
+const fn address_space_str(
+ space: crate::AddressSpace,
+) -> (Option<&'static str>, Option<&'static str>) {
+ use crate::AddressSpace as As;
+
+ (
+ Some(match space {
+ As::Private => "private",
+ As::Uniform => "uniform",
+ As::Storage { access } => {
+ if access.contains(crate::StorageAccess::STORE) {
+ return (Some("storage"), Some("read_write"));
+ } else {
+ "storage"
+ }
+ }
+ As::PushConstant => "push_constant",
+ As::WorkGroup => "workgroup",
+ As::Handle => return (None, None),
+ As::Function => "function",
+ }),
+ None,
+ )
+}
+
+fn map_binding_to_attribute(
+ binding: &crate::Binding,
+ scalar_kind: Option<crate::ScalarKind>,
+) -> Vec<Attribute> {
+ match *binding {
+ crate::Binding::BuiltIn(built_in) => {
+ if let crate::BuiltIn::Position { invariant: true } = built_in {
+ vec![Attribute::BuiltIn(built_in), Attribute::Invariant]
+ } else {
+ vec![Attribute::BuiltIn(built_in)]
+ }
+ }
+ crate::Binding::Location {
+ location,
+ interpolation,
+ sampling,
+ } => match scalar_kind {
+ Some(crate::ScalarKind::Float) => vec![
+ Attribute::Location(location),
+ Attribute::Interpolate(interpolation, sampling),
+ ],
+ _ => vec![Attribute::Location(location)],
+ },
+ }
+}
+
+/// Helper function that check that expression don't access to structure member with unsupported builtin.
+fn access_to_unsupported_builtin(
+ expr: Handle<crate::Expression>,
+ index: u32,
+ module: &Module,
+ info: &valid::FunctionInfo,
+) -> bool {
+ let base_ty_res = &info[expr].ty;
+ let resolved = base_ty_res.inner_with(&module.types);
+ if let TypeInner::Pointer {
+ base: pointer_base_handle,
+ ..
+ } = *resolved
+ {
+ // Let's check that we try to access a struct member with unsupported built-in and skip it.
+ if let TypeInner::Struct { ref members, .. } = module.types[pointer_base_handle].inner {
+ if let Some(crate::Binding::BuiltIn(built_in)) = members[index as usize].binding {
+ if builtin_str(built_in).is_none() {
+ log::warn!("Skip component with unsupported builtin {:?}", built_in);
+ return true;
+ }
+ }
+ }
+ }
+
+ false
+}