summaryrefslogtreecommitdiffstats
path: root/third_party/rust/naga/src/back/msl
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 19:33:14 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 19:33:14 +0000
commit36d22d82aa202bb199967e9512281e9a53db42c9 (patch)
tree105e8c98ddea1c1e4784a60a5a6410fa416be2de /third_party/rust/naga/src/back/msl
parentInitial commit. (diff)
downloadfirefox-esr-upstream.tar.xz
firefox-esr-upstream.zip
Adding upstream version 115.7.0esr.upstream/115.7.0esrupstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/naga/src/back/msl')
-rw-r--r--third_party/rust/naga/src/back/msl/keywords.rs217
-rw-r--r--third_party/rust/naga/src/back/msl/mod.rs489
-rw-r--r--third_party/rust/naga/src/back/msl/sampler.rs175
-rw-r--r--third_party/rust/naga/src/back/msl/writer.rs4429
4 files changed, 5310 insertions, 0 deletions
diff --git a/third_party/rust/naga/src/back/msl/keywords.rs b/third_party/rust/naga/src/back/msl/keywords.rs
new file mode 100644
index 0000000000..a3a9c52dcc
--- /dev/null
+++ b/third_party/rust/naga/src/back/msl/keywords.rs
@@ -0,0 +1,217 @@
+//TODO: find a complete list
+pub const RESERVED: &[&str] = &[
+ // control flow
+ "break",
+ "if",
+ "else",
+ "continue",
+ "goto",
+ "do",
+ "while",
+ "for",
+ "switch",
+ "case",
+ // types and values
+ "void",
+ "unsigned",
+ "signed",
+ "bool",
+ "char",
+ "int",
+ "uint",
+ "long",
+ "float",
+ "double",
+ "char8_t",
+ "wchar_t",
+ "true",
+ "false",
+ "nullptr",
+ "union",
+ "class",
+ "struct",
+ "enum",
+ // other
+ "main",
+ "using",
+ "decltype",
+ "sizeof",
+ "typeof",
+ "typedef",
+ "explicit",
+ "export",
+ "friend",
+ "namespace",
+ "operator",
+ "public",
+ "template",
+ "typename",
+ "typeid",
+ "co_await",
+ "co_return",
+ "co_yield",
+ "module",
+ "import",
+ "ray_data",
+ "vec_step",
+ "visible",
+ "as_type",
+ "this",
+ // qualifiers
+ "mutable",
+ "static",
+ "volatile",
+ "restrict",
+ "const",
+ "non-temporal",
+ "dereferenceable",
+ "invariant",
+ // exceptions
+ "throw",
+ "try",
+ "catch",
+ // operators
+ "const_cast",
+ "dynamic_cast",
+ "reinterpret_cast",
+ "static_cast",
+ "new",
+ "delete",
+ "and",
+ "and_eq",
+ "bitand",
+ "bitor",
+ "compl",
+ "not",
+ "not_eq",
+ "or",
+ "or_eq",
+ "xor",
+ "xor_eq",
+ "compl",
+ // Metal-specific
+ "constant",
+ "device",
+ "threadgroup",
+ "threadgroup_imageblock",
+ "kernel",
+ "compute",
+ "vertex",
+ "fragment",
+ "read_only",
+ "write_only",
+ "read_write",
+ "auto",
+ // Metal reserved types
+ "llong",
+ "ullong",
+ "quad",
+ "complex",
+ "imaginary",
+ // Metal constants
+ "CHAR_BIT",
+ "SCHAR_MAX",
+ "SCHAR_MIN",
+ "UCHAR_MAX",
+ "CHAR_MAX",
+ "CHAR_MIN",
+ "USHRT_MAX",
+ "SHRT_MAX",
+ "SHRT_MIN",
+ "UINT_MAX",
+ "INT_MAX",
+ "INT_MIN",
+ "ULONG_MAX",
+ "LONG_MAX",
+ "LONG_MIN",
+ "ULLONG_MAX",
+ "LLONG_MAX",
+ "LLONG_MIN",
+ "FLT_DIG",
+ "FLT_MANT_DIG",
+ "FLT_MAX_10_EXP",
+ "FLT_MAX_EXP",
+ "FLT_MIN_10_EXP",
+ "FLT_MIN_EXP",
+ "FLT_RADIX",
+ "FLT_MAX",
+ "FLT_MIN",
+ "FLT_EPSILON",
+ "FLT_DECIMAL_DIG",
+ "FP_ILOGB0",
+ "FP_ILOGB0",
+ "FP_ILOGBNAN",
+ "FP_ILOGBNAN",
+ "MAXFLOAT",
+ "HUGE_VALF",
+ "INFINITY",
+ "NAN",
+ "M_E_F",
+ "M_LOG2E_F",
+ "M_LOG10E_F",
+ "M_LN2_F",
+ "M_LN10_F",
+ "M_PI_F",
+ "M_PI_2_F",
+ "M_PI_4_F",
+ "M_1_PI_F",
+ "M_2_PI_F",
+ "M_2_SQRTPI_F",
+ "M_SQRT2_F",
+ "M_SQRT1_2_F",
+ "HALF_DIG",
+ "HALF_MANT_DIG",
+ "HALF_MAX_10_EXP",
+ "HALF_MAX_EXP",
+ "HALF_MIN_10_EXP",
+ "HALF_MIN_EXP",
+ "HALF_RADIX",
+ "HALF_MAX",
+ "HALF_MIN",
+ "HALF_EPSILON",
+ "HALF_DECIMAL_DIG",
+ "MAXHALF",
+ "HUGE_VALH",
+ "M_E_H",
+ "M_LOG2E_H",
+ "M_LOG10E_H",
+ "M_LN2_H",
+ "M_LN10_H",
+ "M_PI_H",
+ "M_PI_2_H",
+ "M_PI_4_H",
+ "M_1_PI_H",
+ "M_2_PI_H",
+ "M_2_SQRTPI_H",
+ "M_SQRT2_H",
+ "M_SQRT1_2_H",
+ "DBL_DIG",
+ "DBL_MANT_DIG",
+ "DBL_MAX_10_EXP",
+ "DBL_MAX_EXP",
+ "DBL_MIN_10_EXP",
+ "DBL_MIN_EXP",
+ "DBL_RADIX",
+ "DBL_MAX",
+ "DBL_MIN",
+ "DBL_EPSILON",
+ "DBL_DECIMAL_DIG",
+ "MAXDOUBLE",
+ "HUGE_VAL",
+ "M_E",
+ "M_LOG2E",
+ "M_LOG10E",
+ "M_LN2",
+ "M_LN10",
+ "M_PI",
+ "M_PI_2",
+ "M_PI_4",
+ "M_1_PI",
+ "M_2_PI",
+ "M_2_SQRTPI",
+ "M_SQRT2",
+ "M_SQRT1_2",
+ // Naga utilities
+ "DefaultConstructible",
+ "clamped_lod_e",
+];
diff --git a/third_party/rust/naga/src/back/msl/mod.rs b/third_party/rust/naga/src/back/msl/mod.rs
new file mode 100644
index 0000000000..3174d4b756
--- /dev/null
+++ b/third_party/rust/naga/src/back/msl/mod.rs
@@ -0,0 +1,489 @@
+/*!
+Backend for [MSL][msl] (Metal Shading Language).
+
+## Binding model
+
+Metal's bindings are flat per resource. Since there isn't an obvious mapping
+from SPIR-V's descriptor sets, we require a separate mapping provided in the options.
+This mapping may have one or more resource end points for each descriptor set + index
+pair.
+
+## Entry points
+
+Even though MSL and our IR appear to be similar in that the entry points in both can
+accept arguments and return values, the restrictions are different.
+MSL allows the varyings to be either in separate arguments, or inside a single
+`[[stage_in]]` struct. We gather input varyings and form this artificial structure.
+We also add all the (non-Private) globals into the arguments.
+
+At the beginning of the entry point, we assign the local constants and re-compose
+the arguments as they are declared on IR side, so that the rest of the logic can
+pretend that MSL doesn't have all the restrictions it has.
+
+For the result type, if it's a structure, we re-compose it with a temporary value
+holding the result.
+
+[msl]: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
+*/
+
+use crate::{arena::Handle, proc::index, valid::ModuleInfo};
+use std::fmt::{Error as FmtError, Write};
+
+mod keywords;
+pub mod sampler;
+mod writer;
+
+pub use writer::Writer;
+
+pub type Slot = u8;
+pub type InlineSamplerIndex = u8;
+
+#[derive(Clone, Debug, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub enum BindSamplerTarget {
+ Resource(Slot),
+ Inline(InlineSamplerIndex),
+}
+
+#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))]
+pub struct BindTarget {
+ pub buffer: Option<Slot>,
+ pub texture: Option<Slot>,
+ pub sampler: Option<BindSamplerTarget>,
+ /// If the binding is an unsized binding array, this overrides the size.
+ pub binding_array_size: Option<u32>,
+ pub mutable: bool,
+}
+
+// Using `BTreeMap` instead of `HashMap` so that we can hash itself.
+pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, BindTarget>;
+
+#[derive(Clone, Debug, Default, Hash, Eq, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))]
+pub struct EntryPointResources {
+ pub resources: BindingMap,
+
+ pub push_constant_buffer: Option<Slot>,
+
+ /// The slot of a buffer that contains an array of `u32`,
+ /// one for the size of each bound buffer that contains a runtime array,
+ /// in order of [`crate::GlobalVariable`] declarations.
+ pub sizes_buffer: Option<Slot>,
+}
+
+pub type EntryPointResourceMap = std::collections::BTreeMap<String, EntryPointResources>;
+
+enum ResolvedBinding {
+ BuiltIn(crate::BuiltIn),
+ Attribute(u32),
+ Color(u32),
+ User {
+ prefix: &'static str,
+ index: u32,
+ interpolation: Option<ResolvedInterpolation>,
+ },
+ Resource(BindTarget),
+}
+
+#[derive(Copy, Clone)]
+enum ResolvedInterpolation {
+ CenterPerspective,
+ CenterNoPerspective,
+ CentroidPerspective,
+ CentroidNoPerspective,
+ SamplePerspective,
+ SampleNoPerspective,
+ Flat,
+}
+
+// Note: some of these should be removed in favor of proper IR validation.
+
+#[derive(Debug, thiserror::Error)]
+pub enum Error {
+ #[error(transparent)]
+ Format(#[from] FmtError),
+ #[error("bind target {0:?} is empty")]
+ UnimplementedBindTarget(BindTarget),
+ #[error("composing of {0:?} is not implemented yet")]
+ UnsupportedCompose(Handle<crate::Type>),
+ #[error("operation {0:?} is not implemented yet")]
+ UnsupportedBinaryOp(crate::BinaryOperator),
+ #[error("standard function '{0}' is not implemented yet")]
+ UnsupportedCall(String),
+ #[error("feature '{0}' is not implemented yet")]
+ FeatureNotImplemented(String),
+ #[error("module is not valid")]
+ Validation,
+ #[error("BuiltIn {0:?} is not supported")]
+ UnsupportedBuiltIn(crate::BuiltIn),
+ #[error("capability {0:?} is not supported")]
+ CapabilityNotSupported(crate::valid::Capabilities),
+ #[error("address space {0:?} is not supported for target MSL version")]
+ UnsupportedAddressSpace(crate::AddressSpace),
+ #[error("attribute '{0}' is not supported for target MSL version")]
+ UnsupportedAttribute(String),
+}
+
+#[derive(Clone, Debug, PartialEq, thiserror::Error)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub enum EntryPointError {
+ #[error("global '{0}' doesn't have a binding")]
+ MissingBinding(String),
+ #[error("mapping of {0:?} is missing")]
+ MissingBindTarget(crate::ResourceBinding),
+ #[error("mapping for push constants is missing")]
+ MissingPushConstants,
+ #[error("mapping for sizes buffer is missing")]
+ MissingSizesBuffer,
+}
+
+/// Points in the MSL code where we might emit a pipeline input or output.
+///
+/// Note that, even though vertex shaders' outputs are always fragment
+/// shaders' inputs, we still need to distinguish `VertexOutput` and
+/// `FragmentInput`, since there are certain differences in the way
+/// [`ResolvedBinding`s] are represented on either side.
+///
+/// [`ResolvedBinding`s]: ResolvedBinding
+#[derive(Clone, Copy, Debug)]
+enum LocationMode {
+ /// Input to the vertex shader.
+ VertexInput,
+
+ /// Output from the vertex shader.
+ VertexOutput,
+
+ /// Input to the fragment shader.
+ FragmentInput,
+
+ /// Output from the fragment shader.
+ FragmentOutput,
+
+ /// Compute shader input or output.
+ Uniform,
+}
+
+#[derive(Clone, Debug, Hash, PartialEq, Eq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct Options {
+ /// (Major, Minor) target version of the Metal Shading Language.
+ pub lang_version: (u8, u8),
+ /// Map of entry-point resources, indexed by entry point function name, to slots.
+ pub per_entry_point_map: EntryPointResourceMap,
+ /// Samplers to be inlined into the code.
+ pub inline_samplers: Vec<sampler::InlineSampler>,
+ /// Make it possible to link different stages via SPIRV-Cross.
+ pub spirv_cross_compatibility: bool,
+ /// Don't panic on missing bindings, instead generate invalid MSL.
+ pub fake_missing_bindings: bool,
+ /// Bounds checking policies.
+ #[cfg_attr(feature = "deserialize", serde(default))]
+ pub bounds_check_policies: index::BoundsCheckPolicies,
+ /// Should workgroup variables be zero initialized (by polyfilling)?
+ pub zero_initialize_workgroup_memory: bool,
+}
+
+impl Default for Options {
+ fn default() -> Self {
+ Options {
+ lang_version: (2, 0),
+ per_entry_point_map: EntryPointResourceMap::default(),
+ inline_samplers: Vec::new(),
+ spirv_cross_compatibility: false,
+ fake_missing_bindings: true,
+ bounds_check_policies: index::BoundsCheckPolicies::default(),
+ zero_initialize_workgroup_memory: true,
+ }
+ }
+}
+
+/// A subset of options that are meant to be changed per pipeline.
+#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct PipelineOptions {
+ /// Allow `BuiltIn::PointSize` in the vertex shader.
+ ///
+ /// Metal doesn't like this for non-point primitive topologies.
+ pub allow_point_size: bool,
+}
+
+impl Options {
+ fn resolve_local_binding(
+ &self,
+ binding: &crate::Binding,
+ mode: LocationMode,
+ ) -> Result<ResolvedBinding, Error> {
+ match *binding {
+ crate::Binding::BuiltIn(mut built_in) => {
+ if let crate::BuiltIn::Position { ref mut invariant } = built_in {
+ if *invariant && self.lang_version < (2, 1) {
+ return Err(Error::UnsupportedAttribute("invariant".to_string()));
+ }
+
+ // The 'invariant' attribute may only appear on vertex
+ // shader outputs, not fragment shader inputs.
+ if !matches!(mode, LocationMode::VertexOutput) {
+ *invariant = false;
+ }
+ }
+
+ Ok(ResolvedBinding::BuiltIn(built_in))
+ }
+ crate::Binding::Location {
+ location,
+ interpolation,
+ sampling,
+ } => match mode {
+ LocationMode::VertexInput => Ok(ResolvedBinding::Attribute(location)),
+ LocationMode::FragmentOutput => Ok(ResolvedBinding::Color(location)),
+ LocationMode::VertexOutput | LocationMode::FragmentInput => {
+ Ok(ResolvedBinding::User {
+ prefix: if self.spirv_cross_compatibility {
+ "locn"
+ } else {
+ "loc"
+ },
+ index: location,
+ interpolation: {
+ // unwrap: The verifier ensures that vertex shader outputs and fragment
+ // shader inputs always have fully specified interpolation, and that
+ // sampling is `None` only for Flat interpolation.
+ let interpolation = interpolation.unwrap();
+ let sampling = sampling.unwrap_or(crate::Sampling::Center);
+ Some(ResolvedInterpolation::from_binding(interpolation, sampling))
+ },
+ })
+ }
+ LocationMode::Uniform => {
+ log::error!(
+ "Unexpected Binding::Location({}) for the Uniform mode",
+ location
+ );
+ Err(Error::Validation)
+ }
+ },
+ }
+ }
+
+ fn get_entry_point_resources(&self, ep: &crate::EntryPoint) -> Option<&EntryPointResources> {
+ self.per_entry_point_map.get(&ep.name)
+ }
+
+ fn get_resource_binding_target(
+ &self,
+ ep: &crate::EntryPoint,
+ res_binding: &crate::ResourceBinding,
+ ) -> Option<&BindTarget> {
+ self.get_entry_point_resources(ep)
+ .and_then(|res| res.resources.get(res_binding))
+ }
+
+ fn resolve_resource_binding(
+ &self,
+ ep: &crate::EntryPoint,
+ res_binding: &crate::ResourceBinding,
+ ) -> Result<ResolvedBinding, EntryPointError> {
+ let target = self.get_resource_binding_target(ep, res_binding);
+ match target {
+ Some(target) => Ok(ResolvedBinding::Resource(target.clone())),
+ None if self.fake_missing_bindings => Ok(ResolvedBinding::User {
+ prefix: "fake",
+ index: 0,
+ interpolation: None,
+ }),
+ None => Err(EntryPointError::MissingBindTarget(res_binding.clone())),
+ }
+ }
+
+ fn resolve_push_constants(
+ &self,
+ ep: &crate::EntryPoint,
+ ) -> Result<ResolvedBinding, EntryPointError> {
+ let slot = self
+ .get_entry_point_resources(ep)
+ .and_then(|res| res.push_constant_buffer);
+ match slot {
+ Some(slot) => Ok(ResolvedBinding::Resource(BindTarget {
+ buffer: Some(slot),
+ ..Default::default()
+ })),
+ None if self.fake_missing_bindings => Ok(ResolvedBinding::User {
+ prefix: "fake",
+ index: 0,
+ interpolation: None,
+ }),
+ None => Err(EntryPointError::MissingPushConstants),
+ }
+ }
+
+ fn resolve_sizes_buffer(
+ &self,
+ ep: &crate::EntryPoint,
+ ) -> Result<ResolvedBinding, EntryPointError> {
+ let slot = self
+ .get_entry_point_resources(ep)
+ .and_then(|res| res.sizes_buffer);
+ match slot {
+ Some(slot) => Ok(ResolvedBinding::Resource(BindTarget {
+ buffer: Some(slot),
+ ..Default::default()
+ })),
+ None if self.fake_missing_bindings => Ok(ResolvedBinding::User {
+ prefix: "fake",
+ index: 0,
+ interpolation: None,
+ }),
+ None => Err(EntryPointError::MissingSizesBuffer),
+ }
+ }
+}
+
+impl ResolvedBinding {
+ fn as_inline_sampler<'a>(&self, options: &'a Options) -> Option<&'a sampler::InlineSampler> {
+ match *self {
+ Self::Resource(BindTarget {
+ sampler: Some(BindSamplerTarget::Inline(index)),
+ ..
+ }) => Some(&options.inline_samplers[index as usize]),
+ _ => None,
+ }
+ }
+
+ const fn as_bind_target(&self) -> Option<&BindTarget> {
+ match *self {
+ Self::Resource(ref target) => Some(target),
+ _ => None,
+ }
+ }
+
+ fn try_fmt<W: Write>(&self, out: &mut W) -> Result<(), Error> {
+ write!(out, " [[")?;
+ match *self {
+ Self::BuiltIn(built_in) => {
+ use crate::BuiltIn as Bi;
+ let name = match built_in {
+ Bi::Position { invariant: false } => "position",
+ Bi::Position { invariant: true } => "position, invariant",
+ // vertex
+ Bi::BaseInstance => "base_instance",
+ Bi::BaseVertex => "base_vertex",
+ Bi::ClipDistance => "clip_distance",
+ Bi::InstanceIndex => "instance_id",
+ Bi::PointSize => "point_size",
+ Bi::VertexIndex => "vertex_id",
+ // fragment
+ Bi::FragDepth => "depth(any)",
+ Bi::PointCoord => "point_coord",
+ Bi::FrontFacing => "front_facing",
+ Bi::PrimitiveIndex => "primitive_id",
+ Bi::SampleIndex => "sample_id",
+ Bi::SampleMask => "sample_mask",
+ // compute
+ Bi::GlobalInvocationId => "thread_position_in_grid",
+ Bi::LocalInvocationId => "thread_position_in_threadgroup",
+ Bi::LocalInvocationIndex => "thread_index_in_threadgroup",
+ Bi::WorkGroupId => "threadgroup_position_in_grid",
+ Bi::WorkGroupSize => "dispatch_threads_per_threadgroup",
+ Bi::NumWorkGroups => "threadgroups_per_grid",
+ Bi::CullDistance | Bi::ViewIndex => {
+ return Err(Error::UnsupportedBuiltIn(built_in))
+ }
+ };
+ write!(out, "{name}")?;
+ }
+ Self::Attribute(index) => write!(out, "attribute({index})")?,
+ Self::Color(index) => write!(out, "color({index})")?,
+ Self::User {
+ prefix,
+ index,
+ interpolation,
+ } => {
+ write!(out, "user({prefix}{index})")?;
+ if let Some(interpolation) = interpolation {
+ write!(out, ", ")?;
+ interpolation.try_fmt(out)?;
+ }
+ }
+ Self::Resource(ref target) => {
+ if let Some(id) = target.buffer {
+ write!(out, "buffer({id})")?;
+ } else if let Some(id) = target.texture {
+ write!(out, "texture({id})")?;
+ } else if let Some(BindSamplerTarget::Resource(id)) = target.sampler {
+ write!(out, "sampler({id})")?;
+ } else {
+ return Err(Error::UnimplementedBindTarget(target.clone()));
+ }
+ }
+ }
+ write!(out, "]]")?;
+ Ok(())
+ }
+}
+
+impl ResolvedInterpolation {
+ const fn from_binding(interpolation: crate::Interpolation, sampling: crate::Sampling) -> Self {
+ use crate::Interpolation as I;
+ use crate::Sampling as S;
+
+ match (interpolation, sampling) {
+ (I::Perspective, S::Center) => Self::CenterPerspective,
+ (I::Perspective, S::Centroid) => Self::CentroidPerspective,
+ (I::Perspective, S::Sample) => Self::SamplePerspective,
+ (I::Linear, S::Center) => Self::CenterNoPerspective,
+ (I::Linear, S::Centroid) => Self::CentroidNoPerspective,
+ (I::Linear, S::Sample) => Self::SampleNoPerspective,
+ (I::Flat, _) => Self::Flat,
+ }
+ }
+
+ fn try_fmt<W: Write>(self, out: &mut W) -> Result<(), Error> {
+ let identifier = match self {
+ Self::CenterPerspective => "center_perspective",
+ Self::CenterNoPerspective => "center_no_perspective",
+ Self::CentroidPerspective => "centroid_perspective",
+ Self::CentroidNoPerspective => "centroid_no_perspective",
+ Self::SamplePerspective => "sample_perspective",
+ Self::SampleNoPerspective => "sample_no_perspective",
+ Self::Flat => "flat",
+ };
+ out.write_str(identifier)?;
+ Ok(())
+ }
+}
+
+/// Information about a translated module that is required
+/// for the use of the result.
+pub struct TranslationInfo {
+ /// Mapping of the entry point names. Each item in the array
+ /// corresponds to an entry point index.
+ ///
+ ///Note: Some entry points may fail translation because of missing bindings.
+ pub entry_point_names: Vec<Result<String, EntryPointError>>,
+}
+
+pub fn write_string(
+ module: &crate::Module,
+ info: &ModuleInfo,
+ options: &Options,
+ pipeline_options: &PipelineOptions,
+) -> Result<(String, TranslationInfo), Error> {
+ let mut w = writer::Writer::new(String::new());
+ let info = w.write(module, info, options, pipeline_options)?;
+ Ok((w.finish(), info))
+}
+
+#[test]
+fn test_error_size() {
+ use std::mem::size_of;
+ assert_eq!(size_of::<Error>(), 32);
+}
diff --git a/third_party/rust/naga/src/back/msl/sampler.rs b/third_party/rust/naga/src/back/msl/sampler.rs
new file mode 100644
index 0000000000..3b95fa3781
--- /dev/null
+++ b/third_party/rust/naga/src/back/msl/sampler.rs
@@ -0,0 +1,175 @@
+#[cfg(feature = "deserialize")]
+use serde::Deserialize;
+#[cfg(feature = "serialize")]
+use serde::Serialize;
+use std::{num::NonZeroU32, ops::Range};
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum Coord {
+ Normalized,
+ Pixel,
+}
+
+impl Default for Coord {
+ fn default() -> Self {
+ Self::Normalized
+ }
+}
+
+impl Coord {
+ pub const fn as_str(&self) -> &'static str {
+ match *self {
+ Self::Normalized => "normalized",
+ Self::Pixel => "pixel",
+ }
+ }
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum Address {
+ Repeat,
+ MirroredRepeat,
+ ClampToEdge,
+ ClampToZero,
+ ClampToBorder,
+}
+
+impl Default for Address {
+ fn default() -> Self {
+ Self::ClampToEdge
+ }
+}
+
+impl Address {
+ pub const fn as_str(&self) -> &'static str {
+ match *self {
+ Self::Repeat => "repeat",
+ Self::MirroredRepeat => "mirrored_repeat",
+ Self::ClampToEdge => "clamp_to_edge",
+ Self::ClampToZero => "clamp_to_zero",
+ Self::ClampToBorder => "clamp_to_border",
+ }
+ }
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum BorderColor {
+ TransparentBlack,
+ OpaqueBlack,
+ OpaqueWhite,
+}
+
+impl Default for BorderColor {
+ fn default() -> Self {
+ Self::TransparentBlack
+ }
+}
+
+impl BorderColor {
+ pub const fn as_str(&self) -> &'static str {
+ match *self {
+ Self::TransparentBlack => "transparent_black",
+ Self::OpaqueBlack => "opaque_black",
+ Self::OpaqueWhite => "opaque_white",
+ }
+ }
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum Filter {
+ Nearest,
+ Linear,
+}
+
+impl Filter {
+ pub const fn as_str(&self) -> &'static str {
+ match *self {
+ Self::Nearest => "nearest",
+ Self::Linear => "linear",
+ }
+ }
+}
+
+impl Default for Filter {
+ fn default() -> Self {
+ Self::Nearest
+ }
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum CompareFunc {
+ Never,
+ Less,
+ LessEqual,
+ Greater,
+ GreaterEqual,
+ Equal,
+ NotEqual,
+ Always,
+}
+
+impl Default for CompareFunc {
+ fn default() -> Self {
+ Self::Never
+ }
+}
+
+impl CompareFunc {
+ pub const fn as_str(&self) -> &'static str {
+ match *self {
+ Self::Never => "never",
+ Self::Less => "less",
+ Self::LessEqual => "less_equal",
+ Self::Greater => "greater",
+ Self::GreaterEqual => "greater_equal",
+ Self::Equal => "equal",
+ Self::NotEqual => "not_equal",
+ Self::Always => "always",
+ }
+ }
+}
+
+#[derive(Clone, Debug, Default, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub struct InlineSampler {
+ pub coord: Coord,
+ pub address: [Address; 3],
+ pub border_color: BorderColor,
+ pub mag_filter: Filter,
+ pub min_filter: Filter,
+ pub mip_filter: Option<Filter>,
+ pub lod_clamp: Option<Range<f32>>,
+ pub max_anisotropy: Option<NonZeroU32>,
+ pub compare_func: CompareFunc,
+}
+
+impl Eq for InlineSampler {}
+
+#[allow(clippy::derive_hash_xor_eq)]
+impl std::hash::Hash for InlineSampler {
+ fn hash<H: std::hash::Hasher>(&self, hasher: &mut H) {
+ self.coord.hash(hasher);
+ self.address.hash(hasher);
+ self.border_color.hash(hasher);
+ self.mag_filter.hash(hasher);
+ self.min_filter.hash(hasher);
+ self.mip_filter.hash(hasher);
+ self.lod_clamp
+ .as_ref()
+ .map(|range| (range.start.to_bits(), range.end.to_bits()))
+ .hash(hasher);
+ self.max_anisotropy.hash(hasher);
+ self.compare_func.hash(hasher);
+ }
+}
diff --git a/third_party/rust/naga/src/back/msl/writer.rs b/third_party/rust/naga/src/back/msl/writer.rs
new file mode 100644
index 0000000000..88600a8f11
--- /dev/null
+++ b/third_party/rust/naga/src/back/msl/writer.rs
@@ -0,0 +1,4429 @@
+use super::{sampler as sm, Error, LocationMode, Options, PipelineOptions, TranslationInfo};
+use crate::{
+ arena::Handle,
+ back,
+ proc::index,
+ proc::{self, NameKey, TypeResolution},
+ valid, FastHashMap, FastHashSet,
+};
+use bit_set::BitSet;
+use std::{
+ fmt::{Display, Error as FmtError, Formatter, Write},
+ iter,
+};
+
+/// Shorthand result used internally by the backend
+type BackendResult = Result<(), Error>;
+
+const NAMESPACE: &str = "metal";
+// The name of the array member of the Metal struct types we generate to
+// represent Naga `Array` types. See the comments in `Writer::write_type_defs`
+// for details.
+const WRAPPED_ARRAY_FIELD: &str = "inner";
+// This is a hack: we need to pass a pointer to an atomic,
+// but generally the backend isn't putting "&" in front of every pointer.
+// Some more general handling of pointers is needed to be implemented here.
+const ATOMIC_REFERENCE: &str = "&";
+
+const RT_NAMESPACE: &str = "metal::raytracing";
+const RAY_QUERY_TYPE: &str = "_RayQuery";
+const RAY_QUERY_FIELD_INTERSECTOR: &str = "intersector";
+const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection";
+const RAY_QUERY_FIELD_READY: &str = "ready";
+const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";
+
+/// Write the Metal name for a Naga numeric type: scalar, vector, or matrix.
+///
+/// The `sizes` slice determines whether this function writes a
+/// scalar, vector, or matrix type:
+///
+/// - An empty slice produces a scalar type.
+/// - A one-element slice produces a vector type.
+/// - A two element slice `[ROWS COLUMNS]` produces a matrix of the given size.
+fn put_numeric_type(
+ out: &mut impl Write,
+ kind: crate::ScalarKind,
+ sizes: &[crate::VectorSize],
+) -> Result<(), FmtError> {
+ match (kind, sizes) {
+ (kind, &[]) => {
+ write!(out, "{}", kind.to_msl_name())
+ }
+ (kind, &[rows]) => {
+ write!(
+ out,
+ "{}::{}{}",
+ NAMESPACE,
+ kind.to_msl_name(),
+ back::vector_size_str(rows)
+ )
+ }
+ (kind, &[rows, columns]) => {
+ write!(
+ out,
+ "{}::{}{}x{}",
+ NAMESPACE,
+ kind.to_msl_name(),
+ back::vector_size_str(columns),
+ back::vector_size_str(rows)
+ )
+ }
+ (_, _) => Ok(()), // not meaningful
+ }
+}
+
+/// Prefix for cached clamped level-of-detail values for `ImageLoad` expressions.
+const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e";
+
+struct TypeContext<'a> {
+ handle: Handle<crate::Type>,
+ module: &'a crate::Module,
+ names: &'a FastHashMap<NameKey, String>,
+ access: crate::StorageAccess,
+ binding: Option<&'a super::ResolvedBinding>,
+ first_time: bool,
+}
+
+impl<'a> Display for TypeContext<'a> {
+ fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> {
+ let ty = &self.module.types[self.handle];
+ if ty.needs_alias() && !self.first_time {
+ let name = &self.names[&NameKey::Type(self.handle)];
+ return write!(out, "{name}");
+ }
+
+ match ty.inner {
+ crate::TypeInner::Scalar { kind, .. } => put_numeric_type(out, kind, &[]),
+ crate::TypeInner::Atomic { kind, .. } => {
+ write!(out, "{}::atomic_{}", NAMESPACE, kind.to_msl_name())
+ }
+ crate::TypeInner::Vector { size, kind, .. } => put_numeric_type(out, kind, &[size]),
+ crate::TypeInner::Matrix { columns, rows, .. } => {
+ put_numeric_type(out, crate::ScalarKind::Float, &[rows, columns])
+ }
+ crate::TypeInner::Pointer { base, space } => {
+ let sub = Self {
+ handle: base,
+ first_time: false,
+ ..*self
+ };
+ let space_name = match space.to_msl_name() {
+ Some(name) => name,
+ None => return Ok(()),
+ };
+ write!(out, "{space_name} {sub}&")
+ }
+ crate::TypeInner::ValuePointer {
+ size,
+ kind,
+ width: _,
+ space,
+ } => {
+ match space.to_msl_name() {
+ Some(name) => write!(out, "{name} ")?,
+ None => return Ok(()),
+ };
+ match size {
+ Some(rows) => put_numeric_type(out, kind, &[rows])?,
+ None => put_numeric_type(out, kind, &[])?,
+ };
+
+ write!(out, "&")
+ }
+ crate::TypeInner::Array { base, .. } => {
+ let sub = Self {
+ handle: base,
+ first_time: false,
+ ..*self
+ };
+ // Array lengths go at the end of the type definition,
+ // so just print the element type here.
+ write!(out, "{sub}")
+ }
+ crate::TypeInner::Struct { .. } => unreachable!(),
+ crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ let dim_str = match dim {
+ crate::ImageDimension::D1 => "1d",
+ crate::ImageDimension::D2 => "2d",
+ crate::ImageDimension::D3 => "3d",
+ crate::ImageDimension::Cube => "cube",
+ };
+ let (texture_str, msaa_str, kind, access) = match class {
+ crate::ImageClass::Sampled { kind, multi } => {
+ let (msaa_str, access) = if multi {
+ ("_ms", "read")
+ } else {
+ ("", "sample")
+ };
+ ("texture", msaa_str, kind, access)
+ }
+ crate::ImageClass::Depth { multi } => {
+ let (msaa_str, access) = if multi {
+ ("_ms", "read")
+ } else {
+ ("", "sample")
+ };
+ ("depth", msaa_str, crate::ScalarKind::Float, access)
+ }
+ crate::ImageClass::Storage { format, .. } => {
+ let access = if self
+ .access
+ .contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE)
+ {
+ "read_write"
+ } else if self.access.contains(crate::StorageAccess::STORE) {
+ "write"
+ } else if self.access.contains(crate::StorageAccess::LOAD) {
+ "read"
+ } else {
+ log::warn!(
+ "Storage access for {:?} (name '{}'): {:?}",
+ self.handle,
+ ty.name.as_deref().unwrap_or_default(),
+ self.access
+ );
+ unreachable!("module is not valid");
+ };
+ ("texture", "", format.into(), access)
+ }
+ };
+ let base_name = kind.to_msl_name();
+ let array_str = if arrayed { "_array" } else { "" };
+ write!(
+ out,
+ "{NAMESPACE}::{texture_str}{dim_str}{msaa_str}{array_str}<{base_name}, {NAMESPACE}::access::{access}>",
+ )
+ }
+ crate::TypeInner::Sampler { comparison: _ } => {
+ write!(out, "{NAMESPACE}::sampler")
+ }
+ crate::TypeInner::AccelerationStructure => {
+ write!(out, "{RT_NAMESPACE}::instance_acceleration_structure")
+ }
+ crate::TypeInner::RayQuery => {
+ write!(out, "{RAY_QUERY_TYPE}")
+ }
+ crate::TypeInner::BindingArray { base, size } => {
+ let base_tyname = Self {
+ handle: base,
+ first_time: false,
+ ..*self
+ };
+
+ if let Some(&super::ResolvedBinding::Resource(super::BindTarget {
+ binding_array_size: Some(override_size),
+ ..
+ })) = self.binding
+ {
+ write!(out, "{NAMESPACE}::array<{base_tyname}, {override_size}>")
+ } else if let crate::ArraySize::Constant(size) = size {
+ let constant_ctx = ConstantContext {
+ handle: size,
+ arena: &self.module.constants,
+ names: self.names,
+ first_time: false,
+ };
+ write!(out, "{NAMESPACE}::array<{base_tyname}, {constant_ctx}>")
+ } else {
+ unreachable!("metal requires all arrays be constant sized");
+ }
+ }
+ }
+ }
+}
+
+struct TypedGlobalVariable<'a> {
+ module: &'a crate::Module,
+ names: &'a FastHashMap<NameKey, String>,
+ handle: Handle<crate::GlobalVariable>,
+ usage: valid::GlobalUse,
+ binding: Option<&'a super::ResolvedBinding>,
+ reference: bool,
+}
+
+impl<'a> TypedGlobalVariable<'a> {
+ fn try_fmt<W: Write>(&self, out: &mut W) -> BackendResult {
+ let var = &self.module.global_variables[self.handle];
+ let name = &self.names[&NameKey::GlobalVariable(self.handle)];
+
+ let storage_access = match var.space {
+ crate::AddressSpace::Storage { access } => access,
+ _ => match self.module.types[var.ty].inner {
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Storage { access, .. },
+ ..
+ } => access,
+ crate::TypeInner::BindingArray { base, .. } => {
+ match self.module.types[base].inner {
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Storage { access, .. },
+ ..
+ } => access,
+ _ => crate::StorageAccess::default(),
+ }
+ }
+ _ => crate::StorageAccess::default(),
+ },
+ };
+ let ty_name = TypeContext {
+ handle: var.ty,
+ module: self.module,
+ names: self.names,
+ access: storage_access,
+ binding: self.binding,
+ first_time: false,
+ };
+
+ let (space, access, reference) = match var.space.to_msl_name() {
+ Some(space) if self.reference => {
+ let access = if var.space.needs_access_qualifier()
+ && !self.usage.contains(valid::GlobalUse::WRITE)
+ {
+ "const"
+ } else {
+ ""
+ };
+ (space, access, "&")
+ }
+ _ => ("", "", ""),
+ };
+
+ Ok(write!(
+ out,
+ "{}{}{}{}{}{} {}",
+ space,
+ if space.is_empty() { "" } else { " " },
+ ty_name,
+ if access.is_empty() { "" } else { " " },
+ access,
+ reference,
+ name,
+ )?)
+ }
+}
+
+struct ConstantContext<'a> {
+ handle: Handle<crate::Constant>,
+ arena: &'a crate::Arena<crate::Constant>,
+ names: &'a FastHashMap<NameKey, String>,
+ first_time: bool,
+}
+
+impl<'a> Display for ConstantContext<'a> {
+ fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> {
+ let con = &self.arena[self.handle];
+ if con.needs_alias() && !self.first_time {
+ let name = &self.names[&NameKey::Constant(self.handle)];
+ return write!(out, "{name}");
+ }
+
+ match con.inner {
+ crate::ConstantInner::Scalar { value, width: _ } => match value {
+ crate::ScalarValue::Sint(value) => {
+ write!(out, "{value}")
+ }
+ crate::ScalarValue::Uint(value) => {
+ write!(out, "{value}u")
+ }
+ crate::ScalarValue::Float(value) => {
+ if value.is_infinite() {
+ let sign = if value.is_sign_negative() { "-" } else { "" };
+ write!(out, "{sign}INFINITY")
+ } else if value.is_nan() {
+ write!(out, "NAN")
+ } else {
+ let suffix = if value.fract() == 0.0 { ".0" } else { "" };
+
+ write!(out, "{value}{suffix}")
+ }
+ }
+ crate::ScalarValue::Bool(value) => {
+ write!(out, "{value}")
+ }
+ },
+ crate::ConstantInner::Composite { .. } => unreachable!("should be aliased"),
+ }
+ }
+}
+
+pub struct Writer<W> {
+ out: W,
+ names: FastHashMap<NameKey, String>,
+ named_expressions: crate::NamedExpressions,
+ /// Set of expressions that need to be baked to avoid unnecessary repetition in output
+ need_bake_expressions: back::NeedBakeExpressions,
+ namer: proc::Namer,
+ #[cfg(test)]
+ put_expression_stack_pointers: FastHashSet<*const ()>,
+ #[cfg(test)]
+ put_block_stack_pointers: FastHashSet<*const ()>,
+ /// Set of (struct type, struct field index) denoting which fields require
+ /// padding inserted **before** them (i.e. between fields at index - 1 and index)
+ struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,
+}
+
+impl crate::ScalarKind {
+ const fn to_msl_name(self) -> &'static str {
+ match self {
+ Self::Float => "float",
+ Self::Sint => "int",
+ Self::Uint => "uint",
+ Self::Bool => "bool",
+ }
+ }
+}
+
+const fn separate(need_separator: bool) -> &'static str {
+ if need_separator {
+ ","
+ } else {
+ ""
+ }
+}
+
+fn should_pack_struct_member(
+ members: &[crate::StructMember],
+ span: u32,
+ index: usize,
+ module: &crate::Module,
+) -> Option<crate::ScalarKind> {
+ let member = &members[index];
+ //Note: this is imperfect - the same structure can be used for host-shared
+ // things, where packed float would matter.
+ if member.binding.is_some() {
+ return None;
+ }
+
+ let ty_inner = &module.types[member.ty].inner;
+ let last_offset = member.offset + ty_inner.size(&module.constants);
+ let next_offset = match members.get(index + 1) {
+ Some(next) => next.offset,
+ None => span,
+ };
+ let is_tight = next_offset == last_offset;
+
+ match *ty_inner {
+ crate::TypeInner::Vector {
+ size: crate::VectorSize::Tri,
+ width: 4,
+ kind,
+ } if member.offset & 0xF != 0 || is_tight => Some(kind),
+ _ => None,
+ }
+}
+
+fn needs_array_length(ty: Handle<crate::Type>, arena: &crate::UniqueArena<crate::Type>) -> bool {
+ match arena[ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ if let Some(member) = members.last() {
+ if let crate::TypeInner::Array {
+ size: crate::ArraySize::Dynamic,
+ ..
+ } = arena[member.ty].inner
+ {
+ return true;
+ }
+ }
+ false
+ }
+ crate::TypeInner::Array {
+ size: crate::ArraySize::Dynamic,
+ ..
+ } => true,
+ _ => false,
+ }
+}
+
+impl crate::AddressSpace {
+ /// Returns true if global variables in this address space are
+ /// passed in function arguments. These arguments need to be
+ /// passed through any functions called from the entry point.
+ const fn needs_pass_through(&self) -> bool {
+ match *self {
+ Self::Uniform
+ | Self::Storage { .. }
+ | Self::Private
+ | Self::WorkGroup
+ | Self::PushConstant
+ | Self::Handle => true,
+ Self::Function => false,
+ }
+ }
+
+ /// Returns true if the address space may need a "const" qualifier.
+ const fn needs_access_qualifier(&self) -> bool {
+ match *self {
+ //Note: we are ignoring the storage access here, and instead
+ // rely on the actual use of a global by functions. This means we
+ // may end up with "const" even if the binding is read-write,
+ // and that should be OK.
+ Self::Storage { .. } => true,
+ // These should always be read-write.
+ Self::Private | Self::WorkGroup => false,
+ // These translate to `constant` address space, no need for qualifiers.
+ Self::Uniform | Self::PushConstant => false,
+ // Not applicable.
+ Self::Handle | Self::Function => false,
+ }
+ }
+
+ const fn to_msl_name(self) -> Option<&'static str> {
+ match self {
+ Self::Handle => None,
+ Self::Uniform | Self::PushConstant => Some("constant"),
+ Self::Storage { .. } => Some("device"),
+ Self::Private | Self::Function => Some("thread"),
+ Self::WorkGroup => Some("threadgroup"),
+ }
+ }
+}
+
+impl crate::Type {
+ // Returns `true` if we need to emit an alias for this type.
+ const fn needs_alias(&self) -> bool {
+ use crate::TypeInner as Ti;
+
+ match self.inner {
+ // value types are concise enough, we only alias them if they are named
+ Ti::Scalar { .. }
+ | Ti::Vector { .. }
+ | Ti::Matrix { .. }
+ | Ti::Atomic { .. }
+ | Ti::Pointer { .. }
+ | Ti::ValuePointer { .. } => self.name.is_some(),
+ // composite types are better to be aliased, regardless of the name
+ Ti::Struct { .. } | Ti::Array { .. } => true,
+ // handle types may be different, depending on the global var access, so we always inline them
+ Ti::Image { .. }
+ | Ti::Sampler { .. }
+ | Ti::AccelerationStructure
+ | Ti::RayQuery
+ | Ti::BindingArray { .. } => false,
+ }
+ }
+}
+
+impl crate::Constant {
+ // Returns `true` if we need to emit an alias for this constant.
+ const fn needs_alias(&self) -> bool {
+ match self.inner {
+ crate::ConstantInner::Scalar { .. } => self.name.is_some(),
+ crate::ConstantInner::Composite { .. } => true,
+ }
+ }
+}
+
+enum FunctionOrigin {
+ Handle(Handle<crate::Function>),
+ EntryPoint(proc::EntryPointIndex),
+}
+
+/// A level of detail argument.
+///
+/// When [`BoundsCheckPolicy::Restrict`] applies to an [`ImageLoad`] access, we
+/// save the clamped level of detail in a temporary variable whose name is based
+/// on the handle of the `ImageLoad` expression. But for other policies, we just
+/// use the expression directly.
+///
+/// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict
+/// [`ImageLoad`]: crate::Expression::ImageLoad
+#[derive(Clone, Copy)]
+enum LevelOfDetail {
+ Direct(Handle<crate::Expression>),
+ Restricted(Handle<crate::Expression>),
+}
+
+/// Values needed to select a particular texel for [`ImageLoad`] and [`ImageStore`].
+///
+/// When this is used in code paths unconcerned with the `Restrict` bounds check
+/// policy, the `LevelOfDetail` enum introduces an unneeded match, since `level`
+/// will always be either `None` or `Some(Direct(_))`. But this turns out not to
+/// be too awkward. If that changes, we can revisit.
+///
+/// [`ImageLoad`]: crate::Expression::ImageLoad
+/// [`ImageStore`]: crate::Statement::ImageStore
+struct TexelAddress {
+ coordinate: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ sample: Option<Handle<crate::Expression>>,
+ level: Option<LevelOfDetail>,
+}
+
+struct ExpressionContext<'a> {
+ function: &'a crate::Function,
+ origin: FunctionOrigin,
+ info: &'a valid::FunctionInfo,
+ module: &'a crate::Module,
+ pipeline_options: &'a PipelineOptions,
+ policies: index::BoundsCheckPolicies,
+
+ /// A bitset containing the `Expression` handle indexes of expressions used
+ /// as indices in `ReadZeroSkipWrite`-policy accesses. These may need to be
+ /// cached in temporary variables. See `index::find_checked_indexes` for
+ /// details.
+ guarded_indices: BitSet,
+}
+
+impl<'a> ExpressionContext<'a> {
+ fn resolve_type(&self, handle: Handle<crate::Expression>) -> &'a crate::TypeInner {
+ self.info[handle].ty.inner_with(&self.module.types)
+ }
+
+ /// Return true if calls to `image`'s `read` and `write` methods should supply a level of detail.
+ ///
+ /// Only mipmapped images need to specify a level of detail. Since 1D
+ /// textures cannot have mipmaps, MSL requires that the level argument to
+ /// texture1d queries and accesses must be a constexpr 0. It's easiest
+ /// just to omit the level entirely for 1D textures.
+ fn image_needs_lod(&self, image: Handle<crate::Expression>) -> bool {
+ let image_ty = self.resolve_type(image);
+ if let crate::TypeInner::Image { dim, class, .. } = *image_ty {
+ class.is_mipmapped() && dim != crate::ImageDimension::D1
+ } else {
+ false
+ }
+ }
+
+ fn choose_bounds_check_policy(
+ &self,
+ pointer: Handle<crate::Expression>,
+ ) -> index::BoundsCheckPolicy {
+ self.policies
+ .choose_policy(pointer, &self.module.types, self.info)
+ }
+
+ fn access_needs_check(
+ &self,
+ base: Handle<crate::Expression>,
+ index: index::GuardedIndex,
+ ) -> Option<index::IndexableLength> {
+ index::access_needs_check(base, index, self.module, self.function, self.info)
+ }
+
+ fn get_packed_vec_kind(
+ &self,
+ expr_handle: Handle<crate::Expression>,
+ ) -> Option<crate::ScalarKind> {
+ match self.function.expressions[expr_handle] {
+ crate::Expression::AccessIndex { base, index } => {
+ let ty = match *self.resolve_type(base) {
+ crate::TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
+ ref ty => ty,
+ };
+ match *ty {
+ crate::TypeInner::Struct {
+ ref members, span, ..
+ } => should_pack_struct_member(members, span, index as usize, self.module),
+ _ => None,
+ }
+ }
+ _ => None,
+ }
+ }
+}
+
+struct StatementContext<'a> {
+ expression: ExpressionContext<'a>,
+ mod_info: &'a valid::ModuleInfo,
+ result_struct: Option<&'a str>,
+}
+
+impl<W: Write> Writer<W> {
+ /// Creates a new `Writer` instance.
+ pub fn new(out: W) -> Self {
+ Writer {
+ out,
+ names: FastHashMap::default(),
+ named_expressions: Default::default(),
+ need_bake_expressions: Default::default(),
+ namer: proc::Namer::default(),
+ #[cfg(test)]
+ put_expression_stack_pointers: Default::default(),
+ #[cfg(test)]
+ put_block_stack_pointers: Default::default(),
+ struct_member_pads: FastHashSet::default(),
+ }
+ }
+
+ /// Finishes writing and returns the output.
+ // See https://github.com/rust-lang/rust-clippy/issues/4979.
+ #[allow(clippy::missing_const_for_fn)]
+ pub fn finish(self) -> W {
+ self.out
+ }
+
+ fn put_call_parameters(
+ &mut self,
+ parameters: impl Iterator<Item = Handle<crate::Expression>>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ write!(self.out, "(")?;
+ for (i, handle) in parameters.enumerate() {
+ if i != 0 {
+ write!(self.out, ", ")?;
+ }
+ self.put_expression(handle, context, true)?;
+ }
+ write!(self.out, ")")?;
+ Ok(())
+ }
+
+ fn put_level_of_detail(
+ &mut self,
+ level: LevelOfDetail,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ match level {
+ LevelOfDetail::Direct(expr) => self.put_expression(expr, context, true)?,
+ LevelOfDetail::Restricted(load) => {
+ write!(self.out, "{}{}", CLAMPED_LOD_LOAD_PREFIX, load.index())?
+ }
+ }
+ Ok(())
+ }
+
+ fn put_image_query(
+ &mut self,
+ image: Handle<crate::Expression>,
+ query: &str,
+ level: Option<LevelOfDetail>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".get_{query}(")?;
+ if let Some(level) = level {
+ self.put_level_of_detail(level, context)?;
+ }
+ write!(self.out, ")")?;
+ Ok(())
+ }
+
+ fn put_image_size_query(
+ &mut self,
+ image: Handle<crate::Expression>,
+ level: Option<LevelOfDetail>,
+ kind: crate::ScalarKind,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ //Note: MSL only has separate width/height/depth queries,
+ // so compose the result of them.
+ let dim = match *context.resolve_type(image) {
+ crate::TypeInner::Image { dim, .. } => dim,
+ ref other => unreachable!("Unexpected type {:?}", other),
+ };
+ let coordinate_type = kind.to_msl_name();
+ match dim {
+ crate::ImageDimension::D1 => {
+ // Since 1D textures never have mipmaps, MSL requires that the
+ // `level` argument be a constexpr 0. It's simplest for us just
+ // to pass `None` and omit the level entirely.
+ if kind == crate::ScalarKind::Uint {
+ // No need to construct a vector. No cast needed.
+ self.put_image_query(image, "width", None, context)?;
+ } else {
+ // There's no definition for `int` in the `metal` namespace.
+ write!(self.out, "int(")?;
+ self.put_image_query(image, "width", None, context)?;
+ write!(self.out, ")")?;
+ }
+ }
+ crate::ImageDimension::D2 => {
+ write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
+ self.put_image_query(image, "width", level, context)?;
+ write!(self.out, ", ")?;
+ self.put_image_query(image, "height", level, context)?;
+ write!(self.out, ")")?;
+ }
+ crate::ImageDimension::D3 => {
+ write!(self.out, "{NAMESPACE}::{coordinate_type}3(")?;
+ self.put_image_query(image, "width", level, context)?;
+ write!(self.out, ", ")?;
+ self.put_image_query(image, "height", level, context)?;
+ write!(self.out, ", ")?;
+ self.put_image_query(image, "depth", level, context)?;
+ write!(self.out, ")")?;
+ }
+ crate::ImageDimension::Cube => {
+ write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
+ self.put_image_query(image, "width", level, context)?;
+ write!(self.out, ")")?;
+ }
+ }
+ Ok(())
+ }
+
+ fn put_cast_to_uint_scalar_or_vector(
+ &mut self,
+ expr: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ // coordinates in IR are int, but Metal expects uint
+ match *context.resolve_type(expr) {
+ crate::TypeInner::Scalar { .. } => {
+ put_numeric_type(&mut self.out, crate::ScalarKind::Uint, &[])?
+ }
+ crate::TypeInner::Vector { size, .. } => {
+ put_numeric_type(&mut self.out, crate::ScalarKind::Uint, &[size])?
+ }
+ _ => return Err(Error::Validation),
+ };
+
+ write!(self.out, "(")?;
+ self.put_expression(expr, context, true)?;
+ write!(self.out, ")")?;
+ Ok(())
+ }
+
+ fn put_image_sample_level(
+ &mut self,
+ image: Handle<crate::Expression>,
+ level: crate::SampleLevel,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ let has_levels = context.image_needs_lod(image);
+ match level {
+ crate::SampleLevel::Auto => {}
+ crate::SampleLevel::Zero => {
+ //TODO: do we support Zero on `Sampled` image classes?
+ }
+ _ if !has_levels => {
+ log::warn!("1D image can't be sampled with level {:?}", level);
+ }
+ crate::SampleLevel::Exact(h) => {
+ write!(self.out, ", {NAMESPACE}::level(")?;
+ self.put_expression(h, context, true)?;
+ write!(self.out, ")")?;
+ }
+ crate::SampleLevel::Bias(h) => {
+ write!(self.out, ", {NAMESPACE}::bias(")?;
+ self.put_expression(h, context, true)?;
+ write!(self.out, ")")?;
+ }
+ crate::SampleLevel::Gradient { x, y } => {
+ write!(self.out, ", {NAMESPACE}::gradient2d(")?;
+ self.put_expression(x, context, true)?;
+ write!(self.out, ", ")?;
+ self.put_expression(y, context, true)?;
+ write!(self.out, ")")?;
+ }
+ }
+ Ok(())
+ }
+
+ fn put_image_coordinate_limits(
+ &mut self,
+ image: Handle<crate::Expression>,
+ level: Option<LevelOfDetail>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
+ write!(self.out, " - 1")?;
+ Ok(())
+ }
+
+ /// General function for writing restricted image indexes.
+ ///
+ /// This is used to produce restricted mip levels, array indices, and sample
+ /// indices for [`ImageLoad`] and [`ImageStore`] accesses under the
+ /// [`Restrict`] bounds check policy.
+ ///
+ /// This function writes an expression of the form:
+ ///
+ /// ```ignore
+ ///
+ /// metal::min(uint(INDEX), IMAGE.LIMIT_METHOD() - 1)
+ ///
+ /// ```
+ ///
+ /// [`ImageLoad`]: crate::Expression::ImageLoad
+ /// [`ImageStore`]: crate::Statement::ImageStore
+ /// [`Restrict`]: index::BoundsCheckPolicy::Restrict
+ fn put_restricted_scalar_image_index(
+ &mut self,
+ image: Handle<crate::Expression>,
+ index: Handle<crate::Expression>,
+ limit_method: &str,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ write!(self.out, "{NAMESPACE}::min(uint(")?;
+ self.put_expression(index, context, true)?;
+ write!(self.out, "), ")?;
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".{limit_method}() - 1)")?;
+ Ok(())
+ }
+
+ fn put_restricted_texel_address(
+ &mut self,
+ image: Handle<crate::Expression>,
+ address: &TexelAddress,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ // Write the coordinate.
+ write!(self.out, "{NAMESPACE}::min(")?;
+ self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
+ write!(self.out, ", ")?;
+ self.put_image_coordinate_limits(image, address.level, context)?;
+ write!(self.out, ")")?;
+
+ // Write the array index, if present.
+ if let Some(array_index) = address.array_index {
+ write!(self.out, ", ")?;
+ self.put_restricted_scalar_image_index(image, array_index, "get_array_size", context)?;
+ }
+
+ // Write the sample index, if present.
+ if let Some(sample) = address.sample {
+ write!(self.out, ", ")?;
+ self.put_restricted_scalar_image_index(image, sample, "get_num_samples", context)?;
+ }
+
+ // The level of detail should be clamped and cached by
+ // `put_cache_restricted_level`, so we don't need to clamp it here.
+ if let Some(level) = address.level {
+ write!(self.out, ", ")?;
+ self.put_level_of_detail(level, context)?;
+ }
+
+ Ok(())
+ }
+
+ /// Write an expression that is true if the given image access is in bounds.
+ fn put_image_access_bounds_check(
+ &mut self,
+ image: Handle<crate::Expression>,
+ address: &TexelAddress,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ let mut conjunction = "";
+
+ // First, check the level of detail. Only if that is in bounds can we
+ // use it to find the appropriate bounds for the coordinates.
+ let level = if let Some(level) = address.level {
+ write!(self.out, "uint(")?;
+ self.put_level_of_detail(level, context)?;
+ write!(self.out, ") < ")?;
+ self.put_expression(image, context, true)?;
+ write!(self.out, ".get_num_mip_levels()")?;
+ conjunction = " && ";
+ Some(level)
+ } else {
+ None
+ };
+
+ // Check sample index, if present.
+ if let Some(sample) = address.sample {
+ write!(self.out, "uint(")?;
+ self.put_expression(sample, context, true)?;
+ write!(self.out, ") < ")?;
+ self.put_expression(image, context, true)?;
+ write!(self.out, ".get_num_samples()")?;
+ conjunction = " && ";
+ }
+
+ // Check array index, if present.
+ if let Some(array_index) = address.array_index {
+ write!(self.out, "{conjunction}uint(")?;
+ self.put_expression(array_index, context, true)?;
+ write!(self.out, ") < ")?;
+ self.put_expression(image, context, true)?;
+ write!(self.out, ".get_array_size()")?;
+ conjunction = " && ";
+ }
+
+ // Finally, check if the coordinates are within bounds.
+ let coord_is_vector = match *context.resolve_type(address.coordinate) {
+ crate::TypeInner::Vector { .. } => true,
+ _ => false,
+ };
+ write!(self.out, "{conjunction}")?;
+ if coord_is_vector {
+ write!(self.out, "{NAMESPACE}::all(")?;
+ }
+ self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
+ write!(self.out, " < ")?;
+ self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
+ if coord_is_vector {
+ write!(self.out, ")")?;
+ }
+
+ Ok(())
+ }
+
+ fn put_image_load(
+ &mut self,
+ load: Handle<crate::Expression>,
+ image: Handle<crate::Expression>,
+ mut address: TexelAddress,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ match context.policies.image {
+ proc::BoundsCheckPolicy::Restrict => {
+ // Use the cached restricted level of detail, if any. Omit the
+ // level altogether for 1D textures.
+ if address.level.is_some() {
+ address.level = if context.image_needs_lod(image) {
+ Some(LevelOfDetail::Restricted(load))
+ } else {
+ None
+ }
+ }
+
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".read(")?;
+ self.put_restricted_texel_address(image, &address, context)?;
+ write!(self.out, ")")?;
+ }
+ proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
+ write!(self.out, "(")?;
+ self.put_image_access_bounds_check(image, &address, context)?;
+ write!(self.out, " ? ")?;
+ self.put_unchecked_image_load(image, &address, context)?;
+ write!(self.out, ": DefaultConstructible())")?;
+ }
+ proc::BoundsCheckPolicy::Unchecked => {
+ self.put_unchecked_image_load(image, &address, context)?;
+ }
+ }
+
+ Ok(())
+ }
+
+ fn put_unchecked_image_load(
+ &mut self,
+ image: Handle<crate::Expression>,
+ address: &TexelAddress,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".read(")?;
+ // coordinates in IR are int, but Metal expects uint
+ self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
+ if let Some(expr) = address.array_index {
+ write!(self.out, ", ")?;
+ self.put_expression(expr, context, true)?;
+ }
+ if let Some(sample) = address.sample {
+ write!(self.out, ", ")?;
+ self.put_expression(sample, context, true)?;
+ }
+ if let Some(level) = address.level {
+ if context.image_needs_lod(image) {
+ write!(self.out, ", ")?;
+ self.put_level_of_detail(level, context)?;
+ }
+ }
+ write!(self.out, ")")?;
+
+ Ok(())
+ }
+
+ fn put_image_store(
+ &mut self,
+ level: back::Level,
+ image: Handle<crate::Expression>,
+ address: &TexelAddress,
+ value: Handle<crate::Expression>,
+ context: &StatementContext,
+ ) -> BackendResult {
+ match context.expression.policies.image {
+ proc::BoundsCheckPolicy::Restrict => {
+ // We don't have a restricted level value, because we don't
+ // support writes to mipmapped textures.
+ debug_assert!(address.level.is_none());
+
+ write!(self.out, "{level}")?;
+ self.put_expression(image, &context.expression, false)?;
+ write!(self.out, ".write(")?;
+ self.put_expression(value, &context.expression, true)?;
+ write!(self.out, ", ")?;
+ self.put_restricted_texel_address(image, address, &context.expression)?;
+ writeln!(self.out, ");")?;
+ }
+ proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
+ write!(self.out, "{level}if (")?;
+ self.put_image_access_bounds_check(image, address, &context.expression)?;
+ writeln!(self.out, ") {{")?;
+ self.put_unchecked_image_store(level.next(), image, address, value, context)?;
+ writeln!(self.out, "{level}}}")?;
+ }
+ proc::BoundsCheckPolicy::Unchecked => {
+ self.put_unchecked_image_store(level, image, address, value, context)?;
+ }
+ }
+
+ Ok(())
+ }
+
+ fn put_unchecked_image_store(
+ &mut self,
+ level: back::Level,
+ image: Handle<crate::Expression>,
+ address: &TexelAddress,
+ value: Handle<crate::Expression>,
+ context: &StatementContext,
+ ) -> BackendResult {
+ write!(self.out, "{level}")?;
+ self.put_expression(image, &context.expression, false)?;
+ write!(self.out, ".write(")?;
+ self.put_expression(value, &context.expression, true)?;
+ write!(self.out, ", ")?;
+ // coordinates in IR are int, but Metal expects uint
+ self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
+ if let Some(expr) = address.array_index {
+ write!(self.out, ", ")?;
+ self.put_expression(expr, &context.expression, true)?;
+ }
+ writeln!(self.out, ");")?;
+
+ Ok(())
+ }
+
+ fn put_compose(
+ &mut self,
+ ty: Handle<crate::Type>,
+ components: &[Handle<crate::Expression>],
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ match context.module.types[ty].inner {
+ crate::TypeInner::Scalar { width: 4, kind } if components.len() == 1 => {
+ write!(self.out, "{}", kind.to_msl_name())?;
+ self.put_call_parameters(components.iter().cloned(), context)?;
+ }
+ crate::TypeInner::Vector { size, kind, .. } => {
+ put_numeric_type(&mut self.out, kind, &[size])?;
+ self.put_call_parameters(components.iter().cloned(), context)?;
+ }
+ crate::TypeInner::Matrix { columns, rows, .. } => {
+ put_numeric_type(&mut self.out, crate::ScalarKind::Float, &[rows, columns])?;
+ self.put_call_parameters(components.iter().cloned(), context)?;
+ }
+ crate::TypeInner::Array { .. } | crate::TypeInner::Struct { .. } => {
+ write!(self.out, "{} {{", &self.names[&NameKey::Type(ty)])?;
+ for (index, &component) in components.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ // insert padding initialization, if needed
+ if self.struct_member_pads.contains(&(ty, index as u32)) {
+ write!(self.out, "{{}}, ")?;
+ }
+ self.put_expression(component, context, true)?;
+ }
+ write!(self.out, "}}")?;
+ }
+ _ => return Err(Error::UnsupportedCompose(ty)),
+ }
+ Ok(())
+ }
+
+ /// Write the maximum valid index of the dynamically sized array at the end of `handle`.
+ ///
+ /// The 'maximum valid index' is simply one less than the array's length.
+ ///
+ /// This emits an expression of the form `a / b`, so the caller must
+ /// parenthesize its output if it will be applying operators of higher
+ /// precedence.
+ ///
+ /// `handle` must be the handle of a global variable whose final member is a
+ /// dynamically sized array.
+ fn put_dynamic_array_max_index(
+ &mut self,
+ handle: Handle<crate::GlobalVariable>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ let global = &context.module.global_variables[handle];
+ let (offset, array_ty) = match context.module.types[global.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => match members.last() {
+ Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
+ None => return Err(Error::Validation),
+ },
+ crate::TypeInner::Array {
+ size: crate::ArraySize::Dynamic,
+ ..
+ } => (0, global.ty),
+ _ => return Err(Error::Validation),
+ };
+
+ let (size, stride) = match context.module.types[array_ty].inner {
+ crate::TypeInner::Array { base, stride, .. } => (
+ context.module.types[base]
+ .inner
+ .size(&context.module.constants),
+ stride,
+ ),
+ _ => return Err(Error::Validation),
+ };
+
+ // When the stride length is larger than the size, the final element's stride of
+ // bytes would have padding following the value. But the buffer size in
+ // `buffer_sizes.sizeN` may not include this padding - it only needs to be large
+ // enough to hold the actual values' bytes.
+ //
+ // So subtract off the size to get a byte size that falls at the start or within
+ // the final element. Then divide by the stride size, to get one less than the
+ // length, and then add one. This works even if the buffer size does include the
+ // stride padding, since division rounds towards zero (MSL 2.4 §6.1). It will fail
+ // if there are zero elements in the array, but the WebGPU `validating shader binding`
+ // rules, together with draw-time validation when `minBindingSize` is zero,
+ // prevent that.
+ write!(
+ self.out,
+ "(_buffer_sizes.size{idx} - {offset} - {size}) / {stride}",
+ idx = handle.index(),
+ offset = offset,
+ size = size,
+ stride = stride,
+ )?;
+ Ok(())
+ }
+
+ fn put_atomic_fetch(
+ &mut self,
+ pointer: Handle<crate::Expression>,
+ key: &str,
+ value: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ self.put_atomic_operation(pointer, "fetch_", key, value, context)
+ }
+
+ fn put_atomic_operation(
+ &mut self,
+ pointer: Handle<crate::Expression>,
+ key1: &str,
+ key2: &str,
+ value: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ // If the pointer we're passing to the atomic operation needs to be conditional
+ // for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and
+ // the pointer operand should be unchecked.
+ let policy = context.choose_bounds_check_policy(pointer);
+ let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
+ && self.put_bounds_checks(pointer, context, back::Level(0), "")?;
+
+ // If requested and successfully put bounds checks, continue the ternary expression.
+ if checked {
+ write!(self.out, " ? ")?;
+ }
+
+ write!(
+ self.out,
+ "{NAMESPACE}::atomic_{key1}{key2}_explicit({ATOMIC_REFERENCE}"
+ )?;
+ self.put_access_chain(pointer, policy, context)?;
+ write!(self.out, ", ")?;
+ self.put_expression(value, context, true)?;
+ write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
+
+ // Finish the ternary expression.
+ if checked {
+ write!(self.out, " : DefaultConstructible()")?;
+ }
+
+ Ok(())
+ }
+
+ /// Emit code for the arithmetic expression of the dot product.
+ ///
+ fn put_dot_product(
+ &mut self,
+ arg: Handle<crate::Expression>,
+ arg1: Handle<crate::Expression>,
+ size: usize,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ // Write parantheses around the dot product expression to prevent operators
+ // with different precedences from applying earlier.
+ write!(self.out, "(")?;
+
+ // Cycle trough all the components of the vector
+ for index in 0..size {
+ let component = back::COMPONENTS[index];
+ // Write the addition to the previous product
+ // This will print an extra '+' at the beginning but that is fine in msl
+ write!(self.out, " + ")?;
+ // Write the first vector expression, this expression is marked to be
+ // cached so unless it can't be cached (for example, it's a Constant)
+ // it shouldn't produce large expressions.
+ self.put_expression(arg, context, true)?;
+ // Access the current component on the first vector
+ write!(self.out, ".{component} * ")?;
+ // Write the second vector expression, this expression is marked to be
+ // cached so unless it can't be cached (for example, it's a Constant)
+ // it shouldn't produce large expressions.
+ self.put_expression(arg1, context, true)?;
+ // Access the current component on the second vector
+ write!(self.out, ".{component}")?;
+ }
+
+ write!(self.out, ")")?;
+ Ok(())
+ }
+
+ /// Emit code for the expression `expr_handle`.
+ ///
+ /// The `is_scoped` argument is true if the surrounding operators have the
+ /// precedence of the comma operator, or lower. So, for example:
+ ///
+ /// - Pass `true` for `is_scoped` when writing function arguments, an
+ /// expression statement, an initializer expression, or anything already
+ /// wrapped in parenthesis.
+ ///
+ /// - Pass `false` if it is an operand of a `?:` operator, a `[]`, or really
+ /// almost anything else.
+ fn put_expression(
+ &mut self,
+ expr_handle: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ is_scoped: bool,
+ ) -> BackendResult {
+ // Add to the set in order to track the stack size.
+ #[cfg(test)]
+ #[allow(trivial_casts)]
+ self.put_expression_stack_pointers
+ .insert(&expr_handle as *const _ as *const ());
+
+ if let Some(name) = self.named_expressions.get(&expr_handle) {
+ write!(self.out, "{name}")?;
+ return Ok(());
+ }
+
+ let expression = &context.function.expressions[expr_handle];
+ log::trace!("expression {:?} = {:?}", expr_handle, expression);
+ match *expression {
+ crate::Expression::Access { base, .. }
+ | crate::Expression::AccessIndex { base, .. } => {
+ // This is an acceptable place to generate a `ReadZeroSkipWrite` check.
+ // Since `put_bounds_checks` and `put_access_chain` handle an entire
+ // access chain at a time, recursing back through `put_expression` only
+ // for index expressions and the base object, we will never see intermediate
+ // `Access` or `AccessIndex` expressions here.
+ let policy = context.choose_bounds_check_policy(base);
+ if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
+ && self.put_bounds_checks(
+ expr_handle,
+ context,
+ back::Level(0),
+ if is_scoped { "" } else { "(" },
+ )?
+ {
+ write!(self.out, " ? ")?;
+ self.put_access_chain(expr_handle, policy, context)?;
+ write!(self.out, " : DefaultConstructible()")?;
+
+ if !is_scoped {
+ write!(self.out, ")")?;
+ }
+ } else {
+ self.put_access_chain(expr_handle, policy, context)?;
+ }
+ }
+ crate::Expression::Constant(handle) => {
+ let coco = ConstantContext {
+ handle,
+ arena: &context.module.constants,
+ names: &self.names,
+ first_time: false,
+ };
+ write!(self.out, "{coco}")?;
+ }
+ crate::Expression::Splat { size, value } => {
+ let scalar_kind = match *context.resolve_type(value) {
+ crate::TypeInner::Scalar { kind, .. } => kind,
+ _ => return Err(Error::Validation),
+ };
+ put_numeric_type(&mut self.out, scalar_kind, &[size])?;
+ write!(self.out, "(")?;
+ self.put_expression(value, context, true)?;
+ write!(self.out, ")")?;
+ }
+ crate::Expression::Swizzle {
+ size,
+ vector,
+ pattern,
+ } => {
+ self.put_wrapped_expression_for_packed_vec3_access(vector, context, false)?;
+ write!(self.out, ".")?;
+ for &sc in pattern[..size as usize].iter() {
+ write!(self.out, "{}", back::COMPONENTS[sc as usize])?;
+ }
+ }
+ crate::Expression::Compose { ty, ref components } => {
+ self.put_compose(ty, components, context)?;
+ }
+ crate::Expression::FunctionArgument(index) => {
+ let name_key = match context.origin {
+ FunctionOrigin::Handle(handle) => NameKey::FunctionArgument(handle, index),
+ FunctionOrigin::EntryPoint(ep_index) => {
+ NameKey::EntryPointArgument(ep_index, index)
+ }
+ };
+ let name = &self.names[&name_key];
+ write!(self.out, "{name}")?;
+ }
+ crate::Expression::GlobalVariable(handle) => {
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ write!(self.out, "{name}")?;
+ }
+ crate::Expression::LocalVariable(handle) => {
+ let name_key = match context.origin {
+ FunctionOrigin::Handle(fun_handle) => {
+ NameKey::FunctionLocal(fun_handle, handle)
+ }
+ FunctionOrigin::EntryPoint(ep_index) => {
+ NameKey::EntryPointLocal(ep_index, handle)
+ }
+ };
+ let name = &self.names[&name_key];
+ write!(self.out, "{name}")?;
+ }
+ crate::Expression::Load { pointer } => self.put_load(pointer, context, is_scoped)?,
+ crate::Expression::ImageSample {
+ image,
+ sampler,
+ gather,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ } => {
+ let main_op = match gather {
+ Some(_) => "gather",
+ None => "sample",
+ };
+ let comparison_op = match depth_ref {
+ Some(_) => "_compare",
+ None => "",
+ };
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".{main_op}{comparison_op}(")?;
+ self.put_expression(sampler, context, true)?;
+ write!(self.out, ", ")?;
+ self.put_expression(coordinate, context, true)?;
+ if let Some(expr) = array_index {
+ write!(self.out, ", ")?;
+ self.put_expression(expr, context, true)?;
+ }
+ if let Some(dref) = depth_ref {
+ write!(self.out, ", ")?;
+ self.put_expression(dref, context, true)?;
+ }
+
+ self.put_image_sample_level(image, level, context)?;
+
+ if let Some(constant) = offset {
+ let coco = ConstantContext {
+ handle: constant,
+ arena: &context.module.constants,
+ names: &self.names,
+ first_time: false,
+ };
+ write!(self.out, ", {coco}")?;
+ }
+ match gather {
+ None | Some(crate::SwizzleComponent::X) => {}
+ Some(component) => {
+ let is_cube_map = match *context.resolve_type(image) {
+ crate::TypeInner::Image {
+ dim: crate::ImageDimension::Cube,
+ ..
+ } => true,
+ _ => false,
+ };
+ // Offset always comes before the gather, except
+ // in cube maps where it's not applicable
+ if offset.is_none() && !is_cube_map {
+ write!(self.out, ", {NAMESPACE}::int2(0)")?;
+ }
+ let letter = back::COMPONENTS[component as usize];
+ write!(self.out, ", {NAMESPACE}::component::{letter}")?;
+ }
+ }
+ write!(self.out, ")")?;
+ }
+ crate::Expression::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => {
+ let address = TexelAddress {
+ coordinate,
+ array_index,
+ sample,
+ level: level.map(LevelOfDetail::Direct),
+ };
+ self.put_image_load(expr_handle, image, address, context)?;
+ }
+ //Note: for all the queries, the signed integers are expected,
+ // so a conversion is needed.
+ crate::Expression::ImageQuery { image, query } => match query {
+ crate::ImageQuery::Size { level } => {
+ self.put_image_size_query(
+ image,
+ level.map(LevelOfDetail::Direct),
+ crate::ScalarKind::Uint,
+ context,
+ )?;
+ }
+ crate::ImageQuery::NumLevels => {
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".get_num_mip_levels()")?;
+ }
+ crate::ImageQuery::NumLayers => {
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".get_array_size()")?;
+ }
+ crate::ImageQuery::NumSamples => {
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".get_num_samples()")?;
+ }
+ },
+ crate::Expression::Unary { op, expr } => {
+ use crate::{ScalarKind as Sk, UnaryOperator as Uo};
+ let op_str = match op {
+ Uo::Negate => "-",
+ Uo::Not => match context.resolve_type(expr).scalar_kind() {
+ Some(Sk::Sint) | Some(Sk::Uint) => "~",
+ Some(Sk::Bool) => "!",
+ _ => return Err(Error::Validation),
+ },
+ };
+ write!(self.out, "{op_str}(")?;
+ self.put_expression(expr, context, false)?;
+ write!(self.out, ")")?;
+ }
+ crate::Expression::Binary { op, left, right } => {
+ let op_str = crate::back::binary_operation_str(op);
+ let kind = context
+ .resolve_type(left)
+ .scalar_kind()
+ .ok_or(Error::UnsupportedBinaryOp(op))?;
+
+ // TODO: handle undefined behavior of BinaryOperator::Modulo
+ //
+ // sint:
+ // if right == 0 return 0
+ // if left == min(type_of(left)) && right == -1 return 0
+ // if sign(left) == -1 || sign(right) == -1 return result as defined by WGSL
+ //
+ // uint:
+ // if right == 0 return 0
+ //
+ // float:
+ // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798
+
+ if op == crate::BinaryOperator::Modulo && kind == crate::ScalarKind::Float {
+ write!(self.out, "{NAMESPACE}::fmod(")?;
+ self.put_expression(left, context, true)?;
+ write!(self.out, ", ")?;
+ self.put_expression(right, context, true)?;
+ write!(self.out, ")")?;
+ } else {
+ if !is_scoped {
+ write!(self.out, "(")?;
+ }
+
+ // Cast packed vector if necessary
+ // Packed vector - matrix multiplications are not supported in MSL
+ if op == crate::BinaryOperator::Multiply
+ && matches!(
+ context.resolve_type(right),
+ &crate::TypeInner::Matrix { .. }
+ )
+ {
+ self.put_wrapped_expression_for_packed_vec3_access(left, context, false)?;
+ } else {
+ self.put_expression(left, context, false)?;
+ }
+
+ write!(self.out, " {op_str} ")?;
+
+ // See comment above
+ if op == crate::BinaryOperator::Multiply
+ && matches!(context.resolve_type(left), &crate::TypeInner::Matrix { .. })
+ {
+ self.put_wrapped_expression_for_packed_vec3_access(right, context, false)?;
+ } else {
+ self.put_expression(right, context, false)?;
+ }
+
+ if !is_scoped {
+ write!(self.out, ")")?;
+ }
+ }
+ }
+ crate::Expression::Select {
+ condition,
+ accept,
+ reject,
+ } => match *context.resolve_type(condition) {
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Bool,
+ ..
+ } => {
+ if !is_scoped {
+ write!(self.out, "(")?;
+ }
+ self.put_expression(condition, context, false)?;
+ write!(self.out, " ? ")?;
+ self.put_expression(accept, context, false)?;
+ write!(self.out, " : ")?;
+ self.put_expression(reject, context, false)?;
+ if !is_scoped {
+ write!(self.out, ")")?;
+ }
+ }
+ crate::TypeInner::Vector {
+ kind: crate::ScalarKind::Bool,
+ ..
+ } => {
+ write!(self.out, "{NAMESPACE}::select(")?;
+ self.put_expression(reject, context, true)?;
+ write!(self.out, ", ")?;
+ self.put_expression(accept, context, true)?;
+ write!(self.out, ", ")?;
+ self.put_expression(condition, context, true)?;
+ write!(self.out, ")")?;
+ }
+ _ => return Err(Error::Validation),
+ },
+ crate::Expression::Derivative { axis, expr, .. } => {
+ use crate::DerivativeAxis as Axis;
+ let op = match axis {
+ Axis::X => "dfdx",
+ Axis::Y => "dfdy",
+ Axis::Width => "fwidth",
+ };
+ write!(self.out, "{NAMESPACE}::{op}")?;
+ self.put_call_parameters(iter::once(expr), context)?;
+ }
+ crate::Expression::Relational { fun, argument } => {
+ let op = match fun {
+ crate::RelationalFunction::Any => "any",
+ crate::RelationalFunction::All => "all",
+ crate::RelationalFunction::IsNan => "isnan",
+ crate::RelationalFunction::IsInf => "isinf",
+ crate::RelationalFunction::IsFinite => "isfinite",
+ crate::RelationalFunction::IsNormal => "isnormal",
+ };
+ write!(self.out, "{NAMESPACE}::{op}")?;
+ self.put_call_parameters(iter::once(argument), context)?;
+ }
+ crate::Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ use crate::MathFunction as Mf;
+
+ let scalar_argument = match *context.resolve_type(arg) {
+ crate::TypeInner::Scalar { .. } => true,
+ _ => false,
+ };
+
+ let fun_name = match fun {
+ // comparison
+ Mf::Abs => "abs",
+ Mf::Min => "min",
+ Mf::Max => "max",
+ Mf::Clamp => "clamp",
+ Mf::Saturate => "saturate",
+ // trigonometry
+ Mf::Cos => "cos",
+ Mf::Cosh => "cosh",
+ Mf::Sin => "sin",
+ Mf::Sinh => "sinh",
+ Mf::Tan => "tan",
+ Mf::Tanh => "tanh",
+ Mf::Acos => "acos",
+ Mf::Asin => "asin",
+ Mf::Atan => "atan",
+ Mf::Atan2 => "atan2",
+ Mf::Asinh => "asinh",
+ Mf::Acosh => "acosh",
+ Mf::Atanh => "atanh",
+ Mf::Radians => "",
+ Mf::Degrees => "",
+ // decomposition
+ Mf::Ceil => "ceil",
+ Mf::Floor => "floor",
+ Mf::Round => "rint",
+ Mf::Fract => "fract",
+ Mf::Trunc => "trunc",
+ Mf::Modf => "modf",
+ Mf::Frexp => "frexp",
+ Mf::Ldexp => "ldexp",
+ // exponent
+ Mf::Exp => "exp",
+ Mf::Exp2 => "exp2",
+ Mf::Log => "log",
+ Mf::Log2 => "log2",
+ Mf::Pow => "pow",
+ // geometry
+ Mf::Dot => match *context.resolve_type(arg) {
+ crate::TypeInner::Vector {
+ kind: crate::ScalarKind::Float,
+ ..
+ } => "dot",
+ crate::TypeInner::Vector { size, .. } => {
+ return self.put_dot_product(arg, arg1.unwrap(), size as usize, context)
+ }
+ _ => unreachable!(
+ "Correct TypeInner for dot product should be already validated"
+ ),
+ },
+ Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
+ Mf::Cross => "cross",
+ Mf::Distance => "distance",
+ Mf::Length if scalar_argument => "abs",
+ Mf::Length => "length",
+ Mf::Normalize => "normalize",
+ Mf::FaceForward => "faceforward",
+ Mf::Reflect => "reflect",
+ Mf::Refract => "refract",
+ // computational
+ Mf::Sign => "sign",
+ Mf::Fma => "fma",
+ Mf::Mix => "mix",
+ Mf::Step => "step",
+ Mf::SmoothStep => "smoothstep",
+ Mf::Sqrt => "sqrt",
+ Mf::InverseSqrt => "rsqrt",
+ Mf::Inverse => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
+ Mf::Transpose => "transpose",
+ Mf::Determinant => "determinant",
+ // bits
+ Mf::CountTrailingZeros => "ctz",
+ Mf::CountLeadingZeros => "clz",
+ Mf::CountOneBits => "popcount",
+ Mf::ReverseBits => "reverse_bits",
+ Mf::ExtractBits => "extract_bits",
+ Mf::InsertBits => "insert_bits",
+ Mf::FindLsb => "",
+ Mf::FindMsb => "",
+ // data packing
+ Mf::Pack4x8snorm => "pack_float_to_snorm4x8",
+ Mf::Pack4x8unorm => "pack_float_to_unorm4x8",
+ Mf::Pack2x16snorm => "pack_float_to_snorm2x16",
+ Mf::Pack2x16unorm => "pack_float_to_unorm2x16",
+ Mf::Pack2x16float => "",
+ // data unpacking
+ Mf::Unpack4x8snorm => "unpack_snorm4x8_to_float",
+ Mf::Unpack4x8unorm => "unpack_unorm4x8_to_float",
+ Mf::Unpack2x16snorm => "unpack_snorm2x16_to_float",
+ Mf::Unpack2x16unorm => "unpack_unorm2x16_to_float",
+ Mf::Unpack2x16float => "",
+ };
+
+ if fun == Mf::Distance && scalar_argument {
+ write!(self.out, "{NAMESPACE}::abs(")?;
+ self.put_expression(arg, context, false)?;
+ write!(self.out, " - ")?;
+ self.put_expression(arg1.unwrap(), context, false)?;
+ write!(self.out, ")")?;
+ } else if fun == Mf::FindLsb {
+ write!(self.out, "((({NAMESPACE}::ctz(")?;
+ self.put_expression(arg, context, true)?;
+ write!(self.out, ") + 1) % 33) - 1)")?;
+ } else if fun == Mf::FindMsb {
+ let inner = context.resolve_type(arg);
+
+ write!(self.out, "{NAMESPACE}::select(31 - {NAMESPACE}::clz(")?;
+
+ if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() {
+ write!(self.out, "{NAMESPACE}::select(")?;
+ self.put_expression(arg, context, true)?;
+ write!(self.out, ", ~")?;
+ self.put_expression(arg, context, true)?;
+ write!(self.out, ", ")?;
+ self.put_expression(arg, context, true)?;
+ write!(self.out, " < 0)")?;
+ } else {
+ self.put_expression(arg, context, true)?;
+ }
+
+ write!(self.out, "), ")?;
+
+ // or metal will complain that select is ambiguous
+ match *inner {
+ crate::TypeInner::Vector { size, kind, .. } => {
+ let size = back::vector_size_str(size);
+ if let crate::ScalarKind::Sint = kind {
+ write!(self.out, "int{size}")?;
+ } else {
+ write!(self.out, "uint{size}")?;
+ }
+ }
+ crate::TypeInner::Scalar { kind, .. } => {
+ if let crate::ScalarKind::Sint = kind {
+ write!(self.out, "int")?;
+ } else {
+ write!(self.out, "uint")?;
+ }
+ }
+ _ => (),
+ }
+
+ write!(self.out, "(-1), ")?;
+ self.put_expression(arg, context, true)?;
+ write!(self.out, " == 0 || ")?;
+ self.put_expression(arg, context, true)?;
+ write!(self.out, " == -1)")?;
+ } else if fun == Mf::Unpack2x16float {
+ write!(self.out, "float2(as_type<half2>(")?;
+ self.put_expression(arg, context, false)?;
+ write!(self.out, "))")?;
+ } else if fun == Mf::Pack2x16float {
+ write!(self.out, "as_type<uint>(half2(")?;
+ self.put_expression(arg, context, false)?;
+ write!(self.out, "))")?;
+ } else if fun == Mf::Radians {
+ write!(self.out, "((")?;
+ self.put_expression(arg, context, false)?;
+ write!(self.out, ") * 0.017453292519943295474)")?;
+ } else if fun == Mf::Degrees {
+ write!(self.out, "((")?;
+ self.put_expression(arg, context, false)?;
+ write!(self.out, ") * 57.295779513082322865)")?;
+ } else {
+ write!(self.out, "{NAMESPACE}::{fun_name}")?;
+ self.put_call_parameters(
+ iter::once(arg).chain(arg1).chain(arg2).chain(arg3),
+ context,
+ )?;
+ }
+ }
+ crate::Expression::As {
+ expr,
+ kind,
+ convert,
+ } => match *context.resolve_type(expr) {
+ crate::TypeInner::Scalar {
+ kind: src_kind,
+ width: src_width,
+ }
+ | crate::TypeInner::Vector {
+ kind: src_kind,
+ width: src_width,
+ ..
+ } => {
+ let is_bool_cast =
+ kind == crate::ScalarKind::Bool || src_kind == crate::ScalarKind::Bool;
+ let op = match convert {
+ Some(w) if w == src_width || is_bool_cast => "static_cast",
+ Some(8) if kind == crate::ScalarKind::Float => {
+ return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
+ }
+ Some(_) => return Err(Error::Validation),
+ None => "as_type",
+ };
+ write!(self.out, "{op}<")?;
+ match *context.resolve_type(expr) {
+ crate::TypeInner::Vector { size, .. } => {
+ put_numeric_type(&mut self.out, kind, &[size])?
+ }
+ _ => put_numeric_type(&mut self.out, kind, &[])?,
+ };
+ write!(self.out, ">(")?;
+ self.put_expression(expr, context, true)?;
+ write!(self.out, ")")?;
+ }
+ crate::TypeInner::Matrix { columns, rows, .. } => {
+ put_numeric_type(&mut self.out, kind, &[rows, columns])?;
+ write!(self.out, "(")?;
+ self.put_expression(expr, context, true)?;
+ write!(self.out, ")")?;
+ }
+ _ => return Err(Error::Validation),
+ },
+ // has to be a named expression
+ crate::Expression::CallResult(_)
+ | crate::Expression::AtomicResult { .. }
+ | crate::Expression::RayQueryProceedResult => {
+ unreachable!()
+ }
+ crate::Expression::ArrayLength(expr) => {
+ // Find the global to which the array belongs.
+ let global = match context.function.expressions[expr] {
+ crate::Expression::AccessIndex { base, .. } => {
+ match context.function.expressions[base] {
+ crate::Expression::GlobalVariable(handle) => handle,
+ _ => return Err(Error::Validation),
+ }
+ }
+ crate::Expression::GlobalVariable(handle) => handle,
+ _ => return Err(Error::Validation),
+ };
+
+ if !is_scoped {
+ write!(self.out, "(")?;
+ }
+ write!(self.out, "1 + ")?;
+ self.put_dynamic_array_max_index(global, context)?;
+ if !is_scoped {
+ write!(self.out, ")")?;
+ }
+ }
+ crate::Expression::RayQueryGetIntersection { query, committed } => {
+ if !committed {
+ unimplemented!()
+ }
+ let ty = context.module.special_types.ray_intersection.unwrap();
+ let type_name = &self.names[&NameKey::Type(ty)];
+ write!(self.out, "{type_name} {{{RAY_QUERY_FUN_MAP_INTERSECTION}(")?;
+ self.put_expression(query, context, true)?;
+ write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.type)")?;
+ let fields = [
+ "distance",
+ "user_instance_id",
+ "instance_id",
+ "", // SBT offset
+ "geometry_id",
+ "primitive_id",
+ "triangle_barycentric_coord",
+ "triangle_front_facing",
+ "", // padding
+ "object_to_world_transform",
+ "world_to_object_transform",
+ ];
+ for field in fields {
+ write!(self.out, ", ")?;
+ if field.is_empty() {
+ write!(self.out, "{{}}")?;
+ } else {
+ self.put_expression(query, context, true)?;
+ write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.{field}")?;
+ }
+ }
+ write!(self.out, "}}")?;
+ }
+ }
+ Ok(())
+ }
+
+ /// Used by expressions like Swizzle and Binary since they need packed_vec3's to be casted to a vec3
+ fn put_wrapped_expression_for_packed_vec3_access(
+ &mut self,
+ expr_handle: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ is_scoped: bool,
+ ) -> BackendResult {
+ if let Some(scalar_kind) = context.get_packed_vec_kind(expr_handle) {
+ write!(self.out, "{}::{}3(", NAMESPACE, scalar_kind.to_msl_name())?;
+ self.put_expression(expr_handle, context, is_scoped)?;
+ write!(self.out, ")")?;
+ } else {
+ self.put_expression(expr_handle, context, is_scoped)?;
+ }
+ Ok(())
+ }
+
+ /// Write a `GuardedIndex` as a Metal expression.
+ fn put_index(
+ &mut self,
+ index: index::GuardedIndex,
+ context: &ExpressionContext,
+ is_scoped: bool,
+ ) -> BackendResult {
+ match index {
+ index::GuardedIndex::Expression(expr) => {
+ self.put_expression(expr, context, is_scoped)?
+ }
+ index::GuardedIndex::Known(value) => write!(self.out, "{value}")?,
+ }
+ Ok(())
+ }
+
+ /// Emit an index bounds check condition for `chain`, if required.
+ ///
+ /// `chain` is a subtree of `Access` and `AccessIndex` expressions,
+ /// operating either on a pointer to a value, or on a value directly. If we cannot
+ /// statically determine that all indexing operations in `chain` are within
+ /// bounds, then write a conditional expression to check them dynamically,
+ /// and return true. All accesses in the chain are checked by the generated
+ /// expression.
+ ///
+ /// This assumes that the [`BoundsCheckPolicy`] for `chain` is [`ReadZeroSkipWrite`].
+ ///
+ /// The text written is of the form:
+ ///
+ /// ```ignore
+ /// {level}{prefix}uint(i) < 4 && uint(j) < 10
+ /// ```
+ ///
+ /// where `{level}` and `{prefix}` are the arguments to this function. For [`Store`]
+ /// statements, presumably these arguments start an indented `if` statement; for
+ /// [`Load`] expressions, the caller is probably building up a ternary `?:`
+ /// expression. In either case, what is written is not a complete syntactic structure
+ /// in its own right, and the caller will have to finish it off if we return `true`.
+ ///
+ /// If no expression is written, return false.
+ ///
+ /// [`BoundsCheckPolicy`]: index::BoundsCheckPolicy
+ /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
+ /// [`Store`]: crate::Statement::Store
+ /// [`Load`]: crate::Expression::Load
+ #[allow(unused_variables)]
+ fn put_bounds_checks(
+ &mut self,
+ mut chain: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ level: back::Level,
+ prefix: &'static str,
+ ) -> Result<bool, Error> {
+ let mut check_written = false;
+
+ // Iterate over the access chain, handling each expression.
+ loop {
+ // Produce a `GuardedIndex`, so we can shared code between the
+ // `Access` and `AccessIndex` cases.
+ let (base, guarded_index) = match context.function.expressions[chain] {
+ crate::Expression::Access { base, index } => {
+ (base, Some(index::GuardedIndex::Expression(index)))
+ }
+ crate::Expression::AccessIndex { base, index } => {
+ // Don't try to check indices into structs. Validation already took
+ // care of them, and index::needs_guard doesn't handle that case.
+ let mut base_inner = context.resolve_type(base);
+ if let crate::TypeInner::Pointer { base, .. } = *base_inner {
+ base_inner = &context.module.types[base].inner;
+ }
+ match *base_inner {
+ crate::TypeInner::Struct { .. } => (base, None),
+ _ => (base, Some(index::GuardedIndex::Known(index))),
+ }
+ }
+ _ => break,
+ };
+
+ if let Some(index) = guarded_index {
+ if let Some(length) = context.access_needs_check(base, index) {
+ if check_written {
+ write!(self.out, " && ")?;
+ } else {
+ write!(self.out, "{level}{prefix}")?;
+ check_written = true;
+ }
+
+ // Check that the index falls within bounds. Do this with a single
+ // comparison, by casting the index to `uint` first, so that negative
+ // indices become large positive values.
+ write!(self.out, "uint(")?;
+ self.put_index(index, context, true)?;
+ self.out.write_str(") < ")?;
+ match length {
+ index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
+ index::IndexableLength::Dynamic => {
+ let global = context
+ .function
+ .originating_global(base)
+ .ok_or(Error::Validation)?;
+ write!(self.out, "1 + ")?;
+ self.put_dynamic_array_max_index(global, context)?
+ }
+ }
+ }
+ }
+
+ chain = base
+ }
+
+ Ok(check_written)
+ }
+
+ /// Write the access chain `chain`.
+ ///
+ /// `chain` is a subtree of [`Access`] and [`AccessIndex`] expressions,
+ /// operating either on a pointer to a value, or on a value directly.
+ ///
+ /// Generate bounds checks code only if `policy` is [`Restrict`]. The
+ /// [`ReadZeroSkipWrite`] policy requires checks before any accesses take place, so
+ /// that must be handled in the caller.
+ ///
+ /// Handle the entire chain, recursing back into `put_expression` only for index
+ /// expressions and the base expression that originates the pointer or composite value
+ /// being accessed. This allows `put_expression` to assume that any `Access` or
+ /// `AccessIndex` expressions it sees are the top of a chain, so it can emit
+ /// `ReadZeroSkipWrite` checks.
+ ///
+ /// [`Access`]: crate::Expression::Access
+ /// [`AccessIndex`]: crate::Expression::AccessIndex
+ /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict
+ /// [`ReadZeroSkipWrite`]: crate::proc::index::BoundsCheckPolicy::ReadZeroSkipWrite
+ fn put_access_chain(
+ &mut self,
+ chain: Handle<crate::Expression>,
+ policy: index::BoundsCheckPolicy,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ match context.function.expressions[chain] {
+ crate::Expression::Access { base, index } => {
+ let mut base_ty = context.resolve_type(base);
+
+ // Look through any pointers to see what we're really indexing.
+ if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
+ base_ty = &context.module.types[base].inner;
+ }
+
+ self.put_subscripted_access_chain(
+ base,
+ base_ty,
+ index::GuardedIndex::Expression(index),
+ policy,
+ context,
+ )?;
+ }
+ crate::Expression::AccessIndex { base, index } => {
+ let base_resolution = &context.info[base].ty;
+ let mut base_ty = base_resolution.inner_with(&context.module.types);
+ let mut base_ty_handle = base_resolution.handle();
+
+ // Look through any pointers to see what we're really indexing.
+ if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
+ base_ty = &context.module.types[base].inner;
+ base_ty_handle = Some(base);
+ }
+
+ // Handle structs and anything else that can use `.x` syntax here, so
+ // `put_subscripted_access_chain` won't have to handle the absurd case of
+ // indexing a struct with an expression.
+ match *base_ty {
+ crate::TypeInner::Struct { .. } => {
+ let base_ty = base_ty_handle.unwrap();
+ self.put_access_chain(base, policy, context)?;
+ let name = &self.names[&NameKey::StructMember(base_ty, index)];
+ write!(self.out, ".{name}")?;
+ }
+ crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => {
+ self.put_access_chain(base, policy, context)?;
+ // Prior to Metal v2.1 component access for packed vectors wasn't available
+ // however array indexing is
+ if context.get_packed_vec_kind(base).is_some() {
+ write!(self.out, "[{index}]")?;
+ } else {
+ write!(self.out, ".{}", back::COMPONENTS[index as usize])?;
+ }
+ }
+ _ => {
+ self.put_subscripted_access_chain(
+ base,
+ base_ty,
+ index::GuardedIndex::Known(index),
+ policy,
+ context,
+ )?;
+ }
+ }
+ }
+ _ => self.put_expression(chain, context, false)?,
+ }
+
+ Ok(())
+ }
+
+ /// Write a `[]`-style access of `base` by `index`.
+ ///
+ /// If `policy` is [`Restrict`], then generate code as needed to force all index
+ /// values within bounds.
+ ///
+ /// The `base_ty` argument must be the type we are actually indexing, like [`Array`] or
+ /// [`Vector`]. In other words, it's `base`'s type with any surrounding [`Pointer`]
+ /// removed. Our callers often already have this handy.
+ ///
+ /// This only emits `[]` expressions; it doesn't handle struct member accesses or
+ /// referencing vector components by name.
+ ///
+ /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict
+ /// [`Array`]: crate::TypeInner::Array
+ /// [`Vector`]: crate::TypeInner::Vector
+ /// [`Pointer`]: crate::TypeInner::Pointer
+ fn put_subscripted_access_chain(
+ &mut self,
+ base: Handle<crate::Expression>,
+ base_ty: &crate::TypeInner,
+ index: index::GuardedIndex,
+ policy: index::BoundsCheckPolicy,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ let accessing_wrapped_array = match *base_ty {
+ crate::TypeInner::Array {
+ size: crate::ArraySize::Constant(_),
+ ..
+ } => true,
+ _ => false,
+ };
+
+ self.put_access_chain(base, policy, context)?;
+ if accessing_wrapped_array {
+ write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
+ }
+ write!(self.out, "[")?;
+
+ // Decide whether this index needs to be clamped to fall within range.
+ let restriction_needed = if policy == index::BoundsCheckPolicy::Restrict {
+ context.access_needs_check(base, index)
+ } else {
+ None
+ };
+ if let Some(limit) = restriction_needed {
+ write!(self.out, "{NAMESPACE}::min(unsigned(")?;
+ self.put_index(index, context, true)?;
+ write!(self.out, "), ")?;
+ match limit {
+ index::IndexableLength::Known(limit) => {
+ write!(self.out, "{}u", limit - 1)?;
+ }
+ index::IndexableLength::Dynamic => {
+ let global = context
+ .function
+ .originating_global(base)
+ .ok_or(Error::Validation)?;
+ self.put_dynamic_array_max_index(global, context)?;
+ }
+ }
+ write!(self.out, ")")?;
+ } else {
+ self.put_index(index, context, true)?;
+ }
+
+ write!(self.out, "]")?;
+
+ Ok(())
+ }
+
+ fn put_load(
+ &mut self,
+ pointer: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ is_scoped: bool,
+ ) -> BackendResult {
+ // Since access chains never cross between address spaces, we can just
+ // check the index bounds check policy once at the top.
+ let policy = context.choose_bounds_check_policy(pointer);
+ if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
+ && self.put_bounds_checks(
+ pointer,
+ context,
+ back::Level(0),
+ if is_scoped { "" } else { "(" },
+ )?
+ {
+ write!(self.out, " ? ")?;
+ self.put_unchecked_load(pointer, policy, context)?;
+ write!(self.out, " : DefaultConstructible()")?;
+
+ if !is_scoped {
+ write!(self.out, ")")?;
+ }
+ } else {
+ self.put_unchecked_load(pointer, policy, context)?;
+ }
+
+ Ok(())
+ }
+
+ fn put_unchecked_load(
+ &mut self,
+ pointer: Handle<crate::Expression>,
+ policy: index::BoundsCheckPolicy,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ let is_atomic = match *context.resolve_type(pointer) {
+ crate::TypeInner::Pointer { base, .. } => match context.module.types[base].inner {
+ crate::TypeInner::Atomic { .. } => true,
+ _ => false,
+ },
+ _ => false,
+ };
+
+ if is_atomic {
+ write!(
+ self.out,
+ "{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}"
+ )?;
+ self.put_access_chain(pointer, policy, context)?;
+ write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
+ } else {
+ // We don't do any dereferencing with `*` here as pointer arguments to functions
+ // are done by `&` references and not `*` pointers. These do not need to be
+ // dereferenced.
+ self.put_access_chain(pointer, policy, context)?;
+ }
+
+ Ok(())
+ }
+
+ fn put_return_value(
+ &mut self,
+ level: back::Level,
+ expr_handle: Handle<crate::Expression>,
+ result_struct: Option<&str>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ match result_struct {
+ Some(struct_name) => {
+ let mut has_point_size = false;
+ let result_ty = context.function.result.as_ref().unwrap().ty;
+ match context.module.types[result_ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ let tmp = "_tmp";
+ write!(self.out, "{level}const auto {tmp} = ")?;
+ self.put_expression(expr_handle, context, true)?;
+ writeln!(self.out, ";")?;
+ write!(self.out, "{level}return {struct_name} {{")?;
+
+ let mut is_first = true;
+
+ for (index, member) in members.iter().enumerate() {
+ if let Some(crate::Binding::BuiltIn(crate::BuiltIn::PointSize)) =
+ member.binding
+ {
+ has_point_size = true;
+ if !context.pipeline_options.allow_point_size {
+ continue;
+ }
+ }
+
+ let comma = if is_first { "" } else { "," };
+ is_first = false;
+ let name = &self.names[&NameKey::StructMember(result_ty, index as u32)];
+ // HACK: we are forcefully deduplicating the expression here
+ // to convert from a wrapped struct to a raw array, e.g.
+ // `float gl_ClipDistance1 [[clip_distance]] [1];`.
+ if let crate::TypeInner::Array {
+ size: crate::ArraySize::Constant(const_handle),
+ ..
+ } = context.module.types[member.ty].inner
+ {
+ let size = context.module.constants[const_handle]
+ .to_array_length()
+ .unwrap();
+ write!(self.out, "{comma} {{")?;
+ for j in 0..size {
+ if j != 0 {
+ write!(self.out, ",")?;
+ }
+ write!(self.out, "{tmp}.{name}.{WRAPPED_ARRAY_FIELD}[{j}]")?;
+ }
+ write!(self.out, "}}")?;
+ } else {
+ write!(self.out, "{comma} {tmp}.{name}")?;
+ }
+ }
+ }
+ _ => {
+ write!(self.out, "{level}return {struct_name} {{ ")?;
+ self.put_expression(expr_handle, context, true)?;
+ }
+ }
+
+ if let FunctionOrigin::EntryPoint(ep_index) = context.origin {
+ let stage = context.module.entry_points[ep_index as usize].stage;
+ if context.pipeline_options.allow_point_size
+ && stage == crate::ShaderStage::Vertex
+ && !has_point_size
+ {
+ // point size was injected and comes last
+ write!(self.out, ", 1.0")?;
+ }
+ }
+ write!(self.out, " }}")?;
+ }
+ None => {
+ write!(self.out, "{level}return ")?;
+ self.put_expression(expr_handle, context, true)?;
+ }
+ }
+ writeln!(self.out, ";")?;
+ Ok(())
+ }
+
+ /// Helper method used to find which expressions of a given function require baking
+ ///
+ /// # Notes
+ /// This function overwrites the contents of `self.need_bake_expressions`
+ fn update_expressions_to_bake(
+ &mut self,
+ func: &crate::Function,
+ info: &valid::FunctionInfo,
+ context: &ExpressionContext,
+ ) {
+ use crate::Expression;
+ self.need_bake_expressions.clear();
+
+ for (expr_handle, expr) in func.expressions.iter() {
+ // Expressions whose reference count is above the
+ // threshold should always be stored in temporaries.
+ let expr_info = &info[expr_handle];
+ let min_ref_count = func.expressions[expr_handle].bake_ref_count();
+ if min_ref_count <= expr_info.ref_count {
+ self.need_bake_expressions.insert(expr_handle);
+ } else {
+ match expr_info.ty {
+ // force ray desc to be baked: it's used multiple times internally
+ TypeResolution::Handle(h)
+ if Some(h) == context.module.special_types.ray_desc =>
+ {
+ self.need_bake_expressions.insert(expr_handle);
+ }
+ _ => {}
+ }
+ }
+
+ if let Expression::Math { fun, arg, arg1, .. } = *expr {
+ match fun {
+ crate::MathFunction::Dot => {
+ // WGSL's `dot` function works on any `vecN` type, but Metal's only
+ // works on floating-point vectors, so we emit inline code for
+ // integer vector `dot` calls. But that code uses each argument `N`
+ // times, once for each component (see `put_dot_product`), so to
+ // avoid duplicated evaluation, we must bake integer operands.
+
+ // check what kind of product this is depending
+ // on the resolve type of the Dot function itself
+ if let crate::TypeInner::Scalar { kind, .. } =
+ *context.resolve_type(expr_handle)
+ {
+ match kind {
+ crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
+ self.need_bake_expressions.insert(arg);
+ self.need_bake_expressions.insert(arg1.unwrap());
+ }
+ _ => {}
+ }
+ }
+ }
+ crate::MathFunction::FindMsb => {
+ self.need_bake_expressions.insert(arg);
+ }
+ _ => {}
+ }
+ }
+ }
+ }
+
+ fn start_baking_expression(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ name: &str,
+ ) -> BackendResult {
+ match context.info[handle].ty {
+ TypeResolution::Handle(ty_handle) => {
+ let ty_name = TypeContext {
+ handle: ty_handle,
+ module: context.module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ write!(self.out, "{ty_name}")?;
+ }
+ TypeResolution::Value(crate::TypeInner::Scalar { kind, .. }) => {
+ put_numeric_type(&mut self.out, kind, &[])?;
+ }
+ TypeResolution::Value(crate::TypeInner::Vector { size, kind, .. }) => {
+ put_numeric_type(&mut self.out, kind, &[size])?;
+ }
+ TypeResolution::Value(crate::TypeInner::Matrix { columns, rows, .. }) => {
+ put_numeric_type(&mut self.out, crate::ScalarKind::Float, &[rows, columns])?;
+ }
+ TypeResolution::Value(ref other) => {
+ log::warn!("Type {:?} isn't a known local", other); //TEMP!
+ return Err(Error::FeatureNotImplemented("weird local type".to_string()));
+ }
+ }
+
+ //TODO: figure out the naming scheme that wouldn't collide with user names.
+ write!(self.out, " {name} = ")?;
+
+ Ok(())
+ }
+
+ /// Cache a clamped level of detail value, if necessary.
+ ///
+ /// [`ImageLoad`] accesses covered by [`BoundsCheckPolicy::Restrict`] use a
+ /// properly clamped level of detail value both in the access itself, and
+ /// for fetching the size of the requested MIP level, needed to clamp the
+ /// coordinates. To avoid recomputing this clamped level of detail, we cache
+ /// it in a temporary variable, as part of the [`Emit`] statement covering
+ /// the [`ImageLoad`] expression.
+ ///
+ /// [`ImageLoad`]: crate::Expression::ImageLoad
+ /// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict
+ /// [`Emit`]: crate::Statement::Emit
+ fn put_cache_restricted_level(
+ &mut self,
+ load: Handle<crate::Expression>,
+ image: Handle<crate::Expression>,
+ mip_level: Option<Handle<crate::Expression>>,
+ indent: back::Level,
+ context: &StatementContext,
+ ) -> BackendResult {
+ // Does this image access actually require (or even permit) a
+ // level-of-detail, and does the policy require us to restrict it?
+ let level_of_detail = match mip_level {
+ Some(level) => level,
+ None => return Ok(()),
+ };
+
+ if context.expression.policies.image != index::BoundsCheckPolicy::Restrict
+ || !context.expression.image_needs_lod(image)
+ {
+ return Ok(());
+ }
+
+ write!(
+ self.out,
+ "{}uint {}{} = ",
+ indent,
+ CLAMPED_LOD_LOAD_PREFIX,
+ load.index(),
+ )?;
+ self.put_restricted_scalar_image_index(
+ image,
+ level_of_detail,
+ "get_num_mip_levels",
+ &context.expression,
+ )?;
+ writeln!(self.out, ";")?;
+
+ Ok(())
+ }
+
+ fn put_block(
+ &mut self,
+ level: back::Level,
+ statements: &[crate::Statement],
+ context: &StatementContext,
+ ) -> BackendResult {
+ // Add to the set in order to track the stack size.
+ #[cfg(test)]
+ #[allow(trivial_casts)]
+ self.put_block_stack_pointers
+ .insert(&level as *const _ as *const ());
+
+ for statement in statements {
+ log::trace!("statement[{}] {:?}", level.0, statement);
+ match *statement {
+ crate::Statement::Emit(ref range) => {
+ for handle in range.clone() {
+ // `ImageLoad` expressions covered by the `Restrict` bounds check policy
+ // may need to cache a clamped version of their level-of-detail argument.
+ if let crate::Expression::ImageLoad {
+ image,
+ level: mip_level,
+ ..
+ } = context.expression.function.expressions[handle]
+ {
+ self.put_cache_restricted_level(
+ handle, image, mip_level, level, context,
+ )?;
+ }
+
+ let info = &context.expression.info[handle];
+ let ptr_class = info
+ .ty
+ .inner_with(&context.expression.module.types)
+ .pointer_space();
+ let expr_name = if ptr_class.is_some() {
+ None // don't bake pointer expressions (just yet)
+ } else if let Some(name) =
+ context.expression.function.named_expressions.get(&handle)
+ {
+ // The `crate::Function::named_expressions` table holds
+ // expressions that should be saved in temporaries once they
+ // are `Emit`ted. We only add them to `self.named_expressions`
+ // when we reach the `Emit` that covers them, so that we don't
+ // try to use their names before we've actually initialized
+ // the temporary that holds them.
+ //
+ // Don't assume the names in `named_expressions` are unique,
+ // or even valid. Use the `Namer`.
+ Some(self.namer.call(name))
+ } else if info.ref_count == 0 {
+ Some(self.namer.call(""))
+ } else {
+ // If this expression is an index that we're going to first compare
+ // against a limit, and then actually use as an index, then we may
+ // want to cache it in a temporary, to avoid evaluating it twice.
+ let bake =
+ if context.expression.guarded_indices.contains(handle.index()) {
+ true
+ } else {
+ self.need_bake_expressions.contains(&handle)
+ };
+
+ if bake {
+ Some(format!("{}{}", back::BAKE_PREFIX, handle.index()))
+ } else {
+ None
+ }
+ };
+
+ if let Some(name) = expr_name {
+ write!(self.out, "{level}")?;
+ self.start_baking_expression(handle, &context.expression, &name)?;
+ self.put_expression(handle, &context.expression, true)?;
+ self.named_expressions.insert(handle, name);
+ writeln!(self.out, ";")?;
+ }
+ }
+ }
+ crate::Statement::Block(ref block) => {
+ if !block.is_empty() {
+ writeln!(self.out, "{level}{{")?;
+ self.put_block(level.next(), block, context)?;
+ writeln!(self.out, "{level}}}")?;
+ }
+ }
+ crate::Statement::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ write!(self.out, "{level}if (")?;
+ self.put_expression(condition, &context.expression, true)?;
+ writeln!(self.out, ") {{")?;
+ self.put_block(level.next(), accept, context)?;
+ if !reject.is_empty() {
+ writeln!(self.out, "{level}}} else {{")?;
+ self.put_block(level.next(), reject, context)?;
+ }
+ writeln!(self.out, "{level}}}")?;
+ }
+ crate::Statement::Switch {
+ selector,
+ ref cases,
+ } => {
+ write!(self.out, "{level}switch(")?;
+ self.put_expression(selector, &context.expression, true)?;
+ writeln!(self.out, ") {{")?;
+ let lcase = level.next();
+ for case in cases.iter() {
+ match case.value {
+ crate::SwitchValue::I32(value) => {
+ write!(self.out, "{lcase}case {value}:")?;
+ }
+ crate::SwitchValue::U32(value) => {
+ write!(self.out, "{lcase}case {value}u:")?;
+ }
+ crate::SwitchValue::Default => {
+ write!(self.out, "{lcase}default:")?;
+ }
+ }
+
+ let write_block_braces = !(case.fall_through && case.body.is_empty());
+ if write_block_braces {
+ writeln!(self.out, " {{")?;
+ } else {
+ writeln!(self.out)?;
+ }
+
+ self.put_block(lcase.next(), &case.body, context)?;
+ if !case.fall_through
+ && case.body.last().map_or(true, |s| !s.is_terminator())
+ {
+ writeln!(self.out, "{}break;", lcase.next())?;
+ }
+
+ if write_block_braces {
+ writeln!(self.out, "{lcase}}}")?;
+ }
+ }
+ writeln!(self.out, "{level}}}")?;
+ }
+ crate::Statement::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ if !continuing.is_empty() || break_if.is_some() {
+ let gate_name = self.namer.call("loop_init");
+ writeln!(self.out, "{level}bool {gate_name} = true;")?;
+ writeln!(self.out, "{level}while(true) {{")?;
+ let lif = level.next();
+ let lcontinuing = lif.next();
+ writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
+ self.put_block(lcontinuing, continuing, context)?;
+ if let Some(condition) = break_if {
+ write!(self.out, "{lcontinuing}if (")?;
+ self.put_expression(condition, &context.expression, true)?;
+ writeln!(self.out, ") {{")?;
+ writeln!(self.out, "{}break;", lcontinuing.next())?;
+ writeln!(self.out, "{lcontinuing}}}")?;
+ }
+ writeln!(self.out, "{lif}}}")?;
+ writeln!(self.out, "{lif}{gate_name} = false;")?;
+ } else {
+ writeln!(self.out, "{level}while(true) {{")?;
+ }
+ self.put_block(level.next(), body, context)?;
+ writeln!(self.out, "{level}}}")?;
+ }
+ crate::Statement::Break => {
+ writeln!(self.out, "{level}break;")?;
+ }
+ crate::Statement::Continue => {
+ writeln!(self.out, "{level}continue;")?;
+ }
+ crate::Statement::Return {
+ value: Some(expr_handle),
+ } => {
+ self.put_return_value(
+ level,
+ expr_handle,
+ context.result_struct,
+ &context.expression,
+ )?;
+ }
+ crate::Statement::Return { value: None } => {
+ writeln!(self.out, "{level}return;")?;
+ }
+ crate::Statement::Kill => {
+ writeln!(self.out, "{level}{NAMESPACE}::discard_fragment();")?;
+ }
+ crate::Statement::Barrier(flags) => {
+ self.write_barrier(flags, level)?;
+ }
+ crate::Statement::Store { pointer, value } => {
+ self.put_store(pointer, value, level, context)?
+ }
+ crate::Statement::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => {
+ let address = TexelAddress {
+ coordinate,
+ array_index,
+ sample: None,
+ level: None,
+ };
+ self.put_image_store(level, image, &address, value, context)?
+ }
+ crate::Statement::Call {
+ function,
+ ref arguments,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ if let Some(expr) = result {
+ let name = format!("{}{}", back::BAKE_PREFIX, expr.index());
+ self.start_baking_expression(expr, &context.expression, &name)?;
+ self.named_expressions.insert(expr, name);
+ }
+ let fun_name = &self.names[&NameKey::Function(function)];
+ write!(self.out, "{fun_name}(")?;
+ // first, write down the actual arguments
+ for (i, &handle) in arguments.iter().enumerate() {
+ if i != 0 {
+ write!(self.out, ", ")?;
+ }
+ self.put_expression(handle, &context.expression, true)?;
+ }
+ // follow-up with any global resources used
+ let mut separate = !arguments.is_empty();
+ let fun_info = &context.mod_info[function];
+ let mut supports_array_length = false;
+ for (handle, var) in context.expression.module.global_variables.iter() {
+ if fun_info[handle].is_empty() {
+ continue;
+ }
+ if var.space.needs_pass_through() {
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ if separate {
+ write!(self.out, ", ")?;
+ } else {
+ separate = true;
+ }
+ write!(self.out, "{name}")?;
+ }
+ supports_array_length |=
+ needs_array_length(var.ty, &context.expression.module.types);
+ }
+ if supports_array_length {
+ if separate {
+ write!(self.out, ", ")?;
+ }
+ write!(self.out, "_buffer_sizes")?;
+ }
+
+ // done
+ writeln!(self.out, ");")?;
+ }
+ crate::Statement::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ self.start_baking_expression(result, &context.expression, &res_name)?;
+ self.named_expressions.insert(result, res_name);
+ match *fun {
+ crate::AtomicFunction::Add => {
+ self.put_atomic_fetch(pointer, "add", value, &context.expression)?;
+ }
+ crate::AtomicFunction::Subtract => {
+ self.put_atomic_fetch(pointer, "sub", value, &context.expression)?;
+ }
+ crate::AtomicFunction::And => {
+ self.put_atomic_fetch(pointer, "and", value, &context.expression)?;
+ }
+ crate::AtomicFunction::InclusiveOr => {
+ self.put_atomic_fetch(pointer, "or", value, &context.expression)?;
+ }
+ crate::AtomicFunction::ExclusiveOr => {
+ self.put_atomic_fetch(pointer, "xor", value, &context.expression)?;
+ }
+ crate::AtomicFunction::Min => {
+ self.put_atomic_fetch(pointer, "min", value, &context.expression)?;
+ }
+ crate::AtomicFunction::Max => {
+ self.put_atomic_fetch(pointer, "max", value, &context.expression)?;
+ }
+ crate::AtomicFunction::Exchange { compare: None } => {
+ self.put_atomic_operation(
+ pointer,
+ "exchange",
+ "",
+ value,
+ &context.expression,
+ )?;
+ }
+ crate::AtomicFunction::Exchange { .. } => {
+ return Err(Error::FeatureNotImplemented(
+ "atomic CompareExchange".to_string(),
+ ));
+ }
+ }
+ // done
+ writeln!(self.out, ";")?;
+ }
+ crate::Statement::RayQuery { query, ref fun } => {
+ match *fun {
+ crate::RayQueryFunction::Initialize {
+ acceleration_structure,
+ descriptor,
+ } => {
+ //TODO: how to deal with winding?
+ write!(self.out, "{level}")?;
+ self.put_expression(query, &context.expression, true)?;
+ writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.assume_geometry_type({RT_NAMESPACE}::geometry_type::triangle);")?;
+ {
+ let f_opaque = back::RayFlag::CULL_OPAQUE.bits();
+ let f_no_opaque = back::RayFlag::CULL_NO_OPAQUE.bits();
+ write!(self.out, "{level}")?;
+ self.put_expression(query, &context.expression, true)?;
+ write!(
+ self.out,
+ ".{RAY_QUERY_FIELD_INTERSECTOR}.set_opacity_cull_mode(("
+ )?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::opaque : (")?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::non_opaque : ")?;
+ writeln!(self.out, "{RT_NAMESPACE}::opacity_cull_mode::none);")?;
+ }
+ {
+ let f_opaque = back::RayFlag::OPAQUE.bits();
+ let f_no_opaque = back::RayFlag::NO_OPAQUE.bits();
+ write!(self.out, "{level}")?;
+ self.put_expression(query, &context.expression, true)?;
+ write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.force_opacity((")?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::opaque : (")?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::non_opaque : ")?;
+ writeln!(self.out, "{RT_NAMESPACE}::forced_opacity::none);")?;
+ }
+ {
+ let flag = back::RayFlag::TERMINATE_ON_FIRST_HIT.bits();
+ write!(self.out, "{level}")?;
+ self.put_expression(query, &context.expression, true)?;
+ write!(
+ self.out,
+ ".{RAY_QUERY_FIELD_INTERSECTOR}.accept_any_intersection(("
+ )?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ writeln!(self.out, ".flags & {flag}) != 0);")?;
+ }
+
+ write!(self.out, "{level}")?;
+ self.put_expression(query, &context.expression, true)?;
+ write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION} = ")?;
+ self.put_expression(query, &context.expression, true)?;
+ write!(
+ self.out,
+ ".{RAY_QUERY_FIELD_INTERSECTOR}.intersect({RT_NAMESPACE}::ray("
+ )?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".origin, ")?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".dir, ")?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".tmin, ")?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".tmax), ")?;
+ self.put_expression(acceleration_structure, &context.expression, true)?;
+ write!(self.out, ", ")?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".cull_mask);")?;
+
+ write!(self.out, "{level}")?;
+ self.put_expression(query, &context.expression, true)?;
+ writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = true;")?;
+ }
+ crate::RayQueryFunction::Proceed { result } => {
+ write!(self.out, "{level}")?;
+ let name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ self.start_baking_expression(result, &context.expression, &name)?;
+ self.named_expressions.insert(result, name);
+ self.put_expression(query, &context.expression, true)?;
+ writeln!(self.out, ".{RAY_QUERY_FIELD_READY};")?;
+ //TODO: actually proceed?
+
+ write!(self.out, "{level}")?;
+ self.put_expression(query, &context.expression, true)?;
+ writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = false;")?;
+ }
+ crate::RayQueryFunction::Terminate => {
+ write!(self.out, "{level}")?;
+ self.put_expression(query, &context.expression, true)?;
+ writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.abort();")?;
+ }
+ }
+ }
+ }
+ }
+
+ // un-emit expressions
+ //TODO: take care of loop/continuing?
+ for statement in statements {
+ if let crate::Statement::Emit(ref range) = *statement {
+ for handle in range.clone() {
+ self.named_expressions.remove(&handle);
+ }
+ }
+ }
+ Ok(())
+ }
+
+ fn put_store(
+ &mut self,
+ pointer: Handle<crate::Expression>,
+ value: Handle<crate::Expression>,
+ level: back::Level,
+ context: &StatementContext,
+ ) -> BackendResult {
+ let policy = context.expression.choose_bounds_check_policy(pointer);
+ if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
+ && self.put_bounds_checks(pointer, &context.expression, level, "if (")?
+ {
+ writeln!(self.out, ") {{")?;
+ self.put_unchecked_store(pointer, value, policy, level.next(), context)?;
+ writeln!(self.out, "{level}}}")?;
+ } else {
+ self.put_unchecked_store(pointer, value, policy, level, context)?;
+ }
+
+ Ok(())
+ }
+
+ fn put_unchecked_store(
+ &mut self,
+ pointer: Handle<crate::Expression>,
+ value: Handle<crate::Expression>,
+ policy: index::BoundsCheckPolicy,
+ level: back::Level,
+ context: &StatementContext,
+ ) -> BackendResult {
+ let pointer_inner = context.expression.resolve_type(pointer);
+ let (array_size, is_atomic) = match *pointer_inner {
+ crate::TypeInner::Pointer { base, .. } => {
+ match context.expression.module.types[base].inner {
+ crate::TypeInner::Array {
+ size: crate::ArraySize::Constant(ch),
+ ..
+ } => (Some(ch), false),
+ crate::TypeInner::Atomic { .. } => (None, true),
+ _ => (None, false),
+ }
+ }
+ _ => (None, false),
+ };
+
+ // we can't assign fixed-size arrays
+ if let Some(const_handle) = array_size {
+ let size = context.expression.module.constants[const_handle]
+ .to_array_length()
+ .unwrap();
+ write!(self.out, "{level}for(int _i=0; _i<{size}; ++_i) ")?;
+ self.put_access_chain(pointer, policy, &context.expression)?;
+ write!(self.out, ".{WRAPPED_ARRAY_FIELD}[_i] = ")?;
+ self.put_expression(value, &context.expression, true)?;
+ writeln!(self.out, ".{WRAPPED_ARRAY_FIELD}[_i];")?;
+ } else if is_atomic {
+ write!(
+ self.out,
+ "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
+ )?;
+ self.put_access_chain(pointer, policy, &context.expression)?;
+ write!(self.out, ", ")?;
+ self.put_expression(value, &context.expression, true)?;
+ writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?;
+ } else {
+ write!(self.out, "{level}")?;
+ self.put_access_chain(pointer, policy, &context.expression)?;
+ write!(self.out, " = ")?;
+ self.put_expression(value, &context.expression, true)?;
+ writeln!(self.out, ";")?;
+ }
+
+ Ok(())
+ }
+
+ pub fn write(
+ &mut self,
+ module: &crate::Module,
+ info: &valid::ModuleInfo,
+ options: &Options,
+ pipeline_options: &PipelineOptions,
+ ) -> Result<TranslationInfo, Error> {
+ self.names.clear();
+ self.namer
+ .reset(module, super::keywords::RESERVED, &[], &mut self.names);
+ self.struct_member_pads.clear();
+
+ writeln!(
+ self.out,
+ "// language: metal{}.{}",
+ options.lang_version.0, options.lang_version.1
+ )?;
+ writeln!(self.out, "#include <metal_stdlib>")?;
+ writeln!(self.out, "#include <simd/simd.h>")?;
+ writeln!(self.out)?;
+ // Work around Metal bug where `uint` is not available by default
+ writeln!(self.out, "using {NAMESPACE}::uint;")?;
+
+ if module.types.iter().any(|(_, t)| match t.inner {
+ crate::TypeInner::RayQuery => true,
+ _ => false,
+ }) {
+ let tab = back::INDENT;
+ writeln!(self.out, "struct {RAY_QUERY_TYPE} {{")?;
+ let full_type = format!("{RT_NAMESPACE}::intersector<{RT_NAMESPACE}::instancing, {RT_NAMESPACE}::triangle_data, {RT_NAMESPACE}::world_space_data>");
+ writeln!(self.out, "{tab}{full_type} {RAY_QUERY_FIELD_INTERSECTOR};")?;
+ writeln!(
+ self.out,
+ "{tab}{full_type}::result_type {RAY_QUERY_FIELD_INTERSECTION};"
+ )?;
+ writeln!(self.out, "{tab}bool {RAY_QUERY_FIELD_READY} = false;")?;
+ writeln!(self.out, "}};")?;
+ writeln!(self.out, "constexpr {NAMESPACE}::uint {RAY_QUERY_FUN_MAP_INTERSECTION}(const {RT_NAMESPACE}::intersection_type ty) {{")?;
+ let v_triangle = back::RayIntersectionType::Triangle as u32;
+ let v_bbox = back::RayIntersectionType::BoundingBox as u32;
+ writeln!(
+ self.out,
+ "{tab}return ty=={RT_NAMESPACE}::intersection_type::triangle ? {v_triangle} : "
+ )?;
+ writeln!(
+ self.out,
+ "{tab}{tab}ty=={RT_NAMESPACE}::intersection_type::bounding_box ? {v_bbox} : 0;"
+ )?;
+ writeln!(self.out, "}}")?;
+ }
+ if options
+ .bounds_check_policies
+ .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite)
+ {
+ self.put_default_constructible()?;
+ }
+ writeln!(self.out)?;
+
+ {
+ let mut indices = vec![];
+ for (handle, var) in module.global_variables.iter() {
+ if needs_array_length(var.ty, &module.types) {
+ let idx = handle.index();
+ indices.push(idx);
+ }
+ }
+
+ if !indices.is_empty() {
+ writeln!(self.out, "struct _mslBufferSizes {{")?;
+
+ for idx in indices {
+ writeln!(self.out, "{}uint size{};", back::INDENT, idx)?;
+ }
+
+ writeln!(self.out, "}};")?;
+ writeln!(self.out)?;
+ }
+ };
+
+ self.write_scalar_constants(module)?;
+ self.write_type_defs(module)?;
+ self.write_composite_constants(module)?;
+ self.write_functions(module, info, options, pipeline_options)
+ }
+
+ /// Write the definition for the `DefaultConstructible` class.
+ ///
+ /// The [`ReadZeroSkipWrite`] bounds check policy requires us to be able to
+ /// produce 'zero' values for any type, including structs, arrays, and so
+ /// on. We could do this by emitting default constructor applications, but
+ /// that would entail printing the name of the type, which is more trouble
+ /// than you'd think. Instead, we just construct this magic C++14 class that
+ /// can be converted to any type that can be default constructed, using
+ /// template parameter inference to detect which type is needed, so we don't
+ /// have to figure out the name.
+ ///
+ /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
+ fn put_default_constructible(&mut self) -> BackendResult {
+ let tab = back::INDENT;
+ writeln!(self.out, "struct DefaultConstructible {{")?;
+ writeln!(self.out, "{tab}template<typename T>")?;
+ writeln!(self.out, "{tab}operator T() && {{")?;
+ writeln!(self.out, "{tab}{tab}return T {{}};")?;
+ writeln!(self.out, "{tab}}}")?;
+ writeln!(self.out, "}};")?;
+ Ok(())
+ }
+
+ fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult {
+ for (handle, ty) in module.types.iter() {
+ if !ty.needs_alias() {
+ continue;
+ }
+ let name = &self.names[&NameKey::Type(handle)];
+ match ty.inner {
+ // Naga IR can pass around arrays by value, but Metal, following
+ // C++, performs an array-to-pointer conversion (C++ [conv.array])
+ // on expressions of array type, so assigning the array by value
+ // isn't possible. However, Metal *does* assign structs by
+ // value. So in our Metal output, we wrap all array types in
+ // synthetic struct types:
+ //
+ // struct type1 {
+ // float inner[10]
+ // };
+ //
+ // Then we carefully include `.inner` (`WRAPPED_ARRAY_FIELD`) in
+ // any expression that actually wants access to the array.
+ crate::TypeInner::Array {
+ base,
+ size,
+ stride: _,
+ } => {
+ let base_name = TypeContext {
+ handle: base,
+ module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+
+ match size {
+ crate::ArraySize::Constant(const_handle) => {
+ let coco = ConstantContext {
+ handle: const_handle,
+ arena: &module.constants,
+ names: &self.names,
+ first_time: false,
+ };
+
+ writeln!(self.out, "struct {name} {{")?;
+ writeln!(
+ self.out,
+ "{}{} {}[{}];",
+ back::INDENT,
+ base_name,
+ WRAPPED_ARRAY_FIELD,
+ coco
+ )?;
+ writeln!(self.out, "}};")?;
+ }
+ crate::ArraySize::Dynamic => {
+ writeln!(self.out, "typedef {base_name} {name}[1];")?;
+ }
+ }
+ }
+ crate::TypeInner::Struct {
+ ref members, span, ..
+ } => {
+ writeln!(self.out, "struct {name} {{")?;
+ let mut last_offset = 0;
+ for (index, member) in members.iter().enumerate() {
+ // quick and dirty way to figure out if we need this...
+ if member.binding.is_none() && member.offset > last_offset {
+ self.struct_member_pads.insert((handle, index as u32));
+ let pad = member.offset - last_offset;
+ writeln!(self.out, "{}char _pad{}[{}];", back::INDENT, index, pad)?;
+ }
+ let ty_inner = &module.types[member.ty].inner;
+ last_offset = member.offset + ty_inner.size(&module.constants);
+
+ let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
+
+ // If the member should be packed (as is the case for a misaligned vec3) issue a packed vector
+ match should_pack_struct_member(members, span, index, module) {
+ Some(kind) => {
+ writeln!(
+ self.out,
+ "{}{}::packed_{}3 {};",
+ back::INDENT,
+ NAMESPACE,
+ kind.to_msl_name(),
+ member_name
+ )?;
+ }
+ None => {
+ let base_name = TypeContext {
+ handle: member.ty,
+ module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ writeln!(
+ self.out,
+ "{}{} {};",
+ back::INDENT,
+ base_name,
+ member_name
+ )?;
+
+ // for 3-component vectors, add one component
+ if let crate::TypeInner::Vector {
+ size: crate::VectorSize::Tri,
+ kind: _,
+ width,
+ } = *ty_inner
+ {
+ last_offset += width as u32;
+ }
+ }
+ }
+ }
+ writeln!(self.out, "}};")?;
+ }
+ _ => {
+ let ty_name = TypeContext {
+ handle,
+ module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: true,
+ };
+ writeln!(self.out, "typedef {ty_name} {name};")?;
+ }
+ }
+ }
+ Ok(())
+ }
+
+ fn write_scalar_constants(&mut self, module: &crate::Module) -> BackendResult {
+ for (handle, constant) in module.constants.iter() {
+ match constant.inner {
+ crate::ConstantInner::Scalar {
+ width: _,
+ ref value,
+ } if constant.name.is_some() => {
+ debug_assert!(constant.needs_alias());
+ write!(self.out, "constexpr constant ")?;
+ match *value {
+ crate::ScalarValue::Sint(_) => {
+ write!(self.out, "int")?;
+ }
+ crate::ScalarValue::Uint(_) => {
+ write!(self.out, "unsigned")?;
+ }
+ crate::ScalarValue::Float(_) => {
+ write!(self.out, "float")?;
+ }
+ crate::ScalarValue::Bool(_) => {
+ write!(self.out, "bool")?;
+ }
+ }
+ let name = &self.names[&NameKey::Constant(handle)];
+ let coco = ConstantContext {
+ handle,
+ arena: &module.constants,
+ names: &self.names,
+ first_time: true,
+ };
+ writeln!(self.out, " {name} = {coco};")?;
+ }
+ _ => {}
+ }
+ }
+ Ok(())
+ }
+
+ fn write_composite_constants(&mut self, module: &crate::Module) -> BackendResult {
+ for (handle, constant) in module.constants.iter() {
+ match constant.inner {
+ crate::ConstantInner::Scalar { .. } => {}
+ crate::ConstantInner::Composite { ty, ref components } => {
+ debug_assert!(constant.needs_alias());
+ let name = &self.names[&NameKey::Constant(handle)];
+ let ty_name = TypeContext {
+ handle: ty,
+ module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ write!(self.out, "constant {ty_name} {name} = {{",)?;
+ for (i, &sub_handle) in components.iter().enumerate() {
+ // insert padding initialization, if needed
+ if self.struct_member_pads.contains(&(ty, i as u32)) {
+ write!(self.out, ", {{}}")?;
+ }
+ let separator = if i != 0 { ", " } else { "" };
+ let coco = ConstantContext {
+ handle: sub_handle,
+ arena: &module.constants,
+ names: &self.names,
+ first_time: false,
+ };
+ write!(self.out, "{separator}{coco}")?;
+ }
+ writeln!(self.out, "}};")?;
+ }
+ }
+ }
+ Ok(())
+ }
+
+ fn put_inline_sampler_properties(
+ &mut self,
+ level: back::Level,
+ sampler: &sm::InlineSampler,
+ ) -> BackendResult {
+ for (&letter, address) in ['s', 't', 'r'].iter().zip(sampler.address.iter()) {
+ writeln!(
+ self.out,
+ "{}{}::{}_address::{},",
+ level,
+ NAMESPACE,
+ letter,
+ address.as_str(),
+ )?;
+ }
+ writeln!(
+ self.out,
+ "{}{}::mag_filter::{},",
+ level,
+ NAMESPACE,
+ sampler.mag_filter.as_str(),
+ )?;
+ writeln!(
+ self.out,
+ "{}{}::min_filter::{},",
+ level,
+ NAMESPACE,
+ sampler.min_filter.as_str(),
+ )?;
+ if let Some(filter) = sampler.mip_filter {
+ writeln!(
+ self.out,
+ "{}{}::mip_filter::{},",
+ level,
+ NAMESPACE,
+ filter.as_str(),
+ )?;
+ }
+ // avoid setting it on platforms that don't support it
+ if sampler.border_color != sm::BorderColor::TransparentBlack {
+ writeln!(
+ self.out,
+ "{}{}::border_color::{},",
+ level,
+ NAMESPACE,
+ sampler.border_color.as_str(),
+ )?;
+ }
+ //TODO: I'm not able to feed this in a way that MSL likes:
+ //>error: use of undeclared identifier 'lod_clamp'
+ //>error: no member named 'max_anisotropy' in namespace 'metal'
+ if false {
+ if let Some(ref lod) = sampler.lod_clamp {
+ writeln!(self.out, "{}lod_clamp({},{}),", level, lod.start, lod.end,)?;
+ }
+ if let Some(aniso) = sampler.max_anisotropy {
+ writeln!(self.out, "{}max_anisotropy({}),", level, aniso.get(),)?;
+ }
+ }
+ if sampler.compare_func != sm::CompareFunc::Never {
+ writeln!(
+ self.out,
+ "{}{}::compare_func::{},",
+ level,
+ NAMESPACE,
+ sampler.compare_func.as_str(),
+ )?;
+ }
+ writeln!(
+ self.out,
+ "{}{}::coord::{}",
+ level,
+ NAMESPACE,
+ sampler.coord.as_str()
+ )?;
+ Ok(())
+ }
+
+ // Returns the array of mapped entry point names.
+ fn write_functions(
+ &mut self,
+ module: &crate::Module,
+ mod_info: &valid::ModuleInfo,
+ options: &Options,
+ pipeline_options: &PipelineOptions,
+ ) -> Result<TranslationInfo, Error> {
+ let mut pass_through_globals = Vec::new();
+ for (fun_handle, fun) in module.functions.iter() {
+ log::trace!(
+ "function {:?}, handle {:?}",
+ fun.name.as_deref().unwrap_or("(anonymous)"),
+ fun_handle
+ );
+
+ let fun_info = &mod_info[fun_handle];
+ pass_through_globals.clear();
+ let mut supports_array_length = false;
+ for (handle, var) in module.global_variables.iter() {
+ if !fun_info[handle].is_empty() {
+ if var.space.needs_pass_through() {
+ pass_through_globals.push(handle);
+ }
+ supports_array_length |= needs_array_length(var.ty, &module.types);
+ }
+ }
+
+ writeln!(self.out)?;
+ let fun_name = &self.names[&NameKey::Function(fun_handle)];
+ match fun.result {
+ Some(ref result) => {
+ let ty_name = TypeContext {
+ handle: result.ty,
+ module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ write!(self.out, "{ty_name}")?;
+ }
+ None => {
+ write!(self.out, "void")?;
+ }
+ }
+ writeln!(self.out, " {fun_name}(")?;
+
+ for (index, arg) in fun.arguments.iter().enumerate() {
+ let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)];
+ let param_type_name = TypeContext {
+ handle: arg.ty,
+ module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ let separator = separate(
+ !pass_through_globals.is_empty()
+ || index + 1 != fun.arguments.len()
+ || supports_array_length,
+ );
+ writeln!(
+ self.out,
+ "{}{} {}{}",
+ back::INDENT,
+ param_type_name,
+ name,
+ separator
+ )?;
+ }
+ for (index, &handle) in pass_through_globals.iter().enumerate() {
+ let tyvar = TypedGlobalVariable {
+ module,
+ names: &self.names,
+ handle,
+ usage: fun_info[handle],
+ binding: None,
+ reference: true,
+ };
+ let separator =
+ separate(index + 1 != pass_through_globals.len() || supports_array_length);
+ write!(self.out, "{}", back::INDENT)?;
+ tyvar.try_fmt(&mut self.out)?;
+ writeln!(self.out, "{separator}")?;
+ }
+
+ if supports_array_length {
+ writeln!(
+ self.out,
+ "{}constant _mslBufferSizes& _buffer_sizes",
+ back::INDENT
+ )?;
+ }
+
+ writeln!(self.out, ") {{")?;
+
+ for (local_handle, local) in fun.local_variables.iter() {
+ let ty_name = TypeContext {
+ handle: local.ty,
+ module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ let local_name = &self.names[&NameKey::FunctionLocal(fun_handle, local_handle)];
+ write!(self.out, "{}{} {}", back::INDENT, ty_name, local_name)?;
+ match local.init {
+ Some(value) => {
+ let coco = ConstantContext {
+ handle: value,
+ arena: &module.constants,
+ names: &self.names,
+ first_time: false,
+ };
+ write!(self.out, " = {coco}")?;
+ }
+ None => {
+ write!(self.out, " = {{}}")?;
+ }
+ };
+ writeln!(self.out, ";")?;
+ }
+
+ let guarded_indices =
+ index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
+
+ let context = StatementContext {
+ expression: ExpressionContext {
+ function: fun,
+ origin: FunctionOrigin::Handle(fun_handle),
+ info: fun_info,
+ policies: options.bounds_check_policies,
+ guarded_indices,
+ module,
+ pipeline_options,
+ },
+ mod_info,
+ result_struct: None,
+ };
+ self.named_expressions.clear();
+ self.update_expressions_to_bake(fun, fun_info, &context.expression);
+ self.put_block(back::Level(1), &fun.body, &context)?;
+ writeln!(self.out, "}}")?;
+ }
+
+ let mut info = TranslationInfo {
+ entry_point_names: Vec::with_capacity(module.entry_points.len()),
+ };
+ for (ep_index, ep) in module.entry_points.iter().enumerate() {
+ let fun = &ep.function;
+ let fun_info = mod_info.get_entry_point(ep_index);
+ let mut ep_error = None;
+
+ log::trace!(
+ "entry point {:?}, index {:?}",
+ fun.name.as_deref().unwrap_or("(anonymous)"),
+ ep_index
+ );
+
+ // Is any global variable used by this entry point dynamically sized?
+ let supports_array_length = module
+ .global_variables
+ .iter()
+ .filter(|&(handle, _)| !fun_info[handle].is_empty())
+ .any(|(_, var)| needs_array_length(var.ty, &module.types));
+
+ // skip this entry point if any global bindings are missing,
+ // or their types are incompatible.
+ if !options.fake_missing_bindings {
+ for (var_handle, var) in module.global_variables.iter() {
+ if fun_info[var_handle].is_empty() {
+ continue;
+ }
+ match var.space {
+ crate::AddressSpace::Uniform
+ | crate::AddressSpace::Storage { .. }
+ | crate::AddressSpace::Handle => {
+ let br = match var.binding {
+ Some(ref br) => br,
+ None => {
+ let var_name = var.name.clone().unwrap_or_default();
+ ep_error =
+ Some(super::EntryPointError::MissingBinding(var_name));
+ break;
+ }
+ };
+ let target = options.get_resource_binding_target(ep, br);
+ let good = match target {
+ Some(target) => {
+ let binding_ty = match module.types[var.ty].inner {
+ crate::TypeInner::BindingArray { base, .. } => {
+ &module.types[base].inner
+ }
+ ref ty => ty,
+ };
+ match *binding_ty {
+ crate::TypeInner::Image { .. } => target.texture.is_some(),
+ crate::TypeInner::Sampler { .. } => {
+ target.sampler.is_some()
+ }
+ _ => target.buffer.is_some(),
+ }
+ }
+ None => false,
+ };
+ if !good {
+ ep_error =
+ Some(super::EntryPointError::MissingBindTarget(br.clone()));
+ break;
+ }
+ }
+ crate::AddressSpace::PushConstant => {
+ if let Err(e) = options.resolve_push_constants(ep) {
+ ep_error = Some(e);
+ break;
+ }
+ }
+ crate::AddressSpace::Function
+ | crate::AddressSpace::Private
+ | crate::AddressSpace::WorkGroup => {}
+ }
+ }
+ if supports_array_length {
+ if let Err(err) = options.resolve_sizes_buffer(ep) {
+ ep_error = Some(err);
+ }
+ }
+ }
+
+ if let Some(err) = ep_error {
+ info.entry_point_names.push(Err(err));
+ continue;
+ }
+ let fun_name = &self.names[&NameKey::EntryPoint(ep_index as _)];
+ info.entry_point_names.push(Ok(fun_name.clone()));
+
+ writeln!(self.out)?;
+
+ let (em_str, in_mode, out_mode) = match ep.stage {
+ crate::ShaderStage::Vertex => (
+ "vertex",
+ LocationMode::VertexInput,
+ LocationMode::VertexOutput,
+ ),
+ crate::ShaderStage::Fragment { .. } => (
+ "fragment",
+ LocationMode::FragmentInput,
+ LocationMode::FragmentOutput,
+ ),
+ crate::ShaderStage::Compute { .. } => {
+ ("kernel", LocationMode::Uniform, LocationMode::Uniform)
+ }
+ };
+
+ // List all the Naga `EntryPoint`'s `Function`'s arguments,
+ // flattening structs into their members. In Metal, we will pass
+ // each of these values to the entry point as a separate argument—
+ // except for the varyings, handled next.
+ let mut flattened_arguments = Vec::new();
+ for (arg_index, arg) in fun.arguments.iter().enumerate() {
+ match module.types[arg.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ for (member_index, member) in members.iter().enumerate() {
+ let member_index = member_index as u32;
+ flattened_arguments.push((
+ NameKey::StructMember(arg.ty, member_index),
+ member.ty,
+ member.binding.as_ref(),
+ ));
+ }
+ }
+ _ => flattened_arguments.push((
+ NameKey::EntryPointArgument(ep_index as _, arg_index as u32),
+ arg.ty,
+ arg.binding.as_ref(),
+ )),
+ }
+ }
+
+ // Identify the varyings among the argument values, and emit a
+ // struct type named `<fun>Input` to hold them.
+ let stage_in_name = format!("{fun_name}Input");
+ let varyings_member_name = self.namer.call("varyings");
+ let mut has_varyings = false;
+ if !flattened_arguments.is_empty() {
+ writeln!(self.out, "struct {stage_in_name} {{")?;
+ for &(ref name_key, ty, binding) in flattened_arguments.iter() {
+ let binding = match binding {
+ Some(ref binding @ &crate::Binding::Location { .. }) => binding,
+ _ => continue,
+ };
+ has_varyings = true;
+ let name = &self.names[name_key];
+ let ty_name = TypeContext {
+ handle: ty,
+ module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ let resolved = options.resolve_local_binding(binding, in_mode)?;
+ write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
+ resolved.try_fmt(&mut self.out)?;
+ writeln!(self.out, ";")?;
+ }
+ writeln!(self.out, "}};")?;
+ }
+
+ // Define a struct type named for the return value, if any, named
+ // `<fun>Output`.
+ let stage_out_name = format!("{fun_name}Output");
+ let result_member_name = self.namer.call("member");
+ let result_type_name = match fun.result {
+ Some(ref result) => {
+ let mut result_members = Vec::new();
+ if let crate::TypeInner::Struct { ref members, .. } =
+ module.types[result.ty].inner
+ {
+ for (member_index, member) in members.iter().enumerate() {
+ result_members.push((
+ &self.names[&NameKey::StructMember(result.ty, member_index as u32)],
+ member.ty,
+ member.binding.as_ref(),
+ ));
+ }
+ } else {
+ result_members.push((
+ &result_member_name,
+ result.ty,
+ result.binding.as_ref(),
+ ));
+ }
+
+ writeln!(self.out, "struct {stage_out_name} {{")?;
+ let mut has_point_size = false;
+ for (name, ty, binding) in result_members {
+ let ty_name = TypeContext {
+ handle: ty,
+ module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: true,
+ };
+ let binding = binding.ok_or(Error::Validation)?;
+
+ if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = *binding {
+ has_point_size = true;
+ if !pipeline_options.allow_point_size {
+ continue;
+ }
+ }
+
+ let array_len = match module.types[ty].inner {
+ crate::TypeInner::Array {
+ size: crate::ArraySize::Constant(handle),
+ ..
+ } => module.constants[handle].to_array_length(),
+ _ => None,
+ };
+ let resolved = options.resolve_local_binding(binding, out_mode)?;
+ write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
+ if let Some(array_len) = array_len {
+ write!(self.out, " [{array_len}]")?;
+ }
+ resolved.try_fmt(&mut self.out)?;
+ writeln!(self.out, ";")?;
+ }
+
+ if pipeline_options.allow_point_size
+ && ep.stage == crate::ShaderStage::Vertex
+ && !has_point_size
+ {
+ // inject the point size output last
+ writeln!(
+ self.out,
+ "{}float _point_size [[point_size]];",
+ back::INDENT
+ )?;
+ }
+ writeln!(self.out, "}};")?;
+ &stage_out_name
+ }
+ None => "void",
+ };
+
+ // Write the entry point function's name, and begin its argument list.
+ writeln!(self.out, "{em_str} {result_type_name} {fun_name}(")?;
+ let mut is_first_argument = true;
+
+ // If we have produced a struct holding the `EntryPoint`'s
+ // `Function`'s arguments' varyings, pass that struct first.
+ if has_varyings {
+ writeln!(
+ self.out,
+ " {stage_in_name} {varyings_member_name} [[stage_in]]"
+ )?;
+ is_first_argument = false;
+ }
+
+ let mut local_invocation_id = None;
+
+ // Then pass the remaining arguments not included in the varyings
+ // struct.
+ //
+ // Since `Namer.reset` wasn't expecting struct members to be
+ // suddenly injected into the normal namespace like this,
+ // `self.names` doesn't keep them distinct from other variables.
+ // Generate fresh names for these arguments, and remember the
+ // mapping.
+ let mut flattened_member_names = FastHashMap::default();
+ for &(ref name_key, ty, binding) in flattened_arguments.iter() {
+ let binding = match binding {
+ Some(binding @ &crate::Binding::BuiltIn { .. }) => binding,
+ _ => continue,
+ };
+ let name = if let NameKey::StructMember(ty, index) = *name_key {
+ // We should always insert a fresh entry here, but use
+ // `or_insert` to get a reference to the `String` we just
+ // inserted.
+ flattened_member_names
+ .entry(NameKey::StructMember(ty, index))
+ .or_insert_with(|| self.namer.call(&self.names[name_key]))
+ } else {
+ &self.names[name_key]
+ };
+
+ if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId) {
+ local_invocation_id = Some(name_key);
+ }
+
+ let ty_name = TypeContext {
+ handle: ty,
+ module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ let resolved = options.resolve_local_binding(binding, in_mode)?;
+ let separator = if is_first_argument {
+ is_first_argument = false;
+ ' '
+ } else {
+ ','
+ };
+ write!(self.out, "{separator} {ty_name} {name}")?;
+ resolved.try_fmt(&mut self.out)?;
+ writeln!(self.out)?;
+ }
+
+ let need_workgroup_variables_initialization =
+ self.need_workgroup_variables_initialization(options, ep, module, fun_info);
+
+ if need_workgroup_variables_initialization && local_invocation_id.is_none() {
+ let separator = if is_first_argument {
+ is_first_argument = false;
+ ' '
+ } else {
+ ','
+ };
+ writeln!(
+ self.out,
+ "{separator} {NAMESPACE}::uint3 __local_invocation_id [[thread_position_in_threadgroup]]"
+ )?;
+ }
+
+ // Those global variables used by this entry point and its callees
+ // get passed as arguments. `Private` globals are an exception, they
+ // don't outlive this invocation, so we declare them below as locals
+ // within the entry point.
+ for (handle, var) in module.global_variables.iter() {
+ let usage = fun_info[handle];
+ if usage.is_empty() || var.space == crate::AddressSpace::Private {
+ continue;
+ }
+ // the resolves have already been checked for `!fake_missing_bindings` case
+ let resolved = match var.space {
+ crate::AddressSpace::PushConstant => options.resolve_push_constants(ep).ok(),
+ crate::AddressSpace::WorkGroup => None,
+ crate::AddressSpace::Storage { .. } if options.lang_version < (2, 0) => {
+ return Err(Error::UnsupportedAddressSpace(var.space))
+ }
+ _ => options
+ .resolve_resource_binding(ep, var.binding.as_ref().unwrap())
+ .ok(),
+ };
+ if let Some(ref resolved) = resolved {
+ // Inline samplers are be defined in the EP body
+ if resolved.as_inline_sampler(options).is_some() {
+ continue;
+ }
+ }
+
+ let tyvar = TypedGlobalVariable {
+ module,
+ names: &self.names,
+ handle,
+ usage,
+ binding: resolved.as_ref(),
+ reference: true,
+ };
+ let separator = if is_first_argument {
+ is_first_argument = false;
+ ' '
+ } else {
+ ','
+ };
+ write!(self.out, "{separator} ")?;
+ tyvar.try_fmt(&mut self.out)?;
+ if let Some(resolved) = resolved {
+ resolved.try_fmt(&mut self.out)?;
+ }
+ if let Some(value) = var.init {
+ let coco = ConstantContext {
+ handle: value,
+ arena: &module.constants,
+ names: &self.names,
+ first_time: false,
+ };
+ write!(self.out, " = {coco}")?;
+ }
+ writeln!(self.out)?;
+ }
+
+ // If this entry uses any variable-length arrays, their sizes are
+ // passed as a final struct-typed argument.
+ if supports_array_length {
+ // this is checked earlier
+ let resolved = options.resolve_sizes_buffer(ep).unwrap();
+ let separator = if module.global_variables.is_empty() {
+ ' '
+ } else {
+ ','
+ };
+ write!(
+ self.out,
+ "{separator} constant _mslBufferSizes& _buffer_sizes",
+ )?;
+ resolved.try_fmt(&mut self.out)?;
+ writeln!(self.out)?;
+ }
+
+ // end of the entry point argument list
+ writeln!(self.out, ") {{")?;
+
+ if need_workgroup_variables_initialization {
+ self.write_workgroup_variables_initialization(
+ module,
+ mod_info,
+ fun_info,
+ local_invocation_id,
+ )?;
+ }
+
+ // Metal doesn't support private mutable variables outside of functions,
+ // so we put them here, just like the locals.
+ for (handle, var) in module.global_variables.iter() {
+ let usage = fun_info[handle];
+ if usage.is_empty() {
+ continue;
+ }
+ if var.space == crate::AddressSpace::Private {
+ let tyvar = TypedGlobalVariable {
+ module,
+ names: &self.names,
+ handle,
+ usage,
+ binding: None,
+ reference: false,
+ };
+ write!(self.out, "{}", back::INDENT)?;
+ tyvar.try_fmt(&mut self.out)?;
+ match var.init {
+ Some(value) => {
+ let coco = ConstantContext {
+ handle: value,
+ arena: &module.constants,
+ names: &self.names,
+ first_time: false,
+ };
+ writeln!(self.out, " = {coco};")?;
+ }
+ None => {
+ writeln!(self.out, " = {{}};")?;
+ }
+ };
+ } else if let Some(ref binding) = var.binding {
+ // write an inline sampler
+ let resolved = options.resolve_resource_binding(ep, binding).unwrap();
+ if let Some(sampler) = resolved.as_inline_sampler(options) {
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ writeln!(
+ self.out,
+ "{}constexpr {}::sampler {}(",
+ back::INDENT,
+ NAMESPACE,
+ name
+ )?;
+ self.put_inline_sampler_properties(back::Level(2), sampler)?;
+ writeln!(self.out, "{});", back::INDENT)?;
+ }
+ }
+ }
+
+ // Now take the arguments that we gathered into structs, and the
+ // structs that we flattened into arguments, and emit local
+ // variables with initializers that put everything back the way the
+ // body code expects.
+ //
+ // If we had to generate fresh names for struct members passed as
+ // arguments, be sure to use those names when rebuilding the struct.
+ //
+ // "Each day, I change some zeros to ones, and some ones to zeros.
+ // The rest, I leave alone."
+ for (arg_index, arg) in fun.arguments.iter().enumerate() {
+ let arg_name =
+ &self.names[&NameKey::EntryPointArgument(ep_index as _, arg_index as u32)];
+ match module.types[arg.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ let struct_name = &self.names[&NameKey::Type(arg.ty)];
+ write!(
+ self.out,
+ "{}const {} {} = {{ ",
+ back::INDENT,
+ struct_name,
+ arg_name
+ )?;
+ for (member_index, member) in members.iter().enumerate() {
+ let key = NameKey::StructMember(arg.ty, member_index as u32);
+ // If it's not in the varying struct, then we should
+ // have passed it as its own argument and assigned
+ // it a new name.
+ let name = match member.binding {
+ Some(crate::Binding::BuiltIn { .. }) => {
+ &flattened_member_names[&key]
+ }
+ _ => &self.names[&key],
+ };
+ if member_index != 0 {
+ write!(self.out, ", ")?;
+ }
+ if let Some(crate::Binding::Location { .. }) = member.binding {
+ write!(self.out, "{varyings_member_name}.")?;
+ }
+ write!(self.out, "{name}")?;
+ }
+ writeln!(self.out, " }};")?;
+ }
+ _ => {
+ if let Some(crate::Binding::Location { .. }) = arg.binding {
+ writeln!(
+ self.out,
+ "{}const auto {} = {}.{};",
+ back::INDENT,
+ arg_name,
+ varyings_member_name,
+ arg_name
+ )?;
+ }
+ }
+ }
+ }
+
+ // Finally, declare all the local variables that we need
+ //TODO: we can postpone this till the relevant expressions are emitted
+ for (local_handle, local) in fun.local_variables.iter() {
+ let name = &self.names[&NameKey::EntryPointLocal(ep_index as _, local_handle)];
+ let ty_name = TypeContext {
+ handle: local.ty,
+ module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
+ match local.init {
+ Some(value) => {
+ let coco = ConstantContext {
+ handle: value,
+ arena: &module.constants,
+ names: &self.names,
+ first_time: false,
+ };
+ write!(self.out, " = {coco}")?;
+ }
+ None => {
+ write!(self.out, " = {{}}")?;
+ }
+ };
+ writeln!(self.out, ";")?;
+ }
+
+ let guarded_indices =
+ index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
+
+ let context = StatementContext {
+ expression: ExpressionContext {
+ function: fun,
+ origin: FunctionOrigin::EntryPoint(ep_index as _),
+ info: fun_info,
+ policies: options.bounds_check_policies,
+ guarded_indices,
+ module,
+ pipeline_options,
+ },
+ mod_info,
+ result_struct: Some(&stage_out_name),
+ };
+ self.named_expressions.clear();
+ self.update_expressions_to_bake(fun, fun_info, &context.expression);
+ self.put_block(back::Level(1), &fun.body, &context)?;
+ writeln!(self.out, "}}")?;
+ if ep_index + 1 != module.entry_points.len() {
+ writeln!(self.out)?;
+ }
+ }
+
+ Ok(info)
+ }
+
+ fn write_barrier(&mut self, flags: crate::Barrier, level: back::Level) -> BackendResult {
+ // Note: OR-ring bitflags requires `__HAVE_MEMFLAG_OPERATORS__`,
+ // so we try to avoid it here.
+ if flags.is_empty() {
+ writeln!(
+ self.out,
+ "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_none);",
+ )?;
+ }
+ if flags.contains(crate::Barrier::STORAGE) {
+ writeln!(
+ self.out,
+ "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_device);",
+ )?;
+ }
+ if flags.contains(crate::Barrier::WORK_GROUP) {
+ writeln!(
+ self.out,
+ "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
+ )?;
+ }
+ Ok(())
+ }
+}
+
+/// Initializing workgroup variables is more tricky for Metal because we have to deal
+/// with atomics at the type-level (which don't have a copy constructor).
+mod workgroup_mem_init {
+ use crate::EntryPoint;
+
+ use super::*;
+
+ enum Access {
+ GlobalVariable(Handle<crate::GlobalVariable>),
+ StructMember(Handle<crate::Type>, u32),
+ Array(usize),
+ }
+
+ impl Access {
+ fn write<W: Write>(
+ &self,
+ writer: &mut W,
+ names: &FastHashMap<NameKey, String>,
+ ) -> Result<(), core::fmt::Error> {
+ match *self {
+ Access::GlobalVariable(handle) => {
+ write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)])
+ }
+ Access::StructMember(handle, index) => {
+ write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)])
+ }
+ Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"),
+ }
+ }
+ }
+
+ struct AccessStack {
+ stack: Vec<Access>,
+ array_depth: usize,
+ }
+
+ impl AccessStack {
+ const fn new() -> Self {
+ Self {
+ stack: Vec::new(),
+ array_depth: 0,
+ }
+ }
+
+ fn enter_array<R>(&mut self, cb: impl FnOnce(&mut Self, usize) -> R) -> R {
+ let array_depth = self.array_depth;
+ self.stack.push(Access::Array(array_depth));
+ self.array_depth += 1;
+ let res = cb(self, array_depth);
+ self.stack.pop();
+ self.array_depth -= 1;
+ res
+ }
+
+ fn enter<R>(&mut self, new: Access, cb: impl FnOnce(&mut Self) -> R) -> R {
+ self.stack.push(new);
+ let res = cb(self);
+ self.stack.pop();
+ res
+ }
+
+ fn write<W: Write>(
+ &self,
+ writer: &mut W,
+ names: &FastHashMap<NameKey, String>,
+ ) -> Result<(), core::fmt::Error> {
+ for next in self.stack.iter() {
+ next.write(writer, names)?;
+ }
+ Ok(())
+ }
+ }
+
+ impl<W: Write> Writer<W> {
+ pub(super) fn need_workgroup_variables_initialization(
+ &mut self,
+ options: &Options,
+ ep: &EntryPoint,
+ module: &crate::Module,
+ fun_info: &valid::FunctionInfo,
+ ) -> bool {
+ options.zero_initialize_workgroup_memory
+ && ep.stage == crate::ShaderStage::Compute
+ && module.global_variables.iter().any(|(handle, var)| {
+ !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
+ })
+ }
+
+ pub(super) fn write_workgroup_variables_initialization(
+ &mut self,
+ module: &crate::Module,
+ module_info: &valid::ModuleInfo,
+ fun_info: &valid::FunctionInfo,
+ local_invocation_id: Option<&NameKey>,
+ ) -> BackendResult {
+ let level = back::Level(1);
+
+ writeln!(
+ self.out,
+ "{}if ({}::all({} == {}::uint3(0u))) {{",
+ level,
+ NAMESPACE,
+ local_invocation_id
+ .map(|name_key| self.names[name_key].as_str())
+ .unwrap_or("__local_invocation_id"),
+ NAMESPACE,
+ )?;
+
+ let mut access_stack = AccessStack::new();
+
+ let vars = module.global_variables.iter().filter(|&(handle, var)| {
+ !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
+ });
+
+ for (handle, var) in vars {
+ access_stack.enter(Access::GlobalVariable(handle), |access_stack| {
+ self.write_workgroup_variable_initialization(
+ module,
+ module_info,
+ var.ty,
+ access_stack,
+ level.next(),
+ )
+ })?;
+ }
+
+ writeln!(self.out, "{level}}}")?;
+ self.write_barrier(crate::Barrier::WORK_GROUP, level)
+ }
+
+ fn write_workgroup_variable_initialization(
+ &mut self,
+ module: &crate::Module,
+ module_info: &valid::ModuleInfo,
+ ty: Handle<crate::Type>,
+ access_stack: &mut AccessStack,
+ level: back::Level,
+ ) -> BackendResult {
+ if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) {
+ write!(self.out, "{level}")?;
+ access_stack.write(&mut self.out, &self.names)?;
+ writeln!(self.out, " = {{}};")?;
+ } else {
+ match module.types[ty].inner {
+ crate::TypeInner::Atomic { .. } => {
+ write!(
+ self.out,
+ "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
+ )?;
+ access_stack.write(&mut self.out, &self.names)?;
+ writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?;
+ }
+ crate::TypeInner::Array { base, size, .. } => {
+ let count = match size.to_indexable_length(module).expect("Bad array size")
+ {
+ proc::IndexableLength::Known(count) => count,
+ proc::IndexableLength::Dynamic => unreachable!(),
+ };
+
+ access_stack.enter_array(|access_stack, array_depth| {
+ writeln!(
+ self.out,
+ "{level}for (int __i{array_depth} = 0; __i{array_depth} < {count}; __i{array_depth}++) {{"
+ )?;
+ self.write_workgroup_variable_initialization(
+ module,
+ module_info,
+ base,
+ access_stack,
+ level.next(),
+ )?;
+ writeln!(self.out, "{level}}}")?;
+ BackendResult::Ok(())
+ })?;
+ }
+ crate::TypeInner::Struct { ref members, .. } => {
+ for (index, member) in members.iter().enumerate() {
+ access_stack.enter(
+ Access::StructMember(ty, index as u32),
+ |access_stack| {
+ self.write_workgroup_variable_initialization(
+ module,
+ module_info,
+ member.ty,
+ access_stack,
+ level,
+ )
+ },
+ )?;
+ }
+ }
+ _ => unreachable!(),
+ }
+ }
+
+ Ok(())
+ }
+ }
+}
+
+#[test]
+fn test_stack_size() {
+ use crate::valid::{Capabilities, ValidationFlags};
+ // create a module with at least one expression nested
+ let mut module = crate::Module::default();
+ let constant = module.constants.append(
+ crate::Constant {
+ name: None,
+ specialization: None,
+ inner: crate::ConstantInner::Scalar {
+ value: crate::ScalarValue::Float(1.0),
+ width: 4,
+ },
+ },
+ Default::default(),
+ );
+ let mut fun = crate::Function::default();
+ let const_expr = fun
+ .expressions
+ .append(crate::Expression::Constant(constant), Default::default());
+ let nested_expr = fun.expressions.append(
+ crate::Expression::Unary {
+ op: crate::UnaryOperator::Negate,
+ expr: const_expr,
+ },
+ Default::default(),
+ );
+ fun.body.push(
+ crate::Statement::Emit(fun.expressions.range_from(1)),
+ Default::default(),
+ );
+ fun.body.push(
+ crate::Statement::If {
+ condition: nested_expr,
+ accept: crate::Block::new(),
+ reject: crate::Block::new(),
+ },
+ Default::default(),
+ );
+ let _ = module.functions.append(fun, Default::default());
+ // analyse the module
+ let info = crate::valid::Validator::new(ValidationFlags::empty(), Capabilities::empty())
+ .validate(&module)
+ .unwrap();
+ // process the module
+ let mut writer = Writer::new(String::new());
+ writer
+ .write(&module, &info, &Default::default(), &Default::default())
+ .unwrap();
+
+ {
+ // check expression stack
+ let mut addresses_start = usize::MAX;
+ let mut addresses_end = 0usize;
+ for pointer in writer.put_expression_stack_pointers {
+ addresses_start = addresses_start.min(pointer as usize);
+ addresses_end = addresses_end.max(pointer as usize);
+ }
+ let stack_size = addresses_end - addresses_start;
+ // check the size (in debug only)
+ // last observed macOS value: 20528 (CI)
+ if !(11000..=25000).contains(&stack_size) {
+ panic!("`put_expression` stack size {stack_size} has changed!");
+ }
+ }
+
+ {
+ // check block stack
+ let mut addresses_start = usize::MAX;
+ let mut addresses_end = 0usize;
+ for pointer in writer.put_block_stack_pointers {
+ addresses_start = addresses_start.min(pointer as usize);
+ addresses_end = addresses_end.max(pointer as usize);
+ }
+ let stack_size = addresses_end - addresses_start;
+ // check the size (in debug only)
+ // last observed macOS value: 19152 (CI)
+ if !(9000..=20000).contains(&stack_size) {
+ panic!("`put_block` stack size {stack_size} has changed!");
+ }
+ }
+}