summaryrefslogtreecommitdiffstats
path: root/third_party/rust/naga/src/back/hlsl
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/naga/src/back/hlsl')
-rw-r--r--third_party/rust/naga/src/back/hlsl/conv.rs222
-rw-r--r--third_party/rust/naga/src/back/hlsl/help.rs1138
-rw-r--r--third_party/rust/naga/src/back/hlsl/keywords.rs904
-rw-r--r--third_party/rust/naga/src/back/hlsl/mod.rs302
-rw-r--r--third_party/rust/naga/src/back/hlsl/storage.rs494
-rw-r--r--third_party/rust/naga/src/back/hlsl/writer.rs3366
6 files changed, 6426 insertions, 0 deletions
diff --git a/third_party/rust/naga/src/back/hlsl/conv.rs b/third_party/rust/naga/src/back/hlsl/conv.rs
new file mode 100644
index 0000000000..b6918ddc42
--- /dev/null
+++ b/third_party/rust/naga/src/back/hlsl/conv.rs
@@ -0,0 +1,222 @@
+use std::borrow::Cow;
+
+use crate::proc::Alignment;
+
+use super::Error;
+
+impl crate::ScalarKind {
+ pub(super) fn to_hlsl_cast(self) -> &'static str {
+ match self {
+ Self::Float => "asfloat",
+ Self::Sint => "asint",
+ Self::Uint => "asuint",
+ Self::Bool | Self::AbstractInt | Self::AbstractFloat => unreachable!(),
+ }
+ }
+}
+
+impl crate::Scalar {
+ /// Helper function that returns scalar related strings
+ ///
+ /// <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-scalar>
+ pub(super) const fn to_hlsl_str(self) -> Result<&'static str, Error> {
+ match self.kind {
+ crate::ScalarKind::Sint => Ok("int"),
+ crate::ScalarKind::Uint => Ok("uint"),
+ crate::ScalarKind::Float => match self.width {
+ 2 => Ok("half"),
+ 4 => Ok("float"),
+ 8 => Ok("double"),
+ _ => Err(Error::UnsupportedScalar(self)),
+ },
+ crate::ScalarKind::Bool => Ok("bool"),
+ crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat => {
+ Err(Error::UnsupportedScalar(self))
+ }
+ }
+ }
+}
+
+impl crate::TypeInner {
+ pub(super) const fn is_matrix(&self) -> bool {
+ match *self {
+ Self::Matrix { .. } => true,
+ _ => false,
+ }
+ }
+
+ pub(super) fn size_hlsl(&self, gctx: crate::proc::GlobalCtx) -> u32 {
+ match *self {
+ Self::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ let stride = Alignment::from(rows) * scalar.width as u32;
+ let last_row_size = rows as u32 * scalar.width as u32;
+ ((columns as u32 - 1) * stride) + last_row_size
+ }
+ Self::Array { base, size, stride } => {
+ let count = match size {
+ crate::ArraySize::Constant(size) => size.get(),
+ // A dynamically-sized array has to have at least one element
+ crate::ArraySize::Dynamic => 1,
+ };
+ let last_el_size = gctx.types[base].inner.size_hlsl(gctx);
+ ((count - 1) * stride) + last_el_size
+ }
+ _ => self.size(gctx),
+ }
+ }
+
+ /// Used to generate the name of the wrapped type constructor
+ pub(super) fn hlsl_type_id<'a>(
+ base: crate::Handle<crate::Type>,
+ gctx: crate::proc::GlobalCtx,
+ names: &'a crate::FastHashMap<crate::proc::NameKey, String>,
+ ) -> Result<Cow<'a, str>, Error> {
+ Ok(match gctx.types[base].inner {
+ crate::TypeInner::Scalar(scalar) => Cow::Borrowed(scalar.to_hlsl_str()?),
+ crate::TypeInner::Vector { size, scalar } => Cow::Owned(format!(
+ "{}{}",
+ scalar.to_hlsl_str()?,
+ crate::back::vector_size_str(size)
+ )),
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => Cow::Owned(format!(
+ "{}{}x{}",
+ scalar.to_hlsl_str()?,
+ crate::back::vector_size_str(columns),
+ crate::back::vector_size_str(rows),
+ )),
+ crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(size),
+ ..
+ } => Cow::Owned(format!(
+ "array{size}_{}_",
+ Self::hlsl_type_id(base, gctx, names)?
+ )),
+ crate::TypeInner::Struct { .. } => {
+ Cow::Borrowed(&names[&crate::proc::NameKey::Type(base)])
+ }
+ _ => unreachable!(),
+ })
+ }
+}
+
+impl crate::StorageFormat {
+ pub(super) const fn to_hlsl_str(self) -> &'static str {
+ match self {
+ Self::R16Float => "float",
+ Self::R8Unorm | Self::R16Unorm => "unorm float",
+ Self::R8Snorm | Self::R16Snorm => "snorm float",
+ Self::R8Uint | Self::R16Uint => "uint",
+ Self::R8Sint | Self::R16Sint => "int",
+
+ Self::Rg16Float => "float2",
+ Self::Rg8Unorm | Self::Rg16Unorm => "unorm float2",
+ Self::Rg8Snorm | Self::Rg16Snorm => "snorm float2",
+
+ Self::Rg8Sint | Self::Rg16Sint => "int2",
+ Self::Rg8Uint | Self::Rg16Uint => "uint2",
+
+ Self::Rg11b10Float => "float3",
+
+ Self::Rgba16Float | Self::R32Float | Self::Rg32Float | Self::Rgba32Float => "float4",
+ Self::Rgba8Unorm | Self::Bgra8Unorm | Self::Rgba16Unorm | Self::Rgb10a2Unorm => {
+ "unorm float4"
+ }
+ Self::Rgba8Snorm | Self::Rgba16Snorm => "snorm float4",
+
+ Self::Rgba8Uint
+ | Self::Rgba16Uint
+ | Self::R32Uint
+ | Self::Rg32Uint
+ | Self::Rgba32Uint
+ | Self::Rgb10a2Uint => "uint4",
+ Self::Rgba8Sint
+ | Self::Rgba16Sint
+ | Self::R32Sint
+ | Self::Rg32Sint
+ | Self::Rgba32Sint => "int4",
+ }
+ }
+}
+
+impl crate::BuiltIn {
+ pub(super) fn to_hlsl_str(self) -> Result<&'static str, Error> {
+ Ok(match self {
+ Self::Position { .. } => "SV_Position",
+ // vertex
+ Self::ClipDistance => "SV_ClipDistance",
+ Self::CullDistance => "SV_CullDistance",
+ Self::InstanceIndex => "SV_InstanceID",
+ Self::VertexIndex => "SV_VertexID",
+ // fragment
+ Self::FragDepth => "SV_Depth",
+ Self::FrontFacing => "SV_IsFrontFace",
+ Self::PrimitiveIndex => "SV_PrimitiveID",
+ Self::SampleIndex => "SV_SampleIndex",
+ Self::SampleMask => "SV_Coverage",
+ // compute
+ Self::GlobalInvocationId => "SV_DispatchThreadID",
+ Self::LocalInvocationId => "SV_GroupThreadID",
+ Self::LocalInvocationIndex => "SV_GroupIndex",
+ Self::WorkGroupId => "SV_GroupID",
+ // The specific semantic we use here doesn't matter, because references
+ // to this field will get replaced with references to `SPECIAL_CBUF_VAR`
+ // in `Writer::write_expr`.
+ Self::NumWorkGroups => "SV_GroupID",
+ Self::BaseInstance | Self::BaseVertex | Self::WorkGroupSize => {
+ return Err(Error::Unimplemented(format!("builtin {self:?}")))
+ }
+ Self::PointSize | Self::ViewIndex | Self::PointCoord => {
+ return Err(Error::Custom(format!("Unsupported builtin {self:?}")))
+ }
+ })
+ }
+}
+
+impl crate::Interpolation {
+ /// Return the string corresponding to the HLSL interpolation qualifier.
+ pub(super) const fn to_hlsl_str(self) -> Option<&'static str> {
+ match self {
+ // Would be "linear", but it's the default interpolation in SM4 and up
+ // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-struct#interpolation-modifiers-introduced-in-shader-model-4
+ Self::Perspective => None,
+ Self::Linear => Some("noperspective"),
+ Self::Flat => Some("nointerpolation"),
+ }
+ }
+}
+
+impl crate::Sampling {
+ /// Return the HLSL auxiliary qualifier for the given sampling value.
+ pub(super) const fn to_hlsl_str(self) -> Option<&'static str> {
+ match self {
+ Self::Center => None,
+ Self::Centroid => Some("centroid"),
+ Self::Sample => Some("sample"),
+ }
+ }
+}
+
+impl crate::AtomicFunction {
+ /// Return the HLSL suffix for the `InterlockedXxx` method.
+ pub(super) const fn to_hlsl_suffix(self) -> &'static str {
+ match self {
+ Self::Add | Self::Subtract => "Add",
+ Self::And => "And",
+ Self::InclusiveOr => "Or",
+ Self::ExclusiveOr => "Xor",
+ Self::Min => "Min",
+ Self::Max => "Max",
+ Self::Exchange { compare: None } => "Exchange",
+ Self::Exchange { .. } => "", //TODO
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/back/hlsl/help.rs b/third_party/rust/naga/src/back/hlsl/help.rs
new file mode 100644
index 0000000000..fa6062a1ad
--- /dev/null
+++ b/third_party/rust/naga/src/back/hlsl/help.rs
@@ -0,0 +1,1138 @@
+/*!
+Helpers for the hlsl backend
+
+Important note about `Expression::ImageQuery`/`Expression::ArrayLength` and hlsl backend:
+
+Due to implementation of `GetDimensions` function in hlsl (<https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-getdimensions>)
+backend can't work with it as an expression.
+Instead, it generates a unique wrapped function per `Expression::ImageQuery`, based on texture info and query function.
+See `WrappedImageQuery` struct that represents a unique function and will be generated before writing all statements and expressions.
+This allowed to works with `Expression::ImageQuery` as expression and write wrapped function.
+
+For example:
+```wgsl
+let dim_1d = textureDimensions(image_1d);
+```
+
+```hlsl
+int NagaDimensions1D(Texture1D<float4>)
+{
+ uint4 ret;
+ image_1d.GetDimensions(ret.x);
+ return ret.x;
+}
+
+int dim_1d = NagaDimensions1D(image_1d);
+```
+*/
+
+use super::{super::FunctionCtx, BackendResult};
+use crate::{arena::Handle, proc::NameKey};
+use std::fmt::Write;
+
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+pub(super) struct WrappedArrayLength {
+ pub(super) writable: bool,
+}
+
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+pub(super) struct WrappedImageQuery {
+ pub(super) dim: crate::ImageDimension,
+ pub(super) arrayed: bool,
+ pub(super) class: crate::ImageClass,
+ pub(super) query: ImageQuery,
+}
+
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+pub(super) struct WrappedConstructor {
+ pub(super) ty: Handle<crate::Type>,
+}
+
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+pub(super) struct WrappedStructMatrixAccess {
+ pub(super) ty: Handle<crate::Type>,
+ pub(super) index: u32,
+}
+
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+pub(super) struct WrappedMatCx2 {
+ pub(super) columns: crate::VectorSize,
+}
+
+/// HLSL backend requires its own `ImageQuery` enum.
+///
+/// It is used inside `WrappedImageQuery` and should be unique per ImageQuery function.
+/// IR version can't be unique per function, because it's store mipmap level as an expression.
+///
+/// For example:
+/// ```wgsl
+/// let dim_cube_array_lod = textureDimensions(image_cube_array, 1);
+/// let dim_cube_array_lod2 = textureDimensions(image_cube_array, 1);
+/// ```
+///
+/// ```ir
+/// ImageQuery {
+/// image: [1],
+/// query: Size {
+/// level: Some(
+/// [1],
+/// ),
+/// },
+/// },
+/// ImageQuery {
+/// image: [1],
+/// query: Size {
+/// level: Some(
+/// [2],
+/// ),
+/// },
+/// },
+/// ```
+///
+/// HLSL should generate only 1 function for this case.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+pub(super) enum ImageQuery {
+ Size,
+ SizeLevel,
+ NumLevels,
+ NumLayers,
+ NumSamples,
+}
+
+impl From<crate::ImageQuery> for ImageQuery {
+ fn from(q: crate::ImageQuery) -> Self {
+ use crate::ImageQuery as Iq;
+ match q {
+ Iq::Size { level: Some(_) } => ImageQuery::SizeLevel,
+ Iq::Size { level: None } => ImageQuery::Size,
+ Iq::NumLevels => ImageQuery::NumLevels,
+ Iq::NumLayers => ImageQuery::NumLayers,
+ Iq::NumSamples => ImageQuery::NumSamples,
+ }
+ }
+}
+
+impl<'a, W: Write> super::Writer<'a, W> {
+ pub(super) fn write_image_type(
+ &mut self,
+ dim: crate::ImageDimension,
+ arrayed: bool,
+ class: crate::ImageClass,
+ ) -> BackendResult {
+ let access_str = match class {
+ crate::ImageClass::Storage { .. } => "RW",
+ _ => "",
+ };
+ let dim_str = dim.to_hlsl_str();
+ let arrayed_str = if arrayed { "Array" } else { "" };
+ write!(self.out, "{access_str}Texture{dim_str}{arrayed_str}")?;
+ match class {
+ crate::ImageClass::Depth { multi } => {
+ let multi_str = if multi { "MS" } else { "" };
+ write!(self.out, "{multi_str}<float>")?
+ }
+ crate::ImageClass::Sampled { kind, multi } => {
+ let multi_str = if multi { "MS" } else { "" };
+ let scalar_kind_str = crate::Scalar { kind, width: 4 }.to_hlsl_str()?;
+ write!(self.out, "{multi_str}<{scalar_kind_str}4>")?
+ }
+ crate::ImageClass::Storage { format, .. } => {
+ let storage_format_str = format.to_hlsl_str();
+ write!(self.out, "<{storage_format_str}>")?
+ }
+ }
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_array_length_function_name(
+ &mut self,
+ query: WrappedArrayLength,
+ ) -> BackendResult {
+ let access_str = if query.writable { "RW" } else { "" };
+ write!(self.out, "NagaBufferLength{access_str}",)?;
+
+ Ok(())
+ }
+
+ /// Helper function that write wrapped function for `Expression::ArrayLength`
+ ///
+ /// <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-object-rwbyteaddressbuffer-getdimensions>
+ pub(super) fn write_wrapped_array_length_function(
+ &mut self,
+ wal: WrappedArrayLength,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ const ARGUMENT_VARIABLE_NAME: &str = "buffer";
+ const RETURN_VARIABLE_NAME: &str = "ret";
+
+ // Write function return type and name
+ write!(self.out, "uint ")?;
+ self.write_wrapped_array_length_function_name(wal)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+ let access_str = if wal.writable { "RW" } else { "" };
+ writeln!(
+ self.out,
+ "{access_str}ByteAddressBuffer {ARGUMENT_VARIABLE_NAME})"
+ )?;
+ // Write function body
+ writeln!(self.out, "{{")?;
+
+ // Write `GetDimensions` function.
+ writeln!(self.out, "{INDENT}uint {RETURN_VARIABLE_NAME};")?;
+ writeln!(
+ self.out,
+ "{INDENT}{ARGUMENT_VARIABLE_NAME}.GetDimensions({RETURN_VARIABLE_NAME});"
+ )?;
+
+ // Write return value
+ writeln!(self.out, "{INDENT}return {RETURN_VARIABLE_NAME};")?;
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_image_query_function_name(
+ &mut self,
+ query: WrappedImageQuery,
+ ) -> BackendResult {
+ let dim_str = query.dim.to_hlsl_str();
+ let class_str = match query.class {
+ crate::ImageClass::Sampled { multi: true, .. } => "MS",
+ crate::ImageClass::Depth { multi: true } => "DepthMS",
+ crate::ImageClass::Depth { multi: false } => "Depth",
+ crate::ImageClass::Sampled { multi: false, .. } => "",
+ crate::ImageClass::Storage { .. } => "RW",
+ };
+ let arrayed_str = if query.arrayed { "Array" } else { "" };
+ let query_str = match query.query {
+ ImageQuery::Size => "Dimensions",
+ ImageQuery::SizeLevel => "MipDimensions",
+ ImageQuery::NumLevels => "NumLevels",
+ ImageQuery::NumLayers => "NumLayers",
+ ImageQuery::NumSamples => "NumSamples",
+ };
+
+ write!(self.out, "Naga{class_str}{query_str}{dim_str}{arrayed_str}")?;
+
+ Ok(())
+ }
+
+ /// Helper function that write wrapped function for `Expression::ImageQuery`
+ ///
+ /// <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-getdimensions>
+ pub(super) fn write_wrapped_image_query_function(
+ &mut self,
+ module: &crate::Module,
+ wiq: WrappedImageQuery,
+ expr_handle: Handle<crate::Expression>,
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ use crate::{
+ back::{COMPONENTS, INDENT},
+ ImageDimension as IDim,
+ };
+
+ const ARGUMENT_VARIABLE_NAME: &str = "tex";
+ const RETURN_VARIABLE_NAME: &str = "ret";
+ const MIP_LEVEL_PARAM: &str = "mip_level";
+
+ // Write function return type and name
+ let ret_ty = func_ctx.resolve_type(expr_handle, &module.types);
+ self.write_value_type(module, ret_ty)?;
+ write!(self.out, " ")?;
+ self.write_wrapped_image_query_function_name(wiq)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+ // Texture always first parameter
+ self.write_image_type(wiq.dim, wiq.arrayed, wiq.class)?;
+ write!(self.out, " {ARGUMENT_VARIABLE_NAME}")?;
+ // Mipmap is a second parameter if exists
+ if let ImageQuery::SizeLevel = wiq.query {
+ write!(self.out, ", uint {MIP_LEVEL_PARAM}")?;
+ }
+ writeln!(self.out, ")")?;
+
+ // Write function body
+ writeln!(self.out, "{{")?;
+
+ let array_coords = usize::from(wiq.arrayed);
+ // extra parameter is the mip level count or the sample count
+ let extra_coords = match wiq.class {
+ crate::ImageClass::Storage { .. } => 0,
+ crate::ImageClass::Sampled { .. } | crate::ImageClass::Depth { .. } => 1,
+ };
+
+ // GetDimensions Overloaded Methods
+ // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-getdimensions#overloaded-methods
+ let (ret_swizzle, number_of_params) = match wiq.query {
+ ImageQuery::Size | ImageQuery::SizeLevel => {
+ let ret = match wiq.dim {
+ IDim::D1 => "x",
+ IDim::D2 => "xy",
+ IDim::D3 => "xyz",
+ IDim::Cube => "xy",
+ };
+ (ret, ret.len() + array_coords + extra_coords)
+ }
+ ImageQuery::NumLevels | ImageQuery::NumSamples | ImageQuery::NumLayers => {
+ if wiq.arrayed || wiq.dim == IDim::D3 {
+ ("w", 4)
+ } else {
+ ("z", 3)
+ }
+ }
+ };
+
+ // Write `GetDimensions` function.
+ writeln!(self.out, "{INDENT}uint4 {RETURN_VARIABLE_NAME};")?;
+ write!(self.out, "{INDENT}{ARGUMENT_VARIABLE_NAME}.GetDimensions(")?;
+ match wiq.query {
+ ImageQuery::SizeLevel => {
+ write!(self.out, "{MIP_LEVEL_PARAM}, ")?;
+ }
+ _ => match wiq.class {
+ crate::ImageClass::Sampled { multi: true, .. }
+ | crate::ImageClass::Depth { multi: true }
+ | crate::ImageClass::Storage { .. } => {}
+ _ => {
+ // Write zero mipmap level for supported types
+ write!(self.out, "0, ")?;
+ }
+ },
+ }
+
+ for component in COMPONENTS[..number_of_params - 1].iter() {
+ write!(self.out, "{RETURN_VARIABLE_NAME}.{component}, ")?;
+ }
+
+ // write last parameter without comma and space for last parameter
+ write!(
+ self.out,
+ "{}.{}",
+ RETURN_VARIABLE_NAME,
+ COMPONENTS[number_of_params - 1]
+ )?;
+
+ writeln!(self.out, ");")?;
+
+ // Write return value
+ writeln!(
+ self.out,
+ "{INDENT}return {RETURN_VARIABLE_NAME}.{ret_swizzle};"
+ )?;
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_constructor_function_name(
+ &mut self,
+ module: &crate::Module,
+ constructor: WrappedConstructor,
+ ) -> BackendResult {
+ let name = crate::TypeInner::hlsl_type_id(constructor.ty, module.to_ctx(), &self.names)?;
+ write!(self.out, "Construct{name}")?;
+ Ok(())
+ }
+
+ /// Helper function that write wrapped function for `Expression::Compose` for structures.
+ pub(super) fn write_wrapped_constructor_function(
+ &mut self,
+ module: &crate::Module,
+ constructor: WrappedConstructor,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ const ARGUMENT_VARIABLE_NAME: &str = "arg";
+ const RETURN_VARIABLE_NAME: &str = "ret";
+
+ // Write function return type and name
+ if let crate::TypeInner::Array { base, size, .. } = module.types[constructor.ty].inner {
+ write!(self.out, "typedef ")?;
+ self.write_type(module, constructor.ty)?;
+ write!(self.out, " ret_")?;
+ self.write_wrapped_constructor_function_name(module, constructor)?;
+ self.write_array_size(module, base, size)?;
+ writeln!(self.out, ";")?;
+
+ write!(self.out, "ret_")?;
+ self.write_wrapped_constructor_function_name(module, constructor)?;
+ } else {
+ self.write_type(module, constructor.ty)?;
+ }
+ write!(self.out, " ")?;
+ self.write_wrapped_constructor_function_name(module, constructor)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+
+ let mut write_arg = |i, ty| -> BackendResult {
+ if i != 0 {
+ write!(self.out, ", ")?;
+ }
+ self.write_type(module, ty)?;
+ write!(self.out, " {ARGUMENT_VARIABLE_NAME}{i}")?;
+ if let crate::TypeInner::Array { base, size, .. } = module.types[ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+ Ok(())
+ };
+
+ match module.types[constructor.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ for (i, member) in members.iter().enumerate() {
+ write_arg(i, member.ty)?;
+ }
+ }
+ crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(size),
+ ..
+ } => {
+ for i in 0..size.get() as usize {
+ write_arg(i, base)?;
+ }
+ }
+ _ => unreachable!(),
+ };
+
+ write!(self.out, ")")?;
+
+ // Write function body
+ writeln!(self.out, " {{")?;
+
+ match module.types[constructor.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ let struct_name = &self.names[&NameKey::Type(constructor.ty)];
+ writeln!(
+ self.out,
+ "{INDENT}{struct_name} {RETURN_VARIABLE_NAME} = ({struct_name})0;"
+ )?;
+ for (i, member) in members.iter().enumerate() {
+ let field_name = &self.names[&NameKey::StructMember(constructor.ty, i as u32)];
+
+ match module.types[member.ty].inner {
+ crate::TypeInner::Matrix {
+ columns,
+ rows: crate::VectorSize::Bi,
+ ..
+ } if member.binding.is_none() => {
+ for j in 0..columns as u8 {
+ writeln!(
+ self.out,
+ "{INDENT}{RETURN_VARIABLE_NAME}.{field_name}_{j} = {ARGUMENT_VARIABLE_NAME}{i}[{j}];"
+ )?;
+ }
+ }
+ ref other => {
+ // We cast arrays of native HLSL `floatCx2`s to arrays of `matCx2`s
+ // (where the inner matrix is represented by a struct with C `float2` members).
+ // See the module-level block comment in mod.rs for details.
+ if let Some(super::writer::MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = super::writer::get_inner_matrix_data(module, member.ty)
+ {
+ write!(
+ self.out,
+ "{}{}.{} = (__mat{}x2",
+ INDENT, RETURN_VARIABLE_NAME, field_name, columns as u8
+ )?;
+ if let crate::TypeInner::Array { base, size, .. } = *other {
+ self.write_array_size(module, base, size)?;
+ }
+ writeln!(self.out, "){ARGUMENT_VARIABLE_NAME}{i};",)?;
+ } else {
+ writeln!(
+ self.out,
+ "{INDENT}{RETURN_VARIABLE_NAME}.{field_name} = {ARGUMENT_VARIABLE_NAME}{i};",
+ )?;
+ }
+ }
+ }
+ }
+ }
+ crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(size),
+ ..
+ } => {
+ write!(self.out, "{INDENT}")?;
+ self.write_type(module, base)?;
+ write!(self.out, " {RETURN_VARIABLE_NAME}")?;
+ self.write_array_size(module, base, crate::ArraySize::Constant(size))?;
+ write!(self.out, " = {{ ")?;
+ for i in 0..size.get() {
+ if i != 0 {
+ write!(self.out, ", ")?;
+ }
+ write!(self.out, "{ARGUMENT_VARIABLE_NAME}{i}")?;
+ }
+ writeln!(self.out, " }};",)?;
+ }
+ _ => unreachable!(),
+ }
+
+ // Write return value
+ writeln!(self.out, "{INDENT}return {RETURN_VARIABLE_NAME};")?;
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_struct_matrix_get_function_name(
+ &mut self,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ let name = &self.names[&NameKey::Type(access.ty)];
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+ write!(self.out, "GetMat{field_name}On{name}")?;
+ Ok(())
+ }
+
+ /// Writes a function used to get a matCx2 from within a structure.
+ pub(super) fn write_wrapped_struct_matrix_get_function(
+ &mut self,
+ module: &crate::Module,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj";
+
+ // Write function return type and name
+ let member = match module.types[access.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize],
+ _ => unreachable!(),
+ };
+ let ret_ty = &module.types[member.ty].inner;
+ self.write_value_type(module, ret_ty)?;
+ write!(self.out, " ")?;
+ self.write_wrapped_struct_matrix_get_function_name(access)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+ let struct_name = &self.names[&NameKey::Type(access.ty)];
+ write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}")?;
+
+ // Write function body
+ writeln!(self.out, ") {{")?;
+
+ // Write return value
+ write!(self.out, "{INDENT}return ")?;
+ self.write_value_type(module, ret_ty)?;
+ write!(self.out, "(")?;
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+ match module.types[member.ty].inner {
+ crate::TypeInner::Matrix { columns, .. } => {
+ for i in 0..columns as u8 {
+ if i != 0 {
+ write!(self.out, ", ")?;
+ }
+ write!(self.out, "{STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i}")?;
+ }
+ }
+ _ => unreachable!(),
+ }
+ writeln!(self.out, ");")?;
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_struct_matrix_set_function_name(
+ &mut self,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ let name = &self.names[&NameKey::Type(access.ty)];
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+ write!(self.out, "SetMat{field_name}On{name}")?;
+ Ok(())
+ }
+
+ /// Writes a function used to set a matCx2 from within a structure.
+ pub(super) fn write_wrapped_struct_matrix_set_function(
+ &mut self,
+ module: &crate::Module,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj";
+ const MATRIX_ARGUMENT_VARIABLE_NAME: &str = "mat";
+
+ // Write function return type and name
+ write!(self.out, "void ")?;
+ self.write_wrapped_struct_matrix_set_function_name(access)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+ let struct_name = &self.names[&NameKey::Type(access.ty)];
+ write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}, ")?;
+ let member = match module.types[access.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize],
+ _ => unreachable!(),
+ };
+ self.write_type(module, member.ty)?;
+ write!(self.out, " {MATRIX_ARGUMENT_VARIABLE_NAME}")?;
+ // Write function body
+ writeln!(self.out, ") {{")?;
+
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+
+ match module.types[member.ty].inner {
+ crate::TypeInner::Matrix { columns, .. } => {
+ for i in 0..columns as u8 {
+ writeln!(
+ self.out,
+ "{INDENT}{STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i} = {MATRIX_ARGUMENT_VARIABLE_NAME}[{i}];"
+ )?;
+ }
+ }
+ _ => unreachable!(),
+ }
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_struct_matrix_set_vec_function_name(
+ &mut self,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ let name = &self.names[&NameKey::Type(access.ty)];
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+ write!(self.out, "SetMatVec{field_name}On{name}")?;
+ Ok(())
+ }
+
+ /// Writes a function used to set a vec2 on a matCx2 from within a structure.
+ pub(super) fn write_wrapped_struct_matrix_set_vec_function(
+ &mut self,
+ module: &crate::Module,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj";
+ const VECTOR_ARGUMENT_VARIABLE_NAME: &str = "vec";
+ const MATRIX_INDEX_ARGUMENT_VARIABLE_NAME: &str = "mat_idx";
+
+ // Write function return type and name
+ write!(self.out, "void ")?;
+ self.write_wrapped_struct_matrix_set_vec_function_name(access)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+ let struct_name = &self.names[&NameKey::Type(access.ty)];
+ write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}, ")?;
+ let member = match module.types[access.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize],
+ _ => unreachable!(),
+ };
+ let vec_ty = match module.types[member.ty].inner {
+ crate::TypeInner::Matrix { rows, scalar, .. } => {
+ crate::TypeInner::Vector { size: rows, scalar }
+ }
+ _ => unreachable!(),
+ };
+ self.write_value_type(module, &vec_ty)?;
+ write!(
+ self.out,
+ " {VECTOR_ARGUMENT_VARIABLE_NAME}, uint {MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}"
+ )?;
+
+ // Write function body
+ writeln!(self.out, ") {{")?;
+
+ writeln!(
+ self.out,
+ "{INDENT}switch({MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}) {{"
+ )?;
+
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+
+ match module.types[member.ty].inner {
+ crate::TypeInner::Matrix { columns, .. } => {
+ for i in 0..columns as u8 {
+ writeln!(
+ self.out,
+ "{INDENT}case {i}: {{ {STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i} = {VECTOR_ARGUMENT_VARIABLE_NAME}; break; }}"
+ )?;
+ }
+ }
+ _ => unreachable!(),
+ }
+
+ writeln!(self.out, "{INDENT}}}")?;
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_struct_matrix_set_scalar_function_name(
+ &mut self,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ let name = &self.names[&NameKey::Type(access.ty)];
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+ write!(self.out, "SetMatScalar{field_name}On{name}")?;
+ Ok(())
+ }
+
+ /// Writes a function used to set a float on a matCx2 from within a structure.
+ pub(super) fn write_wrapped_struct_matrix_set_scalar_function(
+ &mut self,
+ module: &crate::Module,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj";
+ const SCALAR_ARGUMENT_VARIABLE_NAME: &str = "scalar";
+ const MATRIX_INDEX_ARGUMENT_VARIABLE_NAME: &str = "mat_idx";
+ const VECTOR_INDEX_ARGUMENT_VARIABLE_NAME: &str = "vec_idx";
+
+ // Write function return type and name
+ write!(self.out, "void ")?;
+ self.write_wrapped_struct_matrix_set_scalar_function_name(access)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+ let struct_name = &self.names[&NameKey::Type(access.ty)];
+ write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}, ")?;
+ let member = match module.types[access.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize],
+ _ => unreachable!(),
+ };
+ let scalar_ty = match module.types[member.ty].inner {
+ crate::TypeInner::Matrix { scalar, .. } => crate::TypeInner::Scalar(scalar),
+ _ => unreachable!(),
+ };
+ self.write_value_type(module, &scalar_ty)?;
+ write!(
+ self.out,
+ " {SCALAR_ARGUMENT_VARIABLE_NAME}, uint {MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}, uint {VECTOR_INDEX_ARGUMENT_VARIABLE_NAME}"
+ )?;
+
+ // Write function body
+ writeln!(self.out, ") {{")?;
+
+ writeln!(
+ self.out,
+ "{INDENT}switch({MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}) {{"
+ )?;
+
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+
+ match module.types[member.ty].inner {
+ crate::TypeInner::Matrix { columns, .. } => {
+ for i in 0..columns as u8 {
+ writeln!(
+ self.out,
+ "{INDENT}case {i}: {{ {STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i}[{VECTOR_INDEX_ARGUMENT_VARIABLE_NAME}] = {SCALAR_ARGUMENT_VARIABLE_NAME}; break; }}"
+ )?;
+ }
+ }
+ _ => unreachable!(),
+ }
+
+ writeln!(self.out, "{INDENT}}}")?;
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ /// Write functions to create special types.
+ pub(super) fn write_special_functions(&mut self, module: &crate::Module) -> BackendResult {
+ for (type_key, struct_ty) in module.special_types.predeclared_types.iter() {
+ match type_key {
+ &crate::PredeclaredType::ModfResult { size, width }
+ | &crate::PredeclaredType::FrexpResult { size, width } => {
+ let arg_type_name_owner;
+ let arg_type_name = if let Some(size) = size {
+ arg_type_name_owner = format!(
+ "{}{}",
+ if width == 8 { "double" } else { "float" },
+ size as u8
+ );
+ &arg_type_name_owner
+ } else if width == 8 {
+ "double"
+ } else {
+ "float"
+ };
+
+ let (defined_func_name, called_func_name, second_field_name, sign_multiplier) =
+ if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
+ (super::writer::MODF_FUNCTION, "modf", "whole", "")
+ } else {
+ (
+ super::writer::FREXP_FUNCTION,
+ "frexp",
+ "exp_",
+ "sign(arg) * ",
+ )
+ };
+
+ let struct_name = &self.names[&NameKey::Type(*struct_ty)];
+
+ writeln!(
+ self.out,
+ "{struct_name} {defined_func_name}({arg_type_name} arg) {{
+ {arg_type_name} other;
+ {struct_name} result;
+ result.fract = {sign_multiplier}{called_func_name}(arg, other);
+ result.{second_field_name} = other;
+ return result;
+}}"
+ )?;
+ writeln!(self.out)?;
+ }
+ &crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {}
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Helper function that writes compose wrapped functions
+ pub(super) fn write_wrapped_compose_functions(
+ &mut self,
+ module: &crate::Module,
+ expressions: &crate::Arena<crate::Expression>,
+ ) -> BackendResult {
+ for (handle, _) in expressions.iter() {
+ if let crate::Expression::Compose { ty, .. } = expressions[handle] {
+ match module.types[ty].inner {
+ crate::TypeInner::Struct { .. } | crate::TypeInner::Array { .. } => {
+ let constructor = WrappedConstructor { ty };
+ if self.wrapped.constructors.insert(constructor) {
+ self.write_wrapped_constructor_function(module, constructor)?;
+ }
+ }
+ _ => {}
+ };
+ }
+ }
+ Ok(())
+ }
+
+ /// Helper function that writes various wrapped functions
+ pub(super) fn write_wrapped_functions(
+ &mut self,
+ module: &crate::Module,
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ self.write_wrapped_compose_functions(module, func_ctx.expressions)?;
+
+ for (handle, _) in func_ctx.expressions.iter() {
+ match func_ctx.expressions[handle] {
+ crate::Expression::ArrayLength(expr) => {
+ let global_expr = match func_ctx.expressions[expr] {
+ crate::Expression::GlobalVariable(_) => expr,
+ crate::Expression::AccessIndex { base, index: _ } => base,
+ ref other => unreachable!("Array length of {:?}", other),
+ };
+ let global_var = match func_ctx.expressions[global_expr] {
+ crate::Expression::GlobalVariable(var_handle) => {
+ &module.global_variables[var_handle]
+ }
+ ref other => unreachable!("Array length of base {:?}", other),
+ };
+ let storage_access = match global_var.space {
+ crate::AddressSpace::Storage { access } => access,
+ _ => crate::StorageAccess::default(),
+ };
+ let wal = WrappedArrayLength {
+ writable: storage_access.contains(crate::StorageAccess::STORE),
+ };
+
+ if self.wrapped.array_lengths.insert(wal) {
+ self.write_wrapped_array_length_function(wal)?;
+ }
+ }
+ crate::Expression::ImageQuery { image, query } => {
+ let wiq = match *func_ctx.resolve_type(image, &module.types) {
+ crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => WrappedImageQuery {
+ dim,
+ arrayed,
+ class,
+ query: query.into(),
+ },
+ _ => unreachable!("we only query images"),
+ };
+
+ if self.wrapped.image_queries.insert(wiq) {
+ self.write_wrapped_image_query_function(module, wiq, handle, func_ctx)?;
+ }
+ }
+ // Write `WrappedConstructor` for structs that are loaded from `AddressSpace::Storage`
+ // since they will later be used by the fn `write_storage_load`
+ crate::Expression::Load { pointer } => {
+ let pointer_space = func_ctx
+ .resolve_type(pointer, &module.types)
+ .pointer_space();
+
+ if let Some(crate::AddressSpace::Storage { .. }) = pointer_space {
+ if let Some(ty) = func_ctx.info[handle].ty.handle() {
+ write_wrapped_constructor(self, ty, module)?;
+ }
+ }
+
+ fn write_wrapped_constructor<W: Write>(
+ writer: &mut super::Writer<'_, W>,
+ ty: Handle<crate::Type>,
+ module: &crate::Module,
+ ) -> BackendResult {
+ match module.types[ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ for member in members {
+ write_wrapped_constructor(writer, member.ty, module)?;
+ }
+
+ let constructor = WrappedConstructor { ty };
+ if writer.wrapped.constructors.insert(constructor) {
+ writer
+ .write_wrapped_constructor_function(module, constructor)?;
+ }
+ }
+ crate::TypeInner::Array { base, .. } => {
+ write_wrapped_constructor(writer, base, module)?;
+
+ let constructor = WrappedConstructor { ty };
+ if writer.wrapped.constructors.insert(constructor) {
+ writer
+ .write_wrapped_constructor_function(module, constructor)?;
+ }
+ }
+ _ => {}
+ };
+
+ Ok(())
+ }
+ }
+ // We treat matrices of the form `matCx2` as a sequence of C `vec2`s
+ // (see top level module docs for details).
+ //
+ // The functions injected here are required to get the matrix accesses working.
+ crate::Expression::AccessIndex { base, index } => {
+ let base_ty_res = &func_ctx.info[base].ty;
+ let mut resolved = base_ty_res.inner_with(&module.types);
+ let base_ty_handle = match *resolved {
+ crate::TypeInner::Pointer { base, .. } => {
+ resolved = &module.types[base].inner;
+ Some(base)
+ }
+ _ => base_ty_res.handle(),
+ };
+ if let crate::TypeInner::Struct { ref members, .. } = *resolved {
+ let member = &members[index as usize];
+
+ match module.types[member.ty].inner {
+ crate::TypeInner::Matrix {
+ rows: crate::VectorSize::Bi,
+ ..
+ } if member.binding.is_none() => {
+ let ty = base_ty_handle.unwrap();
+ let access = WrappedStructMatrixAccess { ty, index };
+
+ if self.wrapped.struct_matrix_access.insert(access) {
+ self.write_wrapped_struct_matrix_get_function(module, access)?;
+ self.write_wrapped_struct_matrix_set_function(module, access)?;
+ self.write_wrapped_struct_matrix_set_vec_function(
+ module, access,
+ )?;
+ self.write_wrapped_struct_matrix_set_scalar_function(
+ module, access,
+ )?;
+ }
+ }
+ _ => {}
+ }
+ }
+ }
+ _ => {}
+ };
+ }
+
+ Ok(())
+ }
+
+ pub(super) fn write_texture_coordinates(
+ &mut self,
+ kind: &str,
+ coordinate: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ mip_level: Option<Handle<crate::Expression>>,
+ module: &crate::Module,
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ // HLSL expects the array index to be merged with the coordinate
+ let extra = array_index.is_some() as usize + (mip_level.is_some()) as usize;
+ if extra == 0 {
+ self.write_expr(module, coordinate, func_ctx)?;
+ } else {
+ let num_coords = match *func_ctx.resolve_type(coordinate, &module.types) {
+ crate::TypeInner::Scalar { .. } => 1,
+ crate::TypeInner::Vector { size, .. } => size as usize,
+ _ => unreachable!(),
+ };
+ write!(self.out, "{}{}(", kind, num_coords + extra)?;
+ self.write_expr(module, coordinate, func_ctx)?;
+ if let Some(expr) = array_index {
+ write!(self.out, ", ")?;
+ self.write_expr(module, expr, func_ctx)?;
+ }
+ if let Some(expr) = mip_level {
+ write!(self.out, ", ")?;
+ self.write_expr(module, expr, func_ctx)?;
+ }
+ write!(self.out, ")")?;
+ }
+ Ok(())
+ }
+
+ pub(super) fn write_mat_cx2_typedef_and_functions(
+ &mut self,
+ WrappedMatCx2 { columns }: WrappedMatCx2,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ // typedef
+ write!(self.out, "typedef struct {{ ")?;
+ for i in 0..columns as u8 {
+ write!(self.out, "float2 _{i}; ")?;
+ }
+ writeln!(self.out, "}} __mat{}x2;", columns as u8)?;
+
+ // __get_col_of_mat
+ writeln!(
+ self.out,
+ "float2 __get_col_of_mat{}x2(__mat{}x2 mat, uint idx) {{",
+ columns as u8, columns as u8
+ )?;
+ writeln!(self.out, "{INDENT}switch(idx) {{")?;
+ for i in 0..columns as u8 {
+ writeln!(self.out, "{INDENT}case {i}: {{ return mat._{i}; }}")?;
+ }
+ writeln!(self.out, "{INDENT}default: {{ return (float2)0; }}")?;
+ writeln!(self.out, "{INDENT}}}")?;
+ writeln!(self.out, "}}")?;
+
+ // __set_col_of_mat
+ writeln!(
+ self.out,
+ "void __set_col_of_mat{}x2(__mat{}x2 mat, uint idx, float2 value) {{",
+ columns as u8, columns as u8
+ )?;
+ writeln!(self.out, "{INDENT}switch(idx) {{")?;
+ for i in 0..columns as u8 {
+ writeln!(self.out, "{INDENT}case {i}: {{ mat._{i} = value; break; }}")?;
+ }
+ writeln!(self.out, "{INDENT}}}")?;
+ writeln!(self.out, "}}")?;
+
+ // __set_el_of_mat
+ writeln!(
+ self.out,
+ "void __set_el_of_mat{}x2(__mat{}x2 mat, uint idx, uint vec_idx, float value) {{",
+ columns as u8, columns as u8
+ )?;
+ writeln!(self.out, "{INDENT}switch(idx) {{")?;
+ for i in 0..columns as u8 {
+ writeln!(
+ self.out,
+ "{INDENT}case {i}: {{ mat._{i}[vec_idx] = value; break; }}"
+ )?;
+ }
+ writeln!(self.out, "{INDENT}}}")?;
+ writeln!(self.out, "}}")?;
+
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_all_mat_cx2_typedefs_and_functions(
+ &mut self,
+ module: &crate::Module,
+ ) -> BackendResult {
+ for (handle, _) in module.global_variables.iter() {
+ let global = &module.global_variables[handle];
+
+ if global.space == crate::AddressSpace::Uniform {
+ if let Some(super::writer::MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = super::writer::get_inner_matrix_data(module, global.ty)
+ {
+ let entry = WrappedMatCx2 { columns };
+ if self.wrapped.mat_cx2s.insert(entry) {
+ self.write_mat_cx2_typedef_and_functions(entry)?;
+ }
+ }
+ }
+ }
+
+ for (_, ty) in module.types.iter() {
+ if let crate::TypeInner::Struct { ref members, .. } = ty.inner {
+ for member in members.iter() {
+ if let crate::TypeInner::Array { .. } = module.types[member.ty].inner {
+ if let Some(super::writer::MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = super::writer::get_inner_matrix_data(module, member.ty)
+ {
+ let entry = WrappedMatCx2 { columns };
+ if self.wrapped.mat_cx2s.insert(entry) {
+ self.write_mat_cx2_typedef_and_functions(entry)?;
+ }
+ }
+ }
+ }
+ }
+ }
+
+ Ok(())
+ }
+}
diff --git a/third_party/rust/naga/src/back/hlsl/keywords.rs b/third_party/rust/naga/src/back/hlsl/keywords.rs
new file mode 100644
index 0000000000..059e533ff7
--- /dev/null
+++ b/third_party/rust/naga/src/back/hlsl/keywords.rs
@@ -0,0 +1,904 @@
+// When compiling with FXC without strict mode, these keywords are actually case insensitive.
+// If you compile with strict mode and specify a different casing like "Pass" instead in an identifier, FXC will give this error:
+// "error X3086: alternate cases for 'pass' are deprecated in strict mode"
+// This behavior is not documented anywhere, but as far as I can tell this is the full list.
+pub const RESERVED_CASE_INSENSITIVE: &[&str] = &[
+ "asm",
+ "decl",
+ "pass",
+ "technique",
+ "Texture1D",
+ "Texture2D",
+ "Texture3D",
+ "TextureCube",
+];
+
+pub const RESERVED: &[&str] = &[
+ // FXC keywords, from https://github.com/MicrosoftDocs/win32/blob/c885cb0c63b0e9be80c6a0e6512473ac6f4e771e/desktop-src/direct3dhlsl/dx-graphics-hlsl-appendix-keywords.md?plain=1#L99-L118
+ "AppendStructuredBuffer",
+ "asm",
+ "asm_fragment",
+ "BlendState",
+ "bool",
+ "break",
+ "Buffer",
+ "ByteAddressBuffer",
+ "case",
+ "cbuffer",
+ "centroid",
+ "class",
+ "column_major",
+ "compile",
+ "compile_fragment",
+ "CompileShader",
+ "const",
+ "continue",
+ "ComputeShader",
+ "ConsumeStructuredBuffer",
+ "default",
+ "DepthStencilState",
+ "DepthStencilView",
+ "discard",
+ "do",
+ "double",
+ "DomainShader",
+ "dword",
+ "else",
+ "export",
+ "extern",
+ "false",
+ "float",
+ "for",
+ "fxgroup",
+ "GeometryShader",
+ "groupshared",
+ "half",
+ "Hullshader",
+ "if",
+ "in",
+ "inline",
+ "inout",
+ "InputPatch",
+ "int",
+ "interface",
+ "line",
+ "lineadj",
+ "linear",
+ "LineStream",
+ "matrix",
+ "min16float",
+ "min10float",
+ "min16int",
+ "min12int",
+ "min16uint",
+ "namespace",
+ "nointerpolation",
+ "noperspective",
+ "NULL",
+ "out",
+ "OutputPatch",
+ "packoffset",
+ "pass",
+ "pixelfragment",
+ "PixelShader",
+ "point",
+ "PointStream",
+ "precise",
+ "RasterizerState",
+ "RenderTargetView",
+ "return",
+ "register",
+ "row_major",
+ "RWBuffer",
+ "RWByteAddressBuffer",
+ "RWStructuredBuffer",
+ "RWTexture1D",
+ "RWTexture1DArray",
+ "RWTexture2D",
+ "RWTexture2DArray",
+ "RWTexture3D",
+ "sample",
+ "sampler",
+ "SamplerState",
+ "SamplerComparisonState",
+ "shared",
+ "snorm",
+ "stateblock",
+ "stateblock_state",
+ "static",
+ "string",
+ "struct",
+ "switch",
+ "StructuredBuffer",
+ "tbuffer",
+ "technique",
+ "technique10",
+ "technique11",
+ "texture",
+ "Texture1D",
+ "Texture1DArray",
+ "Texture2D",
+ "Texture2DArray",
+ "Texture2DMS",
+ "Texture2DMSArray",
+ "Texture3D",
+ "TextureCube",
+ "TextureCubeArray",
+ "true",
+ "typedef",
+ "triangle",
+ "triangleadj",
+ "TriangleStream",
+ "uint",
+ "uniform",
+ "unorm",
+ "unsigned",
+ "vector",
+ "vertexfragment",
+ "VertexShader",
+ "void",
+ "volatile",
+ "while",
+ // FXC reserved keywords, from https://github.com/MicrosoftDocs/win32/blob/c885cb0c63b0e9be80c6a0e6512473ac6f4e771e/desktop-src/direct3dhlsl/dx-graphics-hlsl-appendix-reserved-words.md?plain=1#L19-L38
+ "auto",
+ "case",
+ "catch",
+ "char",
+ "class",
+ "const_cast",
+ "default",
+ "delete",
+ "dynamic_cast",
+ "enum",
+ "explicit",
+ "friend",
+ "goto",
+ "long",
+ "mutable",
+ "new",
+ "operator",
+ "private",
+ "protected",
+ "public",
+ "reinterpret_cast",
+ "short",
+ "signed",
+ "sizeof",
+ "static_cast",
+ "template",
+ "this",
+ "throw",
+ "try",
+ "typename",
+ "union",
+ "unsigned",
+ "using",
+ "virtual",
+ // FXC intrinsics, from https://github.com/MicrosoftDocs/win32/blob/1682b99e203708f6f5eda972d966e30f3c1588de/desktop-src/direct3dhlsl/dx-graphics-hlsl-intrinsic-functions.md?plain=1#L26-L165
+ "abort",
+ "abs",
+ "acos",
+ "all",
+ "AllMemoryBarrier",
+ "AllMemoryBarrierWithGroupSync",
+ "any",
+ "asdouble",
+ "asfloat",
+ "asin",
+ "asint",
+ "asuint",
+ "atan",
+ "atan2",
+ "ceil",
+ "CheckAccessFullyMapped",
+ "clamp",
+ "clip",
+ "cos",
+ "cosh",
+ "countbits",
+ "cross",
+ "D3DCOLORtoUBYTE4",
+ "ddx",
+ "ddx_coarse",
+ "ddx_fine",
+ "ddy",
+ "ddy_coarse",
+ "ddy_fine",
+ "degrees",
+ "determinant",
+ "DeviceMemoryBarrier",
+ "DeviceMemoryBarrierWithGroupSync",
+ "distance",
+ "dot",
+ "dst",
+ "errorf",
+ "EvaluateAttributeCentroid",
+ "EvaluateAttributeAtSample",
+ "EvaluateAttributeSnapped",
+ "exp",
+ "exp2",
+ "f16tof32",
+ "f32tof16",
+ "faceforward",
+ "firstbithigh",
+ "firstbitlow",
+ "floor",
+ "fma",
+ "fmod",
+ "frac",
+ "frexp",
+ "fwidth",
+ "GetRenderTargetSampleCount",
+ "GetRenderTargetSamplePosition",
+ "GroupMemoryBarrier",
+ "GroupMemoryBarrierWithGroupSync",
+ "InterlockedAdd",
+ "InterlockedAnd",
+ "InterlockedCompareExchange",
+ "InterlockedCompareStore",
+ "InterlockedExchange",
+ "InterlockedMax",
+ "InterlockedMin",
+ "InterlockedOr",
+ "InterlockedXor",
+ "isfinite",
+ "isinf",
+ "isnan",
+ "ldexp",
+ "length",
+ "lerp",
+ "lit",
+ "log",
+ "log10",
+ "log2",
+ "mad",
+ "max",
+ "min",
+ "modf",
+ "msad4",
+ "mul",
+ "noise",
+ "normalize",
+ "pow",
+ "printf",
+ "Process2DQuadTessFactorsAvg",
+ "Process2DQuadTessFactorsMax",
+ "Process2DQuadTessFactorsMin",
+ "ProcessIsolineTessFactors",
+ "ProcessQuadTessFactorsAvg",
+ "ProcessQuadTessFactorsMax",
+ "ProcessQuadTessFactorsMin",
+ "ProcessTriTessFactorsAvg",
+ "ProcessTriTessFactorsMax",
+ "ProcessTriTessFactorsMin",
+ "radians",
+ "rcp",
+ "reflect",
+ "refract",
+ "reversebits",
+ "round",
+ "rsqrt",
+ "saturate",
+ "sign",
+ "sin",
+ "sincos",
+ "sinh",
+ "smoothstep",
+ "sqrt",
+ "step",
+ "tan",
+ "tanh",
+ "tex1D",
+ "tex1Dbias",
+ "tex1Dgrad",
+ "tex1Dlod",
+ "tex1Dproj",
+ "tex2D",
+ "tex2Dbias",
+ "tex2Dgrad",
+ "tex2Dlod",
+ "tex2Dproj",
+ "tex3D",
+ "tex3Dbias",
+ "tex3Dgrad",
+ "tex3Dlod",
+ "tex3Dproj",
+ "texCUBE",
+ "texCUBEbias",
+ "texCUBEgrad",
+ "texCUBElod",
+ "texCUBEproj",
+ "transpose",
+ "trunc",
+ // DXC (reserved) keywords, from https://github.com/microsoft/DirectXShaderCompiler/blob/d5d478470d3020a438d3cb810b8d3fe0992e6709/tools/clang/include/clang/Basic/TokenKinds.def#L222-L648
+ // with the KEYALL, KEYCXX, BOOLSUPPORT, WCHARSUPPORT, KEYHLSL options enabled (see https://github.com/microsoft/DirectXShaderCompiler/blob/d5d478470d3020a438d3cb810b8d3fe0992e6709/tools/clang/lib/Frontend/CompilerInvocation.cpp#L1199)
+ "auto",
+ "break",
+ "case",
+ "char",
+ "const",
+ "continue",
+ "default",
+ "do",
+ "double",
+ "else",
+ "enum",
+ "extern",
+ "float",
+ "for",
+ "goto",
+ "if",
+ "inline",
+ "int",
+ "long",
+ "register",
+ "return",
+ "short",
+ "signed",
+ "sizeof",
+ "static",
+ "struct",
+ "switch",
+ "typedef",
+ "union",
+ "unsigned",
+ "void",
+ "volatile",
+ "while",
+ "_Alignas",
+ "_Alignof",
+ "_Atomic",
+ "_Complex",
+ "_Generic",
+ "_Imaginary",
+ "_Noreturn",
+ "_Static_assert",
+ "_Thread_local",
+ "__func__",
+ "__objc_yes",
+ "__objc_no",
+ "asm",
+ "bool",
+ "catch",
+ "class",
+ "const_cast",
+ "delete",
+ "dynamic_cast",
+ "explicit",
+ "export",
+ "false",
+ "friend",
+ "mutable",
+ "namespace",
+ "new",
+ "operator",
+ "private",
+ "protected",
+ "public",
+ "reinterpret_cast",
+ "static_cast",
+ "template",
+ "this",
+ "throw",
+ "true",
+ "try",
+ "typename",
+ "typeid",
+ "using",
+ "virtual",
+ "wchar_t",
+ "_Decimal32",
+ "_Decimal64",
+ "_Decimal128",
+ "__null",
+ "__alignof",
+ "__attribute",
+ "__builtin_choose_expr",
+ "__builtin_offsetof",
+ "__builtin_va_arg",
+ "__extension__",
+ "__imag",
+ "__int128",
+ "__label__",
+ "__real",
+ "__thread",
+ "__FUNCTION__",
+ "__PRETTY_FUNCTION__",
+ "__is_nothrow_assignable",
+ "__is_constructible",
+ "__is_nothrow_constructible",
+ "__has_nothrow_assign",
+ "__has_nothrow_move_assign",
+ "__has_nothrow_copy",
+ "__has_nothrow_constructor",
+ "__has_trivial_assign",
+ "__has_trivial_move_assign",
+ "__has_trivial_copy",
+ "__has_trivial_constructor",
+ "__has_trivial_move_constructor",
+ "__has_trivial_destructor",
+ "__has_virtual_destructor",
+ "__is_abstract",
+ "__is_base_of",
+ "__is_class",
+ "__is_convertible_to",
+ "__is_empty",
+ "__is_enum",
+ "__is_final",
+ "__is_literal",
+ "__is_literal_type",
+ "__is_pod",
+ "__is_polymorphic",
+ "__is_trivial",
+ "__is_union",
+ "__is_trivially_constructible",
+ "__is_trivially_copyable",
+ "__is_trivially_assignable",
+ "__underlying_type",
+ "__is_lvalue_expr",
+ "__is_rvalue_expr",
+ "__is_arithmetic",
+ "__is_floating_point",
+ "__is_integral",
+ "__is_complete_type",
+ "__is_void",
+ "__is_array",
+ "__is_function",
+ "__is_reference",
+ "__is_lvalue_reference",
+ "__is_rvalue_reference",
+ "__is_fundamental",
+ "__is_object",
+ "__is_scalar",
+ "__is_compound",
+ "__is_pointer",
+ "__is_member_object_pointer",
+ "__is_member_function_pointer",
+ "__is_member_pointer",
+ "__is_const",
+ "__is_volatile",
+ "__is_standard_layout",
+ "__is_signed",
+ "__is_unsigned",
+ "__is_same",
+ "__is_convertible",
+ "__array_rank",
+ "__array_extent",
+ "__private_extern__",
+ "__module_private__",
+ "__declspec",
+ "__cdecl",
+ "__stdcall",
+ "__fastcall",
+ "__thiscall",
+ "__vectorcall",
+ "cbuffer",
+ "tbuffer",
+ "packoffset",
+ "linear",
+ "centroid",
+ "nointerpolation",
+ "noperspective",
+ "sample",
+ "column_major",
+ "row_major",
+ "in",
+ "out",
+ "inout",
+ "uniform",
+ "precise",
+ "center",
+ "shared",
+ "groupshared",
+ "discard",
+ "snorm",
+ "unorm",
+ "point",
+ "line",
+ "lineadj",
+ "triangle",
+ "triangleadj",
+ "globallycoherent",
+ "interface",
+ "sampler_state",
+ "technique",
+ "indices",
+ "vertices",
+ "primitives",
+ "payload",
+ "Technique",
+ "technique10",
+ "technique11",
+ "__builtin_omp_required_simd_align",
+ "__pascal",
+ "__fp16",
+ "__alignof__",
+ "__asm",
+ "__asm__",
+ "__attribute__",
+ "__complex",
+ "__complex__",
+ "__const",
+ "__const__",
+ "__decltype",
+ "__imag__",
+ "__inline",
+ "__inline__",
+ "__nullptr",
+ "__real__",
+ "__restrict",
+ "__restrict__",
+ "__signed",
+ "__signed__",
+ "__typeof",
+ "__typeof__",
+ "__volatile",
+ "__volatile__",
+ "_Nonnull",
+ "_Nullable",
+ "_Null_unspecified",
+ "__builtin_convertvector",
+ "__char16_t",
+ "__char32_t",
+ // DXC intrinsics, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/utils/hct/gen_intrin_main.txt#L86-L376
+ "D3DCOLORtoUBYTE4",
+ "GetRenderTargetSampleCount",
+ "GetRenderTargetSamplePosition",
+ "abort",
+ "abs",
+ "acos",
+ "all",
+ "AllMemoryBarrier",
+ "AllMemoryBarrierWithGroupSync",
+ "any",
+ "asdouble",
+ "asfloat",
+ "asfloat16",
+ "asint16",
+ "asin",
+ "asint",
+ "asuint",
+ "asuint16",
+ "atan",
+ "atan2",
+ "ceil",
+ "clamp",
+ "clip",
+ "cos",
+ "cosh",
+ "countbits",
+ "cross",
+ "ddx",
+ "ddx_coarse",
+ "ddx_fine",
+ "ddy",
+ "ddy_coarse",
+ "ddy_fine",
+ "degrees",
+ "determinant",
+ "DeviceMemoryBarrier",
+ "DeviceMemoryBarrierWithGroupSync",
+ "distance",
+ "dot",
+ "dst",
+ "EvaluateAttributeAtSample",
+ "EvaluateAttributeCentroid",
+ "EvaluateAttributeSnapped",
+ "GetAttributeAtVertex",
+ "exp",
+ "exp2",
+ "f16tof32",
+ "f32tof16",
+ "faceforward",
+ "firstbithigh",
+ "firstbitlow",
+ "floor",
+ "fma",
+ "fmod",
+ "frac",
+ "frexp",
+ "fwidth",
+ "GroupMemoryBarrier",
+ "GroupMemoryBarrierWithGroupSync",
+ "InterlockedAdd",
+ "InterlockedMin",
+ "InterlockedMax",
+ "InterlockedAnd",
+ "InterlockedOr",
+ "InterlockedXor",
+ "InterlockedCompareStore",
+ "InterlockedExchange",
+ "InterlockedCompareExchange",
+ "InterlockedCompareStoreFloatBitwise",
+ "InterlockedCompareExchangeFloatBitwise",
+ "isfinite",
+ "isinf",
+ "isnan",
+ "ldexp",
+ "length",
+ "lerp",
+ "lit",
+ "log",
+ "log10",
+ "log2",
+ "mad",
+ "max",
+ "min",
+ "modf",
+ "msad4",
+ "mul",
+ "normalize",
+ "pow",
+ "printf",
+ "Process2DQuadTessFactorsAvg",
+ "Process2DQuadTessFactorsMax",
+ "Process2DQuadTessFactorsMin",
+ "ProcessIsolineTessFactors",
+ "ProcessQuadTessFactorsAvg",
+ "ProcessQuadTessFactorsMax",
+ "ProcessQuadTessFactorsMin",
+ "ProcessTriTessFactorsAvg",
+ "ProcessTriTessFactorsMax",
+ "ProcessTriTessFactorsMin",
+ "radians",
+ "rcp",
+ "reflect",
+ "refract",
+ "reversebits",
+ "round",
+ "rsqrt",
+ "saturate",
+ "sign",
+ "sin",
+ "sincos",
+ "sinh",
+ "smoothstep",
+ "source_mark",
+ "sqrt",
+ "step",
+ "tan",
+ "tanh",
+ "tex1D",
+ "tex1Dbias",
+ "tex1Dgrad",
+ "tex1Dlod",
+ "tex1Dproj",
+ "tex2D",
+ "tex2Dbias",
+ "tex2Dgrad",
+ "tex2Dlod",
+ "tex2Dproj",
+ "tex3D",
+ "tex3Dbias",
+ "tex3Dgrad",
+ "tex3Dlod",
+ "tex3Dproj",
+ "texCUBE",
+ "texCUBEbias",
+ "texCUBEgrad",
+ "texCUBElod",
+ "texCUBEproj",
+ "transpose",
+ "trunc",
+ "CheckAccessFullyMapped",
+ "AddUint64",
+ "NonUniformResourceIndex",
+ "WaveIsFirstLane",
+ "WaveGetLaneIndex",
+ "WaveGetLaneCount",
+ "WaveActiveAnyTrue",
+ "WaveActiveAllTrue",
+ "WaveActiveAllEqual",
+ "WaveActiveBallot",
+ "WaveReadLaneAt",
+ "WaveReadLaneFirst",
+ "WaveActiveCountBits",
+ "WaveActiveSum",
+ "WaveActiveProduct",
+ "WaveActiveBitAnd",
+ "WaveActiveBitOr",
+ "WaveActiveBitXor",
+ "WaveActiveMin",
+ "WaveActiveMax",
+ "WavePrefixCountBits",
+ "WavePrefixSum",
+ "WavePrefixProduct",
+ "WaveMatch",
+ "WaveMultiPrefixBitAnd",
+ "WaveMultiPrefixBitOr",
+ "WaveMultiPrefixBitXor",
+ "WaveMultiPrefixCountBits",
+ "WaveMultiPrefixProduct",
+ "WaveMultiPrefixSum",
+ "QuadReadLaneAt",
+ "QuadReadAcrossX",
+ "QuadReadAcrossY",
+ "QuadReadAcrossDiagonal",
+ "QuadAny",
+ "QuadAll",
+ "TraceRay",
+ "ReportHit",
+ "CallShader",
+ "IgnoreHit",
+ "AcceptHitAndEndSearch",
+ "DispatchRaysIndex",
+ "DispatchRaysDimensions",
+ "WorldRayOrigin",
+ "WorldRayDirection",
+ "ObjectRayOrigin",
+ "ObjectRayDirection",
+ "RayTMin",
+ "RayTCurrent",
+ "PrimitiveIndex",
+ "InstanceID",
+ "InstanceIndex",
+ "GeometryIndex",
+ "HitKind",
+ "RayFlags",
+ "ObjectToWorld",
+ "WorldToObject",
+ "ObjectToWorld3x4",
+ "WorldToObject3x4",
+ "ObjectToWorld4x3",
+ "WorldToObject4x3",
+ "dot4add_u8packed",
+ "dot4add_i8packed",
+ "dot2add",
+ "unpack_s8s16",
+ "unpack_u8u16",
+ "unpack_s8s32",
+ "unpack_u8u32",
+ "pack_s8",
+ "pack_u8",
+ "pack_clamp_s8",
+ "pack_clamp_u8",
+ "SetMeshOutputCounts",
+ "DispatchMesh",
+ "IsHelperLane",
+ "AllocateRayQuery",
+ "CreateResourceFromHeap",
+ "and",
+ "or",
+ "select",
+ // DXC resource and other types, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/tools/clang/lib/AST/HlslTypes.cpp#L441-#L572
+ "InputPatch",
+ "OutputPatch",
+ "PointStream",
+ "LineStream",
+ "TriangleStream",
+ "Texture1D",
+ "RWTexture1D",
+ "Texture2D",
+ "RWTexture2D",
+ "Texture2DMS",
+ "RWTexture2DMS",
+ "Texture3D",
+ "RWTexture3D",
+ "TextureCube",
+ "RWTextureCube",
+ "Texture1DArray",
+ "RWTexture1DArray",
+ "Texture2DArray",
+ "RWTexture2DArray",
+ "Texture2DMSArray",
+ "RWTexture2DMSArray",
+ "TextureCubeArray",
+ "RWTextureCubeArray",
+ "FeedbackTexture2D",
+ "FeedbackTexture2DArray",
+ "RasterizerOrderedTexture1D",
+ "RasterizerOrderedTexture2D",
+ "RasterizerOrderedTexture3D",
+ "RasterizerOrderedTexture1DArray",
+ "RasterizerOrderedTexture2DArray",
+ "RasterizerOrderedBuffer",
+ "RasterizerOrderedByteAddressBuffer",
+ "RasterizerOrderedStructuredBuffer",
+ "ByteAddressBuffer",
+ "RWByteAddressBuffer",
+ "StructuredBuffer",
+ "RWStructuredBuffer",
+ "AppendStructuredBuffer",
+ "ConsumeStructuredBuffer",
+ "Buffer",
+ "RWBuffer",
+ "SamplerState",
+ "SamplerComparisonState",
+ "ConstantBuffer",
+ "TextureBuffer",
+ "RaytracingAccelerationStructure",
+ // DXC templated types, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/tools/clang/lib/AST/ASTContextHLSL.cpp
+ // look for `BuiltinTypeDeclBuilder`
+ "matrix",
+ "vector",
+ "TextureBuffer",
+ "ConstantBuffer",
+ "RayQuery",
+ // Naga utilities
+ super::writer::MODF_FUNCTION,
+ super::writer::FREXP_FUNCTION,
+];
+
+// DXC scalar types, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/tools/clang/lib/AST/ASTContextHLSL.cpp#L48-L254
+// + vector and matrix shorthands
+pub const TYPES: &[&str] = &{
+ const L: usize = 23 * (1 + 4 + 4 * 4);
+ let mut res = [""; L];
+ let mut c = 0;
+
+ /// For each scalar type, it will additionally generate vector and matrix shorthands
+ macro_rules! generate {
+ ([$($roots:literal),*], $x:tt) => {
+ $(
+ generate!(@inner push $roots);
+ generate!(@inner $roots, $x);
+ )*
+ };
+
+ (@inner $root:literal, [$($x:literal),*]) => {
+ generate!(@inner vector $root, $($x)*);
+ generate!(@inner matrix $root, $($x)*);
+ };
+
+ (@inner vector $root:literal, $($x:literal)*) => {
+ $(
+ generate!(@inner push concat!($root, $x));
+ )*
+ };
+
+ (@inner matrix $root:literal, $($x:literal)*) => {
+ // Duplicate the list
+ generate!(@inner matrix $root, $($x)*; $($x)*);
+ };
+
+ // The head/tail recursion: pick the first element of the first list and recursively do it for the tail.
+ (@inner matrix $root:literal, $head:literal $($tail:literal)*; $($x:literal)*) => {
+ $(
+ generate!(@inner push concat!($root, $head, "x", $x));
+ )*
+ generate!(@inner matrix $root, $($tail)*; $($x)*);
+
+ };
+
+ // The end of iteration: we exhausted the list
+ (@inner matrix $root:literal, ; $($x:literal)*) => {};
+
+ (@inner push $v:expr) => {
+ res[c] = $v;
+ c += 1;
+ };
+ }
+
+ generate!(
+ [
+ "bool",
+ "int",
+ "uint",
+ "dword",
+ "half",
+ "float",
+ "double",
+ "min10float",
+ "min16float",
+ "min12int",
+ "min16int",
+ "min16uint",
+ "int16_t",
+ "int32_t",
+ "int64_t",
+ "uint16_t",
+ "uint32_t",
+ "uint64_t",
+ "float16_t",
+ "float32_t",
+ "float64_t",
+ "int8_t4_packed",
+ "uint8_t4_packed"
+ ],
+ ["1", "2", "3", "4"]
+ );
+
+ debug_assert!(c == L);
+
+ res
+};
diff --git a/third_party/rust/naga/src/back/hlsl/mod.rs b/third_party/rust/naga/src/back/hlsl/mod.rs
new file mode 100644
index 0000000000..37ddbd3d67
--- /dev/null
+++ b/third_party/rust/naga/src/back/hlsl/mod.rs
@@ -0,0 +1,302 @@
+/*!
+Backend for [HLSL][hlsl] (High-Level Shading Language).
+
+# Supported shader model versions:
+- 5.0
+- 5.1
+- 6.0
+
+# Layout of values in `uniform` buffers
+
+WGSL's ["Internal Layout of Values"][ilov] rules specify how each WGSL
+type should be stored in `uniform` and `storage` buffers. The HLSL we
+generate must access values in that form, even when it is not what
+HLSL would use normally.
+
+The rules described here only apply to WGSL `uniform` variables. WGSL
+`storage` buffers are translated as HLSL `ByteAddressBuffers`, for
+which we generate `Load` and `Store` method calls with explicit byte
+offsets. WGSL pipeline inputs must be scalars or vectors; they cannot
+be matrices, which is where the interesting problems arise.
+
+## Row- and column-major ordering for matrices
+
+WGSL specifies that matrices in uniform buffers are stored in
+column-major order. This matches HLSL's default, so one might expect
+things to be straightforward. Unfortunately, WGSL and HLSL disagree on
+what indexing a matrix means: in WGSL, `m[i]` retrieves the `i`'th
+*column* of `m`, whereas in HLSL it retrieves the `i`'th *row*. We
+want to avoid translating `m[i]` into some complicated reassembly of a
+vector from individually fetched components, so this is a problem.
+
+However, with a bit of trickery, it is possible to use HLSL's `m[i]`
+as the translation of WGSL's `m[i]`:
+
+- We declare all matrices in uniform buffers in HLSL with the
+ `row_major` qualifier, and transpose the row and column counts: a
+ WGSL `mat3x4<f32>`, say, becomes an HLSL `row_major float3x4`. (Note
+ that WGSL and HLSL type names put the row and column in reverse
+ order.) Since the HLSL type is the transpose of how WebGPU directs
+ the user to store the data, HLSL will load all matrices transposed.
+
+- Since matrices are transposed, an HLSL indexing expression retrieves
+ the "columns" of the intended WGSL value, as desired.
+
+- For vector-matrix multiplication, since `mul(transpose(m), v)` is
+ equivalent to `mul(v, m)` (note the reversal of the arguments), and
+ `mul(v, transpose(m))` is equivalent to `mul(m, v)`, we can
+ translate WGSL `m * v` and `v * m` to HLSL by simply reversing the
+ arguments to `mul`.
+
+## Padding in two-row matrices
+
+An HLSL `row_major floatKx2` matrix has padding between its rows that
+the WGSL `matKx2<f32>` matrix it represents does not. HLSL stores all
+matrix rows [aligned on 16-byte boundaries][16bb], whereas WGSL says
+that the columns of a `matKx2<f32>` need only be [aligned as required
+for `vec2<f32>`][ilov], which is [eight-byte alignment][8bb].
+
+To compensate for this, any time a `matKx2<f32>` appears in a WGSL
+`uniform` variable, whether directly as the variable's type or as part
+of a struct/array, we actually emit `K` separate `float2` members, and
+assemble/disassemble the matrix from its columns (in WGSL; rows in
+HLSL) upon load and store.
+
+For example, the following WGSL struct type:
+
+```ignore
+struct Baz {
+ m: mat3x2<f32>,
+}
+```
+
+is rendered as the HLSL struct type:
+
+```ignore
+struct Baz {
+ float2 m_0; float2 m_1; float2 m_2;
+};
+```
+
+The `wrapped_struct_matrix` functions in `help.rs` generate HLSL
+helper functions to access such members, converting between the stored
+form and the HLSL matrix types appropriately. For example, for reading
+the member `m` of the `Baz` struct above, we emit:
+
+```ignore
+float3x2 GetMatmOnBaz(Baz obj) {
+ return float3x2(obj.m_0, obj.m_1, obj.m_2);
+}
+```
+
+We also emit an analogous `Set` function, as well as functions for
+accessing individual columns by dynamic index.
+
+[hlsl]: https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl
+[ilov]: https://gpuweb.github.io/gpuweb/wgsl/#internal-value-layout
+[16bb]: https://github.com/microsoft/DirectXShaderCompiler/wiki/Buffer-Packing#constant-buffer-packing
+[8bb]: https://gpuweb.github.io/gpuweb/wgsl/#alignment-and-size
+*/
+
+mod conv;
+mod help;
+mod keywords;
+mod storage;
+mod writer;
+
+use std::fmt::Error as FmtError;
+use thiserror::Error;
+
+use crate::{back, proc};
+
+#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct BindTarget {
+ pub space: u8,
+ pub register: u32,
+ /// If the binding is an unsized binding array, this overrides the size.
+ pub binding_array_size: Option<u32>,
+}
+
+// Using `BTreeMap` instead of `HashMap` so that we can hash itself.
+pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, BindTarget>;
+
+/// A HLSL shader model version.
+#[allow(non_snake_case, non_camel_case_types)]
+#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub enum ShaderModel {
+ V5_0,
+ V5_1,
+ V6_0,
+}
+
+impl ShaderModel {
+ pub const fn to_str(self) -> &'static str {
+ match self {
+ Self::V5_0 => "5_0",
+ Self::V5_1 => "5_1",
+ Self::V6_0 => "6_0",
+ }
+ }
+}
+
+impl crate::ShaderStage {
+ pub const fn to_hlsl_str(self) -> &'static str {
+ match self {
+ Self::Vertex => "vs",
+ Self::Fragment => "ps",
+ Self::Compute => "cs",
+ }
+ }
+}
+
+impl crate::ImageDimension {
+ const fn to_hlsl_str(self) -> &'static str {
+ match self {
+ Self::D1 => "1D",
+ Self::D2 => "2D",
+ Self::D3 => "3D",
+ Self::Cube => "Cube",
+ }
+ }
+}
+
+/// Shorthand result used internally by the backend
+type BackendResult = Result<(), Error>;
+
+#[derive(Clone, Debug, PartialEq, thiserror::Error)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub enum EntryPointError {
+ #[error("mapping of {0:?} is missing")]
+ MissingBinding(crate::ResourceBinding),
+}
+
+/// Configuration used in the [`Writer`].
+#[derive(Clone, Debug, Hash, PartialEq, Eq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct Options {
+ /// The hlsl shader model to be used
+ pub shader_model: ShaderModel,
+ /// Map of resources association to binding locations.
+ pub binding_map: BindingMap,
+ /// Don't panic on missing bindings, instead generate any HLSL.
+ pub fake_missing_bindings: bool,
+ /// Add special constants to `SV_VertexIndex` and `SV_InstanceIndex`,
+ /// to make them work like in Vulkan/Metal, with help of the host.
+ pub special_constants_binding: Option<BindTarget>,
+ /// Bind target of the push constant buffer
+ pub push_constants_target: Option<BindTarget>,
+ /// Should workgroup variables be zero initialized (by polyfilling)?
+ pub zero_initialize_workgroup_memory: bool,
+}
+
+impl Default for Options {
+ fn default() -> Self {
+ Options {
+ shader_model: ShaderModel::V5_1,
+ binding_map: BindingMap::default(),
+ fake_missing_bindings: true,
+ special_constants_binding: None,
+ push_constants_target: None,
+ zero_initialize_workgroup_memory: true,
+ }
+ }
+}
+
+impl Options {
+ fn resolve_resource_binding(
+ &self,
+ res_binding: &crate::ResourceBinding,
+ ) -> Result<BindTarget, EntryPointError> {
+ match self.binding_map.get(res_binding) {
+ Some(target) => Ok(target.clone()),
+ None if self.fake_missing_bindings => Ok(BindTarget {
+ space: res_binding.group as u8,
+ register: res_binding.binding,
+ binding_array_size: None,
+ }),
+ None => Err(EntryPointError::MissingBinding(res_binding.clone())),
+ }
+ }
+}
+
+/// Reflection info for entry point names.
+#[derive(Default)]
+pub struct ReflectionInfo {
+ /// Mapping of the entry point names.
+ ///
+ /// Each item in the array corresponds to an entry point index. The real entry point name may be different if one of the
+ /// reserved words are used.
+ ///
+ /// Note: Some entry points may fail translation because of missing bindings.
+ pub entry_point_names: Vec<Result<String, EntryPointError>>,
+}
+
+#[derive(Error, Debug)]
+pub enum Error {
+ #[error(transparent)]
+ IoError(#[from] FmtError),
+ #[error("A scalar with an unsupported width was requested: {0:?}")]
+ UnsupportedScalar(crate::Scalar),
+ #[error("{0}")]
+ Unimplemented(String), // TODO: Error used only during development
+ #[error("{0}")]
+ Custom(String),
+}
+
+#[derive(Default)]
+struct Wrapped {
+ array_lengths: crate::FastHashSet<help::WrappedArrayLength>,
+ image_queries: crate::FastHashSet<help::WrappedImageQuery>,
+ constructors: crate::FastHashSet<help::WrappedConstructor>,
+ struct_matrix_access: crate::FastHashSet<help::WrappedStructMatrixAccess>,
+ mat_cx2s: crate::FastHashSet<help::WrappedMatCx2>,
+}
+
+impl Wrapped {
+ fn clear(&mut self) {
+ self.array_lengths.clear();
+ self.image_queries.clear();
+ self.constructors.clear();
+ self.struct_matrix_access.clear();
+ self.mat_cx2s.clear();
+ }
+}
+
+pub struct Writer<'a, W> {
+ out: W,
+ names: crate::FastHashMap<proc::NameKey, String>,
+ namer: proc::Namer,
+ /// HLSL backend options
+ options: &'a Options,
+ /// Information about entry point arguments and result types.
+ entry_point_io: Vec<writer::EntryPointInterface>,
+ /// Set of expressions that have associated temporary variables
+ named_expressions: crate::NamedExpressions,
+ wrapped: Wrapped,
+
+ /// A reference to some part of a global variable, lowered to a series of
+ /// byte offset calculations.
+ ///
+ /// See the [`storage`] module for background on why we need this.
+ ///
+ /// Each [`SubAccess`] in the vector is a lowering of some [`Access`] or
+ /// [`AccessIndex`] expression to the level of byte strides and offsets. See
+ /// [`SubAccess`] for details.
+ ///
+ /// This field is a member of [`Writer`] solely to allow re-use of
+ /// the `Vec`'s dynamic allocation. The value is no longer needed
+ /// once HLSL for the access has been generated.
+ ///
+ /// [`Storage`]: crate::AddressSpace::Storage
+ /// [`SubAccess`]: storage::SubAccess
+ /// [`Access`]: crate::Expression::Access
+ /// [`AccessIndex`]: crate::Expression::AccessIndex
+ temp_access_chain: Vec<storage::SubAccess>,
+ need_bake_expressions: back::NeedBakeExpressions,
+}
diff --git a/third_party/rust/naga/src/back/hlsl/storage.rs b/third_party/rust/naga/src/back/hlsl/storage.rs
new file mode 100644
index 0000000000..1b8a6ec12d
--- /dev/null
+++ b/third_party/rust/naga/src/back/hlsl/storage.rs
@@ -0,0 +1,494 @@
+/*!
+Generating accesses to [`ByteAddressBuffer`] contents.
+
+Naga IR globals in the [`Storage`] address space are rendered as
+[`ByteAddressBuffer`]s or [`RWByteAddressBuffer`]s in HLSL. These
+buffers don't have HLSL types (structs, arrays, etc.); instead, they
+are just raw blocks of bytes, with methods to load and store values of
+specific types at particular byte offsets. This means that Naga must
+translate chains of [`Access`] and [`AccessIndex`] expressions into
+HLSL expressions that compute byte offsets into the buffer.
+
+To generate code for a [`Storage`] access:
+
+- Call [`Writer::fill_access_chain`] on the expression referring to
+ the value. This populates [`Writer::temp_access_chain`] with the
+ appropriate byte offset calculations, as a vector of [`SubAccess`]
+ values.
+
+- Call [`Writer::write_storage_address`] to emit an HLSL expression
+ for a given slice of [`SubAccess`] values.
+
+Naga IR expressions can operate on composite values of any type, but
+[`ByteAddressBuffer`] and [`RWByteAddressBuffer`] have only a fixed
+set of `Load` and `Store` methods, to access one through four
+consecutive 32-bit values. To synthesize a Naga access, you can
+initialize [`temp_access_chain`] to refer to the composite, and then
+temporarily push and pop additional steps on
+[`Writer::temp_access_chain`] to generate accesses to the individual
+elements/members.
+
+The [`temp_access_chain`] field is a member of [`Writer`] solely to
+allow re-use of the `Vec`'s dynamic allocation. Its value is no longer
+needed once HLSL for the access has been generated.
+
+[`Storage`]: crate::AddressSpace::Storage
+[`ByteAddressBuffer`]: https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-object-byteaddressbuffer
+[`RWByteAddressBuffer`]: https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-object-rwbyteaddressbuffer
+[`Access`]: crate::Expression::Access
+[`AccessIndex`]: crate::Expression::AccessIndex
+[`Writer::fill_access_chain`]: super::Writer::fill_access_chain
+[`Writer::write_storage_address`]: super::Writer::write_storage_address
+[`Writer::temp_access_chain`]: super::Writer::temp_access_chain
+[`temp_access_chain`]: super::Writer::temp_access_chain
+[`Writer`]: super::Writer
+*/
+
+use super::{super::FunctionCtx, BackendResult, Error};
+use crate::{
+ proc::{Alignment, NameKey, TypeResolution},
+ Handle,
+};
+
+use std::{fmt, mem};
+
+const STORE_TEMP_NAME: &str = "_value";
+
+/// One step in accessing a [`Storage`] global's component or element.
+///
+/// [`Writer::temp_access_chain`] holds a series of these structures,
+/// describing how to compute the byte offset of a particular element
+/// or member of some global variable in the [`Storage`] address
+/// space.
+///
+/// [`Writer::temp_access_chain`]: super::Writer::temp_access_chain
+/// [`Storage`]: crate::AddressSpace::Storage
+#[derive(Debug)]
+pub(super) enum SubAccess {
+ /// Add the given byte offset. This is used for struct members, or
+ /// known components of a vector or matrix. In all those cases,
+ /// the byte offset is a compile-time constant.
+ Offset(u32),
+
+ /// Scale `value` by `stride`, and add that to the current byte
+ /// offset. This is used to compute the offset of an array element
+ /// whose index is computed at runtime.
+ Index {
+ value: Handle<crate::Expression>,
+ stride: u32,
+ },
+}
+
+pub(super) enum StoreValue {
+ Expression(Handle<crate::Expression>),
+ TempIndex {
+ depth: usize,
+ index: u32,
+ ty: TypeResolution,
+ },
+ TempAccess {
+ depth: usize,
+ base: Handle<crate::Type>,
+ member_index: u32,
+ },
+}
+
+impl<W: fmt::Write> super::Writer<'_, W> {
+ pub(super) fn write_storage_address(
+ &mut self,
+ module: &crate::Module,
+ chain: &[SubAccess],
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ if chain.is_empty() {
+ write!(self.out, "0")?;
+ }
+ for (i, access) in chain.iter().enumerate() {
+ if i != 0 {
+ write!(self.out, "+")?;
+ }
+ match *access {
+ SubAccess::Offset(offset) => {
+ write!(self.out, "{offset}")?;
+ }
+ SubAccess::Index { value, stride } => {
+ self.write_expr(module, value, func_ctx)?;
+ write!(self.out, "*{stride}")?;
+ }
+ }
+ }
+ Ok(())
+ }
+
+ fn write_storage_load_sequence<I: Iterator<Item = (TypeResolution, u32)>>(
+ &mut self,
+ module: &crate::Module,
+ var_handle: Handle<crate::GlobalVariable>,
+ sequence: I,
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ for (i, (ty_resolution, offset)) in sequence.enumerate() {
+ // add the index temporarily
+ self.temp_access_chain.push(SubAccess::Offset(offset));
+ if i != 0 {
+ write!(self.out, ", ")?;
+ };
+ self.write_storage_load(module, var_handle, ty_resolution, func_ctx)?;
+ self.temp_access_chain.pop();
+ }
+ Ok(())
+ }
+
+ /// Emit code to access a [`Storage`] global's component.
+ ///
+ /// Emit HLSL to access the component of `var_handle`, a global
+ /// variable in the [`Storage`] address space, whose type is
+ /// `result_ty` and whose location within the global is given by
+ /// [`self.temp_access_chain`]. See the [`storage`] module's
+ /// documentation for background.
+ ///
+ /// [`Storage`]: crate::AddressSpace::Storage
+ /// [`self.temp_access_chain`]: super::Writer::temp_access_chain
+ pub(super) fn write_storage_load(
+ &mut self,
+ module: &crate::Module,
+ var_handle: Handle<crate::GlobalVariable>,
+ result_ty: TypeResolution,
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ match *result_ty.inner_with(&module.types) {
+ crate::TypeInner::Scalar(scalar) => {
+ // working around the borrow checker in `self.write_expr`
+ let chain = mem::take(&mut self.temp_access_chain);
+ let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
+ let cast = scalar.kind.to_hlsl_cast();
+ write!(self.out, "{cast}({var_name}.Load(")?;
+ self.write_storage_address(module, &chain, func_ctx)?;
+ write!(self.out, "))")?;
+ self.temp_access_chain = chain;
+ }
+ crate::TypeInner::Vector { size, scalar } => {
+ // working around the borrow checker in `self.write_expr`
+ let chain = mem::take(&mut self.temp_access_chain);
+ let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
+ let cast = scalar.kind.to_hlsl_cast();
+ write!(self.out, "{}({}.Load{}(", cast, var_name, size as u8)?;
+ self.write_storage_address(module, &chain, func_ctx)?;
+ write!(self.out, "))")?;
+ self.temp_access_chain = chain;
+ }
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ write!(
+ self.out,
+ "{}{}x{}(",
+ scalar.to_hlsl_str()?,
+ columns as u8,
+ rows as u8,
+ )?;
+
+ // Note: Matrices containing vec3s, due to padding, act like they contain vec4s.
+ let row_stride = Alignment::from(rows) * scalar.width as u32;
+ let iter = (0..columns as u32).map(|i| {
+ let ty_inner = crate::TypeInner::Vector { size: rows, scalar };
+ (TypeResolution::Value(ty_inner), i * row_stride)
+ });
+ self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(size),
+ stride,
+ } => {
+ let constructor = super::help::WrappedConstructor {
+ ty: result_ty.handle().unwrap(),
+ };
+ self.write_wrapped_constructor_function_name(module, constructor)?;
+ write!(self.out, "(")?;
+ let iter = (0..size.get()).map(|i| (TypeResolution::Handle(base), stride * i));
+ self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ crate::TypeInner::Struct { ref members, .. } => {
+ let constructor = super::help::WrappedConstructor {
+ ty: result_ty.handle().unwrap(),
+ };
+ self.write_wrapped_constructor_function_name(module, constructor)?;
+ write!(self.out, "(")?;
+ let iter = members
+ .iter()
+ .map(|m| (TypeResolution::Handle(m.ty), m.offset));
+ self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ _ => unreachable!(),
+ }
+ Ok(())
+ }
+
+ fn write_store_value(
+ &mut self,
+ module: &crate::Module,
+ value: &StoreValue,
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ match *value {
+ StoreValue::Expression(expr) => self.write_expr(module, expr, func_ctx)?,
+ StoreValue::TempIndex {
+ depth,
+ index,
+ ty: _,
+ } => write!(self.out, "{STORE_TEMP_NAME}{depth}[{index}]")?,
+ StoreValue::TempAccess {
+ depth,
+ base,
+ member_index,
+ } => {
+ let name = &self.names[&NameKey::StructMember(base, member_index)];
+ write!(self.out, "{STORE_TEMP_NAME}{depth}.{name}")?
+ }
+ }
+ Ok(())
+ }
+
+ /// Helper function to write down the Store operation on a `ByteAddressBuffer`.
+ pub(super) fn write_storage_store(
+ &mut self,
+ module: &crate::Module,
+ var_handle: Handle<crate::GlobalVariable>,
+ value: StoreValue,
+ func_ctx: &FunctionCtx,
+ level: crate::back::Level,
+ ) -> BackendResult {
+ let temp_resolution;
+ let ty_resolution = match value {
+ StoreValue::Expression(expr) => &func_ctx.info[expr].ty,
+ StoreValue::TempIndex {
+ depth: _,
+ index: _,
+ ref ty,
+ } => ty,
+ StoreValue::TempAccess {
+ depth: _,
+ base,
+ member_index,
+ } => {
+ let ty_handle = match module.types[base].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ members[member_index as usize].ty
+ }
+ _ => unreachable!(),
+ };
+ temp_resolution = TypeResolution::Handle(ty_handle);
+ &temp_resolution
+ }
+ };
+ match *ty_resolution.inner_with(&module.types) {
+ crate::TypeInner::Scalar(_) => {
+ // working around the borrow checker in `self.write_expr`
+ let chain = mem::take(&mut self.temp_access_chain);
+ let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
+ write!(self.out, "{level}{var_name}.Store(")?;
+ self.write_storage_address(module, &chain, func_ctx)?;
+ write!(self.out, ", asuint(")?;
+ self.write_store_value(module, &value, func_ctx)?;
+ writeln!(self.out, "));")?;
+ self.temp_access_chain = chain;
+ }
+ crate::TypeInner::Vector { size, .. } => {
+ // working around the borrow checker in `self.write_expr`
+ let chain = mem::take(&mut self.temp_access_chain);
+ let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
+ write!(self.out, "{}{}.Store{}(", level, var_name, size as u8)?;
+ self.write_storage_address(module, &chain, func_ctx)?;
+ write!(self.out, ", asuint(")?;
+ self.write_store_value(module, &value, func_ctx)?;
+ writeln!(self.out, "));")?;
+ self.temp_access_chain = chain;
+ }
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ // first, assign the value to a temporary
+ writeln!(self.out, "{level}{{")?;
+ let depth = level.0 + 1;
+ write!(
+ self.out,
+ "{}{}{}x{} {}{} = ",
+ level.next(),
+ scalar.to_hlsl_str()?,
+ columns as u8,
+ rows as u8,
+ STORE_TEMP_NAME,
+ depth,
+ )?;
+ self.write_store_value(module, &value, func_ctx)?;
+ writeln!(self.out, ";")?;
+
+ // Note: Matrices containing vec3s, due to padding, act like they contain vec4s.
+ let row_stride = Alignment::from(rows) * scalar.width as u32;
+
+ // then iterate the stores
+ for i in 0..columns as u32 {
+ self.temp_access_chain
+ .push(SubAccess::Offset(i * row_stride));
+ let ty_inner = crate::TypeInner::Vector { size: rows, scalar };
+ let sv = StoreValue::TempIndex {
+ depth,
+ index: i,
+ ty: TypeResolution::Value(ty_inner),
+ };
+ self.write_storage_store(module, var_handle, sv, func_ctx, level.next())?;
+ self.temp_access_chain.pop();
+ }
+ // done
+ writeln!(self.out, "{level}}}")?;
+ }
+ crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(size),
+ stride,
+ } => {
+ // first, assign the value to a temporary
+ writeln!(self.out, "{level}{{")?;
+ write!(self.out, "{}", level.next())?;
+ self.write_value_type(module, &module.types[base].inner)?;
+ let depth = level.next().0;
+ write!(self.out, " {STORE_TEMP_NAME}{depth}")?;
+ self.write_array_size(module, base, crate::ArraySize::Constant(size))?;
+ write!(self.out, " = ")?;
+ self.write_store_value(module, &value, func_ctx)?;
+ writeln!(self.out, ";")?;
+ // then iterate the stores
+ for i in 0..size.get() {
+ self.temp_access_chain.push(SubAccess::Offset(i * stride));
+ let sv = StoreValue::TempIndex {
+ depth,
+ index: i,
+ ty: TypeResolution::Handle(base),
+ };
+ self.write_storage_store(module, var_handle, sv, func_ctx, level.next())?;
+ self.temp_access_chain.pop();
+ }
+ // done
+ writeln!(self.out, "{level}}}")?;
+ }
+ crate::TypeInner::Struct { ref members, .. } => {
+ // first, assign the value to a temporary
+ writeln!(self.out, "{level}{{")?;
+ let depth = level.next().0;
+ let struct_ty = ty_resolution.handle().unwrap();
+ let struct_name = &self.names[&NameKey::Type(struct_ty)];
+ write!(
+ self.out,
+ "{}{} {}{} = ",
+ level.next(),
+ struct_name,
+ STORE_TEMP_NAME,
+ depth
+ )?;
+ self.write_store_value(module, &value, func_ctx)?;
+ writeln!(self.out, ";")?;
+ // then iterate the stores
+ for (i, member) in members.iter().enumerate() {
+ self.temp_access_chain
+ .push(SubAccess::Offset(member.offset));
+ let sv = StoreValue::TempAccess {
+ depth,
+ base: struct_ty,
+ member_index: i as u32,
+ };
+ self.write_storage_store(module, var_handle, sv, func_ctx, level.next())?;
+ self.temp_access_chain.pop();
+ }
+ // done
+ writeln!(self.out, "{level}}}")?;
+ }
+ _ => unreachable!(),
+ }
+ Ok(())
+ }
+
+ /// Set [`temp_access_chain`] to compute the byte offset of `cur_expr`.
+ ///
+ /// The `cur_expr` expression must be a reference to a global
+ /// variable in the [`Storage`] address space, or a chain of
+ /// [`Access`] and [`AccessIndex`] expressions referring to some
+ /// component of such a global.
+ ///
+ /// [`temp_access_chain`]: super::Writer::temp_access_chain
+ /// [`Storage`]: crate::AddressSpace::Storage
+ /// [`Access`]: crate::Expression::Access
+ /// [`AccessIndex`]: crate::Expression::AccessIndex
+ pub(super) fn fill_access_chain(
+ &mut self,
+ module: &crate::Module,
+ mut cur_expr: Handle<crate::Expression>,
+ func_ctx: &FunctionCtx,
+ ) -> Result<Handle<crate::GlobalVariable>, Error> {
+ enum AccessIndex {
+ Expression(Handle<crate::Expression>),
+ Constant(u32),
+ }
+ enum Parent<'a> {
+ Array { stride: u32 },
+ Struct(&'a [crate::StructMember]),
+ }
+ self.temp_access_chain.clear();
+
+ loop {
+ let (next_expr, access_index) = match func_ctx.expressions[cur_expr] {
+ crate::Expression::GlobalVariable(handle) => return Ok(handle),
+ crate::Expression::Access { base, index } => (base, AccessIndex::Expression(index)),
+ crate::Expression::AccessIndex { base, index } => {
+ (base, AccessIndex::Constant(index))
+ }
+ ref other => {
+ return Err(Error::Unimplemented(format!("Pointer access of {other:?}")))
+ }
+ };
+
+ let parent = match *func_ctx.resolve_type(next_expr, &module.types) {
+ crate::TypeInner::Pointer { base, .. } => match module.types[base].inner {
+ crate::TypeInner::Struct { ref members, .. } => Parent::Struct(members),
+ crate::TypeInner::Array { stride, .. } => Parent::Array { stride },
+ crate::TypeInner::Vector { scalar, .. } => Parent::Array {
+ stride: scalar.width as u32,
+ },
+ crate::TypeInner::Matrix { rows, scalar, .. } => Parent::Array {
+ // The stride between matrices is the count of rows as this is how
+ // long each column is.
+ stride: Alignment::from(rows) * scalar.width as u32,
+ },
+ _ => unreachable!(),
+ },
+ crate::TypeInner::ValuePointer { scalar, .. } => Parent::Array {
+ stride: scalar.width as u32,
+ },
+ _ => unreachable!(),
+ };
+
+ let sub = match (parent, access_index) {
+ (Parent::Array { stride }, AccessIndex::Expression(value)) => {
+ SubAccess::Index { value, stride }
+ }
+ (Parent::Array { stride }, AccessIndex::Constant(index)) => {
+ SubAccess::Offset(stride * index)
+ }
+ (Parent::Struct(members), AccessIndex::Constant(index)) => {
+ SubAccess::Offset(members[index as usize].offset)
+ }
+ (Parent::Struct(_), AccessIndex::Expression(_)) => unreachable!(),
+ };
+
+ self.temp_access_chain.push(sub);
+ cur_expr = next_expr;
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/back/hlsl/writer.rs b/third_party/rust/naga/src/back/hlsl/writer.rs
new file mode 100644
index 0000000000..43f7212837
--- /dev/null
+++ b/third_party/rust/naga/src/back/hlsl/writer.rs
@@ -0,0 +1,3366 @@
+use super::{
+ help::{WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess},
+ storage::StoreValue,
+ BackendResult, Error, Options,
+};
+use crate::{
+ back,
+ proc::{self, NameKey},
+ valid, Handle, Module, ScalarKind, ShaderStage, TypeInner,
+};
+use std::{fmt, mem};
+
+const LOCATION_SEMANTIC: &str = "LOC";
+const SPECIAL_CBUF_TYPE: &str = "NagaConstants";
+const SPECIAL_CBUF_VAR: &str = "_NagaConstants";
+const SPECIAL_FIRST_VERTEX: &str = "first_vertex";
+const SPECIAL_FIRST_INSTANCE: &str = "first_instance";
+const SPECIAL_OTHER: &str = "other";
+
+pub(crate) const MODF_FUNCTION: &str = "naga_modf";
+pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
+
+struct EpStructMember {
+ name: String,
+ ty: Handle<crate::Type>,
+ // technically, this should always be `Some`
+ binding: Option<crate::Binding>,
+ index: u32,
+}
+
+/// Structure contains information required for generating
+/// wrapped structure of all entry points arguments
+struct EntryPointBinding {
+ /// Name of the fake EP argument that contains the struct
+ /// with all the flattened input data.
+ arg_name: String,
+ /// Generated structure name
+ ty_name: String,
+ /// Members of generated structure
+ members: Vec<EpStructMember>,
+}
+
+pub(super) struct EntryPointInterface {
+ /// If `Some`, the input of an entry point is gathered in a special
+ /// struct with members sorted by binding.
+ /// The `EntryPointBinding::members` array is sorted by index,
+ /// so that we can walk it in `write_ep_arguments_initialization`.
+ input: Option<EntryPointBinding>,
+ /// If `Some`, the output of an entry point is flattened.
+ /// The `EntryPointBinding::members` array is sorted by binding,
+ /// So that we can walk it in `Statement::Return` handler.
+ output: Option<EntryPointBinding>,
+}
+
+#[derive(Clone, Eq, PartialEq, PartialOrd, Ord)]
+enum InterfaceKey {
+ Location(u32),
+ BuiltIn(crate::BuiltIn),
+ Other,
+}
+
+impl InterfaceKey {
+ const fn new(binding: Option<&crate::Binding>) -> Self {
+ match binding {
+ Some(&crate::Binding::Location { location, .. }) => Self::Location(location),
+ Some(&crate::Binding::BuiltIn(built_in)) => Self::BuiltIn(built_in),
+ None => Self::Other,
+ }
+ }
+}
+
+#[derive(Copy, Clone, PartialEq)]
+enum Io {
+ Input,
+ Output,
+}
+
+impl<'a, W: fmt::Write> super::Writer<'a, W> {
+ pub fn new(out: W, options: &'a Options) -> Self {
+ Self {
+ out,
+ names: crate::FastHashMap::default(),
+ namer: proc::Namer::default(),
+ options,
+ entry_point_io: Vec::new(),
+ named_expressions: crate::NamedExpressions::default(),
+ wrapped: super::Wrapped::default(),
+ temp_access_chain: Vec::new(),
+ need_bake_expressions: Default::default(),
+ }
+ }
+
+ fn reset(&mut self, module: &Module) {
+ self.names.clear();
+ self.namer.reset(
+ module,
+ super::keywords::RESERVED,
+ super::keywords::TYPES,
+ super::keywords::RESERVED_CASE_INSENSITIVE,
+ &[],
+ &mut self.names,
+ );
+ self.entry_point_io.clear();
+ self.named_expressions.clear();
+ self.wrapped.clear();
+ self.need_bake_expressions.clear();
+ }
+
+ /// Helper method used to find which expressions of a given function require baking
+ ///
+ /// # Notes
+ /// Clears `need_bake_expressions` set before adding to it
+ fn update_expressions_to_bake(
+ &mut self,
+ module: &Module,
+ func: &crate::Function,
+ info: &valid::FunctionInfo,
+ ) {
+ use crate::Expression;
+ self.need_bake_expressions.clear();
+ for (fun_handle, expr) in func.expressions.iter() {
+ let expr_info = &info[fun_handle];
+ let min_ref_count = func.expressions[fun_handle].bake_ref_count();
+ if min_ref_count <= expr_info.ref_count {
+ self.need_bake_expressions.insert(fun_handle);
+ }
+
+ if let Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } = *expr
+ {
+ match fun {
+ crate::MathFunction::Asinh
+ | crate::MathFunction::Acosh
+ | crate::MathFunction::Atanh
+ | crate::MathFunction::Unpack2x16float
+ | crate::MathFunction::Unpack2x16snorm
+ | crate::MathFunction::Unpack2x16unorm
+ | crate::MathFunction::Unpack4x8snorm
+ | crate::MathFunction::Unpack4x8unorm
+ | crate::MathFunction::Pack2x16float
+ | crate::MathFunction::Pack2x16snorm
+ | crate::MathFunction::Pack2x16unorm
+ | crate::MathFunction::Pack4x8snorm
+ | crate::MathFunction::Pack4x8unorm => {
+ self.need_bake_expressions.insert(arg);
+ }
+ crate::MathFunction::ExtractBits => {
+ self.need_bake_expressions.insert(arg);
+ self.need_bake_expressions.insert(arg1.unwrap());
+ self.need_bake_expressions.insert(arg2.unwrap());
+ }
+ crate::MathFunction::InsertBits => {
+ self.need_bake_expressions.insert(arg);
+ self.need_bake_expressions.insert(arg1.unwrap());
+ self.need_bake_expressions.insert(arg2.unwrap());
+ self.need_bake_expressions.insert(arg3.unwrap());
+ }
+ crate::MathFunction::CountLeadingZeros => {
+ let inner = info[fun_handle].ty.inner_with(&module.types);
+ if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() {
+ self.need_bake_expressions.insert(arg);
+ }
+ }
+ _ => {}
+ }
+ }
+
+ if let Expression::Derivative { axis, ctrl, expr } = *expr {
+ use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
+ if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
+ self.need_bake_expressions.insert(expr);
+ }
+ }
+ }
+ }
+
+ pub fn write(
+ &mut self,
+ module: &Module,
+ module_info: &valid::ModuleInfo,
+ ) -> Result<super::ReflectionInfo, Error> {
+ self.reset(module);
+
+ // Write special constants, if needed
+ if let Some(ref bt) = self.options.special_constants_binding {
+ writeln!(self.out, "struct {SPECIAL_CBUF_TYPE} {{")?;
+ writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_VERTEX)?;
+ writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_INSTANCE)?;
+ writeln!(self.out, "{}uint {};", back::INDENT, SPECIAL_OTHER)?;
+ writeln!(self.out, "}};")?;
+ write!(
+ self.out,
+ "ConstantBuffer<{}> {}: register(b{}",
+ SPECIAL_CBUF_TYPE, SPECIAL_CBUF_VAR, bt.register
+ )?;
+ if bt.space != 0 {
+ write!(self.out, ", space{}", bt.space)?;
+ }
+ writeln!(self.out, ");")?;
+
+ // Extra newline for readability
+ writeln!(self.out)?;
+ }
+
+ // Save all entry point output types
+ let ep_results = module
+ .entry_points
+ .iter()
+ .map(|ep| (ep.stage, ep.function.result.clone()))
+ .collect::<Vec<(ShaderStage, Option<crate::FunctionResult>)>>();
+
+ self.write_all_mat_cx2_typedefs_and_functions(module)?;
+
+ // Write all structs
+ for (handle, ty) in module.types.iter() {
+ if let TypeInner::Struct { ref members, span } = ty.inner {
+ if module.types[members.last().unwrap().ty]
+ .inner
+ .is_dynamically_sized(&module.types)
+ {
+ // unsized arrays can only be in storage buffers,
+ // for which we use `ByteAddressBuffer` anyway.
+ continue;
+ }
+
+ let ep_result = ep_results.iter().find(|e| {
+ if let Some(ref result) = e.1 {
+ result.ty == handle
+ } else {
+ false
+ }
+ });
+
+ self.write_struct(
+ module,
+ handle,
+ members,
+ span,
+ ep_result.map(|r| (r.0, Io::Output)),
+ )?;
+ writeln!(self.out)?;
+ }
+ }
+
+ self.write_special_functions(module)?;
+
+ self.write_wrapped_compose_functions(module, &module.const_expressions)?;
+
+ // Write all named constants
+ let mut constants = module
+ .constants
+ .iter()
+ .filter(|&(_, c)| c.name.is_some())
+ .peekable();
+ while let Some((handle, _)) = constants.next() {
+ self.write_global_constant(module, handle)?;
+ // Add extra newline for readability on last iteration
+ if constants.peek().is_none() {
+ writeln!(self.out)?;
+ }
+ }
+
+ // Write all globals
+ for (ty, _) in module.global_variables.iter() {
+ self.write_global(module, ty)?;
+ }
+
+ if !module.global_variables.is_empty() {
+ // Add extra newline for readability
+ writeln!(self.out)?;
+ }
+
+ // Write all entry points wrapped structs
+ for (index, ep) in module.entry_points.iter().enumerate() {
+ let ep_name = self.names[&NameKey::EntryPoint(index as u16)].clone();
+ let ep_io = self.write_ep_interface(module, &ep.function, ep.stage, &ep_name)?;
+ self.entry_point_io.push(ep_io);
+ }
+
+ // Write all regular functions
+ for (handle, function) in module.functions.iter() {
+ let info = &module_info[handle];
+
+ // Check if all of the globals are accessible
+ if !self.options.fake_missing_bindings {
+ if let Some((var_handle, _)) =
+ module
+ .global_variables
+ .iter()
+ .find(|&(var_handle, var)| match var.binding {
+ Some(ref binding) if !info[var_handle].is_empty() => {
+ self.options.resolve_resource_binding(binding).is_err()
+ }
+ _ => false,
+ })
+ {
+ log::info!(
+ "Skipping function {:?} (name {:?}) because global {:?} is inaccessible",
+ handle,
+ function.name,
+ var_handle
+ );
+ continue;
+ }
+ }
+
+ let ctx = back::FunctionCtx {
+ ty: back::FunctionType::Function(handle),
+ info,
+ expressions: &function.expressions,
+ named_expressions: &function.named_expressions,
+ };
+ let name = self.names[&NameKey::Function(handle)].clone();
+
+ self.write_wrapped_functions(module, &ctx)?;
+
+ self.write_function(module, name.as_str(), function, &ctx, info)?;
+
+ writeln!(self.out)?;
+ }
+
+ let mut entry_point_names = Vec::with_capacity(module.entry_points.len());
+
+ // Write all entry points
+ for (index, ep) in module.entry_points.iter().enumerate() {
+ let info = module_info.get_entry_point(index);
+
+ if !self.options.fake_missing_bindings {
+ let mut ep_error = None;
+ for (var_handle, var) in module.global_variables.iter() {
+ match var.binding {
+ Some(ref binding) if !info[var_handle].is_empty() => {
+ if let Err(err) = self.options.resolve_resource_binding(binding) {
+ ep_error = Some(err);
+ break;
+ }
+ }
+ _ => {}
+ }
+ }
+ if let Some(err) = ep_error {
+ entry_point_names.push(Err(err));
+ continue;
+ }
+ }
+
+ let ctx = back::FunctionCtx {
+ ty: back::FunctionType::EntryPoint(index as u16),
+ info,
+ expressions: &ep.function.expressions,
+ named_expressions: &ep.function.named_expressions,
+ };
+
+ self.write_wrapped_functions(module, &ctx)?;
+
+ if ep.stage == ShaderStage::Compute {
+ // HLSL is calling workgroup size "num threads"
+ let num_threads = ep.workgroup_size;
+ writeln!(
+ self.out,
+ "[numthreads({}, {}, {})]",
+ num_threads[0], num_threads[1], num_threads[2]
+ )?;
+ }
+
+ let name = self.names[&NameKey::EntryPoint(index as u16)].clone();
+ self.write_function(module, &name, &ep.function, &ctx, info)?;
+
+ if index < module.entry_points.len() - 1 {
+ writeln!(self.out)?;
+ }
+
+ entry_point_names.push(Ok(name));
+ }
+
+ Ok(super::ReflectionInfo { entry_point_names })
+ }
+
+ fn write_modifier(&mut self, binding: &crate::Binding) -> BackendResult {
+ match *binding {
+ crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }) => {
+ write!(self.out, "precise ")?;
+ }
+ crate::Binding::Location {
+ interpolation,
+ sampling,
+ ..
+ } => {
+ if let Some(interpolation) = interpolation {
+ if let Some(string) = interpolation.to_hlsl_str() {
+ write!(self.out, "{string} ")?
+ }
+ }
+
+ if let Some(sampling) = sampling {
+ if let Some(string) = sampling.to_hlsl_str() {
+ write!(self.out, "{string} ")?
+ }
+ }
+ }
+ crate::Binding::BuiltIn(_) => {}
+ }
+
+ Ok(())
+ }
+
+ //TODO: we could force fragment outputs to always go through `entry_point_io.output` path
+ // if they are struct, so that the `stage` argument here could be omitted.
+ fn write_semantic(
+ &mut self,
+ binding: &crate::Binding,
+ stage: Option<(ShaderStage, Io)>,
+ ) -> BackendResult {
+ match *binding {
+ crate::Binding::BuiltIn(builtin) => {
+ let builtin_str = builtin.to_hlsl_str()?;
+ write!(self.out, " : {builtin_str}")?;
+ }
+ crate::Binding::Location {
+ second_blend_source: true,
+ ..
+ } => {
+ write!(self.out, " : SV_Target1")?;
+ }
+ crate::Binding::Location {
+ location,
+ second_blend_source: false,
+ ..
+ } => {
+ if stage == Some((crate::ShaderStage::Fragment, Io::Output)) {
+ write!(self.out, " : SV_Target{location}")?;
+ } else {
+ write!(self.out, " : {LOCATION_SEMANTIC}{location}")?;
+ }
+ }
+ }
+
+ Ok(())
+ }
+
+ fn write_interface_struct(
+ &mut self,
+ module: &Module,
+ shader_stage: (ShaderStage, Io),
+ struct_name: String,
+ mut members: Vec<EpStructMember>,
+ ) -> Result<EntryPointBinding, Error> {
+ // Sort the members so that first come the user-defined varyings
+ // in ascending locations, and then built-ins. This allows VS and FS
+ // interfaces to match with regards to order.
+ members.sort_by_key(|m| InterfaceKey::new(m.binding.as_ref()));
+
+ write!(self.out, "struct {struct_name}")?;
+ writeln!(self.out, " {{")?;
+ for m in members.iter() {
+ write!(self.out, "{}", back::INDENT)?;
+ if let Some(ref binding) = m.binding {
+ self.write_modifier(binding)?;
+ }
+ self.write_type(module, m.ty)?;
+ write!(self.out, " {}", &m.name)?;
+ if let Some(ref binding) = m.binding {
+ self.write_semantic(binding, Some(shader_stage))?;
+ }
+ writeln!(self.out, ";")?;
+ }
+ writeln!(self.out, "}};")?;
+ writeln!(self.out)?;
+
+ match shader_stage.1 {
+ Io::Input => {
+ // bring back the original order
+ members.sort_by_key(|m| m.index);
+ }
+ Io::Output => {
+ // keep it sorted by binding
+ }
+ }
+
+ Ok(EntryPointBinding {
+ arg_name: self.namer.call(struct_name.to_lowercase().as_str()),
+ ty_name: struct_name,
+ members,
+ })
+ }
+
+ /// Flatten all entry point arguments into a single struct.
+ /// This is needed since we need to re-order them: first placing user locations,
+ /// then built-ins.
+ fn write_ep_input_struct(
+ &mut self,
+ module: &Module,
+ func: &crate::Function,
+ stage: ShaderStage,
+ entry_point_name: &str,
+ ) -> Result<EntryPointBinding, Error> {
+ let struct_name = format!("{stage:?}Input_{entry_point_name}");
+
+ let mut fake_members = Vec::new();
+ for arg in func.arguments.iter() {
+ match module.types[arg.ty].inner {
+ TypeInner::Struct { ref members, .. } => {
+ for member in members.iter() {
+ let name = self.namer.call_or(&member.name, "member");
+ let index = fake_members.len() as u32;
+ fake_members.push(EpStructMember {
+ name,
+ ty: member.ty,
+ binding: member.binding.clone(),
+ index,
+ });
+ }
+ }
+ _ => {
+ let member_name = self.namer.call_or(&arg.name, "member");
+ let index = fake_members.len() as u32;
+ fake_members.push(EpStructMember {
+ name: member_name,
+ ty: arg.ty,
+ binding: arg.binding.clone(),
+ index,
+ });
+ }
+ }
+ }
+
+ self.write_interface_struct(module, (stage, Io::Input), struct_name, fake_members)
+ }
+
+ /// Flatten all entry point results into a single struct.
+ /// This is needed since we need to re-order them: first placing user locations,
+ /// then built-ins.
+ fn write_ep_output_struct(
+ &mut self,
+ module: &Module,
+ result: &crate::FunctionResult,
+ stage: ShaderStage,
+ entry_point_name: &str,
+ ) -> Result<EntryPointBinding, Error> {
+ let struct_name = format!("{stage:?}Output_{entry_point_name}");
+
+ let mut fake_members = Vec::new();
+ let empty = [];
+ let members = match module.types[result.ty].inner {
+ TypeInner::Struct { ref members, .. } => members,
+ ref other => {
+ log::error!("Unexpected {:?} output type without a binding", other);
+ &empty[..]
+ }
+ };
+
+ for member in members.iter() {
+ let member_name = self.namer.call_or(&member.name, "member");
+ let index = fake_members.len() as u32;
+ fake_members.push(EpStructMember {
+ name: member_name,
+ ty: member.ty,
+ binding: member.binding.clone(),
+ index,
+ });
+ }
+
+ self.write_interface_struct(module, (stage, Io::Output), struct_name, fake_members)
+ }
+
+ /// Writes special interface structures for an entry point. The special structures have
+ /// all the fields flattened into them and sorted by binding. They are only needed for
+ /// VS outputs and FS inputs, so that these interfaces match.
+ fn write_ep_interface(
+ &mut self,
+ module: &Module,
+ func: &crate::Function,
+ stage: ShaderStage,
+ ep_name: &str,
+ ) -> Result<EntryPointInterface, Error> {
+ Ok(EntryPointInterface {
+ input: if !func.arguments.is_empty() && stage == ShaderStage::Fragment {
+ Some(self.write_ep_input_struct(module, func, stage, ep_name)?)
+ } else {
+ None
+ },
+ output: match func.result {
+ Some(ref fr) if fr.binding.is_none() && stage == ShaderStage::Vertex => {
+ Some(self.write_ep_output_struct(module, fr, stage, ep_name)?)
+ }
+ _ => None,
+ },
+ })
+ }
+
+ /// Write an entry point preface that initializes the arguments as specified in IR.
+ fn write_ep_arguments_initialization(
+ &mut self,
+ module: &Module,
+ func: &crate::Function,
+ ep_index: u16,
+ ) -> BackendResult {
+ let ep_input = match self.entry_point_io[ep_index as usize].input.take() {
+ Some(ep_input) => ep_input,
+ None => return Ok(()),
+ };
+ let mut fake_iter = ep_input.members.iter();
+ for (arg_index, arg) in func.arguments.iter().enumerate() {
+ write!(self.out, "{}", back::INDENT)?;
+ self.write_type(module, arg.ty)?;
+ let arg_name = &self.names[&NameKey::EntryPointArgument(ep_index, arg_index as u32)];
+ write!(self.out, " {arg_name}")?;
+ match module.types[arg.ty].inner {
+ TypeInner::Array { base, size, .. } => {
+ self.write_array_size(module, base, size)?;
+ let fake_member = fake_iter.next().unwrap();
+ writeln!(self.out, " = {}.{};", ep_input.arg_name, fake_member.name)?;
+ }
+ TypeInner::Struct { ref members, .. } => {
+ write!(self.out, " = {{ ")?;
+ for index in 0..members.len() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ let fake_member = fake_iter.next().unwrap();
+ write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
+ }
+ writeln!(self.out, " }};")?;
+ }
+ _ => {
+ let fake_member = fake_iter.next().unwrap();
+ writeln!(self.out, " = {}.{};", ep_input.arg_name, fake_member.name)?;
+ }
+ }
+ }
+ assert!(fake_iter.next().is_none());
+ Ok(())
+ }
+
+ /// Helper method used to write global variables
+ /// # Notes
+ /// Always adds a newline
+ fn write_global(
+ &mut self,
+ module: &Module,
+ handle: Handle<crate::GlobalVariable>,
+ ) -> BackendResult {
+ let global = &module.global_variables[handle];
+ let inner = &module.types[global.ty].inner;
+
+ if let Some(ref binding) = global.binding {
+ if let Err(err) = self.options.resolve_resource_binding(binding) {
+ log::info!(
+ "Skipping global {:?} (name {:?}) for being inaccessible: {}",
+ handle,
+ global.name,
+ err,
+ );
+ return Ok(());
+ }
+ }
+
+ // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-variable-register
+ let register_ty = match global.space {
+ crate::AddressSpace::Function => unreachable!("Function address space"),
+ crate::AddressSpace::Private => {
+ write!(self.out, "static ")?;
+ self.write_type(module, global.ty)?;
+ ""
+ }
+ crate::AddressSpace::WorkGroup => {
+ write!(self.out, "groupshared ")?;
+ self.write_type(module, global.ty)?;
+ ""
+ }
+ crate::AddressSpace::Uniform => {
+ // constant buffer declarations are expected to be inlined, e.g.
+ // `cbuffer foo: register(b0) { field1: type1; }`
+ write!(self.out, "cbuffer")?;
+ "b"
+ }
+ crate::AddressSpace::Storage { access } => {
+ let (prefix, register) = if access.contains(crate::StorageAccess::STORE) {
+ ("RW", "u")
+ } else {
+ ("", "t")
+ };
+ write!(self.out, "{prefix}ByteAddressBuffer")?;
+ register
+ }
+ crate::AddressSpace::Handle => {
+ let handle_ty = match *inner {
+ TypeInner::BindingArray { ref base, .. } => &module.types[*base].inner,
+ _ => inner,
+ };
+
+ let register = match *handle_ty {
+ TypeInner::Sampler { .. } => "s",
+ // all storage textures are UAV, unconditionally
+ TypeInner::Image {
+ class: crate::ImageClass::Storage { .. },
+ ..
+ } => "u",
+ _ => "t",
+ };
+ self.write_type(module, global.ty)?;
+ register
+ }
+ crate::AddressSpace::PushConstant => {
+ // The type of the push constants will be wrapped in `ConstantBuffer`
+ write!(self.out, "ConstantBuffer<")?;
+ "b"
+ }
+ };
+
+ // If the global is a push constant write the type now because it will be a
+ // generic argument to `ConstantBuffer`
+ if global.space == crate::AddressSpace::PushConstant {
+ self.write_global_type(module, global.ty)?;
+
+ // need to write the array size if the type was emitted with `write_type`
+ if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+
+ // Close the angled brackets for the generic argument
+ write!(self.out, ">")?;
+ }
+
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ write!(self.out, " {name}")?;
+
+ // Push constants need to be assigned a binding explicitly by the consumer
+ // since naga has no way to know the binding from the shader alone
+ if global.space == crate::AddressSpace::PushConstant {
+ let target = self
+ .options
+ .push_constants_target
+ .as_ref()
+ .expect("No bind target was defined for the push constants block");
+ write!(self.out, ": register(b{}", target.register)?;
+ if target.space != 0 {
+ write!(self.out, ", space{}", target.space)?;
+ }
+ write!(self.out, ")")?;
+ }
+
+ if let Some(ref binding) = global.binding {
+ // this was already resolved earlier when we started evaluating an entry point.
+ let bt = self.options.resolve_resource_binding(binding).unwrap();
+
+ // need to write the binding array size if the type was emitted with `write_type`
+ if let TypeInner::BindingArray { base, size, .. } = module.types[global.ty].inner {
+ if let Some(overridden_size) = bt.binding_array_size {
+ write!(self.out, "[{overridden_size}]")?;
+ } else {
+ self.write_array_size(module, base, size)?;
+ }
+ }
+
+ write!(self.out, " : register({}{}", register_ty, bt.register)?;
+ if bt.space != 0 {
+ write!(self.out, ", space{}", bt.space)?;
+ }
+ write!(self.out, ")")?;
+ } else {
+ // need to write the array size if the type was emitted with `write_type`
+ if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+ if global.space == crate::AddressSpace::Private {
+ write!(self.out, " = ")?;
+ if let Some(init) = global.init {
+ self.write_const_expression(module, init)?;
+ } else {
+ self.write_default_init(module, global.ty)?;
+ }
+ }
+ }
+
+ if global.space == crate::AddressSpace::Uniform {
+ write!(self.out, " {{ ")?;
+
+ self.write_global_type(module, global.ty)?;
+
+ write!(
+ self.out,
+ " {}",
+ &self.names[&NameKey::GlobalVariable(handle)]
+ )?;
+
+ // need to write the array size if the type was emitted with `write_type`
+ if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+
+ writeln!(self.out, "; }}")?;
+ } else {
+ writeln!(self.out, ";")?;
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write global constants
+ ///
+ /// # Notes
+ /// Ends in a newline
+ fn write_global_constant(
+ &mut self,
+ module: &Module,
+ handle: Handle<crate::Constant>,
+ ) -> BackendResult {
+ write!(self.out, "static const ")?;
+ let constant = &module.constants[handle];
+ self.write_type(module, constant.ty)?;
+ let name = &self.names[&NameKey::Constant(handle)];
+ write!(self.out, " {}", name)?;
+ // Write size for array type
+ if let TypeInner::Array { base, size, .. } = module.types[constant.ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+ write!(self.out, " = ")?;
+ self.write_const_expression(module, constant.init)?;
+ writeln!(self.out, ";")?;
+ Ok(())
+ }
+
+ pub(super) fn write_array_size(
+ &mut self,
+ module: &Module,
+ base: Handle<crate::Type>,
+ size: crate::ArraySize,
+ ) -> BackendResult {
+ write!(self.out, "[")?;
+
+ match size {
+ crate::ArraySize::Constant(size) => {
+ write!(self.out, "{size}")?;
+ }
+ crate::ArraySize::Dynamic => unreachable!(),
+ }
+
+ write!(self.out, "]")?;
+
+ if let TypeInner::Array {
+ base: next_base,
+ size: next_size,
+ ..
+ } = module.types[base].inner
+ {
+ self.write_array_size(module, next_base, next_size)?;
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write structs
+ ///
+ /// # Notes
+ /// Ends in a newline
+ fn write_struct(
+ &mut self,
+ module: &Module,
+ handle: Handle<crate::Type>,
+ members: &[crate::StructMember],
+ span: u32,
+ shader_stage: Option<(ShaderStage, Io)>,
+ ) -> BackendResult {
+ // Write struct name
+ let struct_name = &self.names[&NameKey::Type(handle)];
+ writeln!(self.out, "struct {struct_name} {{")?;
+
+ let mut last_offset = 0;
+ for (index, member) in members.iter().enumerate() {
+ if member.binding.is_none() && member.offset > last_offset {
+ // using int as padding should work as long as the backend
+ // doesn't support a type that's less than 4 bytes in size
+ // (Error::UnsupportedScalar catches this)
+ let padding = (member.offset - last_offset) / 4;
+ for i in 0..padding {
+ writeln!(self.out, "{}int _pad{}_{};", back::INDENT, index, i)?;
+ }
+ }
+ let ty_inner = &module.types[member.ty].inner;
+ last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx());
+
+ // The indentation is only for readability
+ write!(self.out, "{}", back::INDENT)?;
+
+ match module.types[member.ty].inner {
+ TypeInner::Array { base, size, .. } => {
+ // HLSL arrays are written as `type name[size]`
+
+ self.write_global_type(module, member.ty)?;
+
+ // Write `name`
+ write!(
+ self.out,
+ " {}",
+ &self.names[&NameKey::StructMember(handle, index as u32)]
+ )?;
+ // Write [size]
+ self.write_array_size(module, base, size)?;
+ }
+ // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
+ // See the module-level block comment in mod.rs for details.
+ TypeInner::Matrix {
+ rows,
+ columns,
+ scalar,
+ } if member.binding.is_none() && rows == crate::VectorSize::Bi => {
+ let vec_ty = crate::TypeInner::Vector { size: rows, scalar };
+ let field_name_key = NameKey::StructMember(handle, index as u32);
+
+ for i in 0..columns as u8 {
+ if i != 0 {
+ write!(self.out, "; ")?;
+ }
+ self.write_value_type(module, &vec_ty)?;
+ write!(self.out, " {}_{}", &self.names[&field_name_key], i)?;
+ }
+ }
+ _ => {
+ // Write modifier before type
+ if let Some(ref binding) = member.binding {
+ self.write_modifier(binding)?;
+ }
+
+ // Even though Naga IR matrices are column-major, we must describe
+ // matrices passed from the CPU as being in row-major order.
+ // See the module-level block comment in mod.rs for details.
+ if let TypeInner::Matrix { .. } = module.types[member.ty].inner {
+ write!(self.out, "row_major ")?;
+ }
+
+ // Write the member type and name
+ self.write_type(module, member.ty)?;
+ write!(
+ self.out,
+ " {}",
+ &self.names[&NameKey::StructMember(handle, index as u32)]
+ )?;
+ }
+ }
+
+ if let Some(ref binding) = member.binding {
+ self.write_semantic(binding, shader_stage)?;
+ };
+ writeln!(self.out, ";")?;
+ }
+
+ // add padding at the end since sizes of types don't get rounded up to their alignment in HLSL
+ if members.last().unwrap().binding.is_none() && span > last_offset {
+ let padding = (span - last_offset) / 4;
+ for i in 0..padding {
+ writeln!(self.out, "{}int _end_pad_{};", back::INDENT, i)?;
+ }
+ }
+
+ writeln!(self.out, "}};")?;
+ Ok(())
+ }
+
+ /// Helper method used to write global/structs non image/sampler types
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ pub(super) fn write_global_type(
+ &mut self,
+ module: &Module,
+ ty: Handle<crate::Type>,
+ ) -> BackendResult {
+ let matrix_data = get_inner_matrix_data(module, ty);
+
+ // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
+ // See the module-level block comment in mod.rs for details.
+ if let Some(MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = matrix_data
+ {
+ write!(self.out, "__mat{}x2", columns as u8)?;
+ } else {
+ // Even though Naga IR matrices are column-major, we must describe
+ // matrices passed from the CPU as being in row-major order.
+ // See the module-level block comment in mod.rs for details.
+ if matrix_data.is_some() {
+ write!(self.out, "row_major ")?;
+ }
+
+ self.write_type(module, ty)?;
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write non image/sampler types
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ pub(super) fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
+ let inner = &module.types[ty].inner;
+ match *inner {
+ TypeInner::Struct { .. } => write!(self.out, "{}", self.names[&NameKey::Type(ty)])?,
+ // hlsl array has the size separated from the base type
+ TypeInner::Array { base, .. } | TypeInner::BindingArray { base, .. } => {
+ self.write_type(module, base)?
+ }
+ ref other => self.write_value_type(module, other)?,
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write value types
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ pub(super) fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult {
+ match *inner {
+ TypeInner::Scalar(scalar) | TypeInner::Atomic(scalar) => {
+ write!(self.out, "{}", scalar.to_hlsl_str()?)?;
+ }
+ TypeInner::Vector { size, scalar } => {
+ write!(
+ self.out,
+ "{}{}",
+ scalar.to_hlsl_str()?,
+ back::vector_size_str(size)
+ )?;
+ }
+ TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ // The IR supports only float matrix
+ // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-matrix
+
+ // Because of the implicit transpose all matrices have in HLSL, we need to transpose the size as well.
+ write!(
+ self.out,
+ "{}{}x{}",
+ scalar.to_hlsl_str()?,
+ back::vector_size_str(columns),
+ back::vector_size_str(rows),
+ )?;
+ }
+ TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ self.write_image_type(dim, arrayed, class)?;
+ }
+ TypeInner::Sampler { comparison } => {
+ let sampler = if comparison {
+ "SamplerComparisonState"
+ } else {
+ "SamplerState"
+ };
+ write!(self.out, "{sampler}")?;
+ }
+ // HLSL arrays are written as `type name[size]`
+ // Current code is written arrays only as `[size]`
+ // Base `type` and `name` should be written outside
+ TypeInner::Array { base, size, .. } | TypeInner::BindingArray { base, size } => {
+ self.write_array_size(module, base, size)?;
+ }
+ _ => return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))),
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write functions
+ /// # Notes
+ /// Ends in a newline
+ fn write_function(
+ &mut self,
+ module: &Module,
+ name: &str,
+ func: &crate::Function,
+ func_ctx: &back::FunctionCtx<'_>,
+ info: &valid::FunctionInfo,
+ ) -> BackendResult {
+ // Function Declaration Syntax - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-function-syntax
+
+ self.update_expressions_to_bake(module, func, info);
+
+ // Write modifier
+ if let Some(crate::FunctionResult {
+ binding:
+ Some(
+ ref binding @ crate::Binding::BuiltIn(crate::BuiltIn::Position {
+ invariant: true,
+ }),
+ ),
+ ..
+ }) = func.result
+ {
+ self.write_modifier(binding)?;
+ }
+
+ // Write return type
+ if let Some(ref result) = func.result {
+ match func_ctx.ty {
+ back::FunctionType::Function(_) => {
+ self.write_type(module, result.ty)?;
+ }
+ back::FunctionType::EntryPoint(index) => {
+ if let Some(ref ep_output) = self.entry_point_io[index as usize].output {
+ write!(self.out, "{}", ep_output.ty_name)?;
+ } else {
+ self.write_type(module, result.ty)?;
+ }
+ }
+ }
+ } else {
+ write!(self.out, "void")?;
+ }
+
+ // Write function name
+ write!(self.out, " {name}(")?;
+
+ let need_workgroup_variables_initialization =
+ self.need_workgroup_variables_initialization(func_ctx, module);
+
+ // Write function arguments for non entry point functions
+ match func_ctx.ty {
+ back::FunctionType::Function(handle) => {
+ for (index, arg) in func.arguments.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ // Write argument type
+ let arg_ty = match module.types[arg.ty].inner {
+ // pointers in function arguments are expected and resolve to `inout`
+ TypeInner::Pointer { base, .. } => {
+ //TODO: can we narrow this down to just `in` when possible?
+ write!(self.out, "inout ")?;
+ base
+ }
+ _ => arg.ty,
+ };
+ self.write_type(module, arg_ty)?;
+
+ let argument_name =
+ &self.names[&NameKey::FunctionArgument(handle, index as u32)];
+
+ // Write argument name. Space is important.
+ write!(self.out, " {argument_name}")?;
+ if let TypeInner::Array { base, size, .. } = module.types[arg_ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+ }
+ }
+ back::FunctionType::EntryPoint(ep_index) => {
+ if let Some(ref ep_input) = self.entry_point_io[ep_index as usize].input {
+ write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name,)?;
+ } else {
+ let stage = module.entry_points[ep_index as usize].stage;
+ for (index, arg) in func.arguments.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ self.write_type(module, arg.ty)?;
+
+ let argument_name =
+ &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)];
+
+ write!(self.out, " {argument_name}")?;
+ if let TypeInner::Array { base, size, .. } = module.types[arg.ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+
+ if let Some(ref binding) = arg.binding {
+ self.write_semantic(binding, Some((stage, Io::Input)))?;
+ }
+ }
+
+ if need_workgroup_variables_initialization {
+ if !func.arguments.is_empty() {
+ write!(self.out, ", ")?;
+ }
+ write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?;
+ }
+ }
+ }
+ }
+ // Ends of arguments
+ write!(self.out, ")")?;
+
+ // Write semantic if it present
+ if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
+ let stage = module.entry_points[index as usize].stage;
+ if let Some(crate::FunctionResult {
+ binding: Some(ref binding),
+ ..
+ }) = func.result
+ {
+ self.write_semantic(binding, Some((stage, Io::Output)))?;
+ }
+ }
+
+ // Function body start
+ writeln!(self.out)?;
+ writeln!(self.out, "{{")?;
+
+ if need_workgroup_variables_initialization {
+ self.write_workgroup_variables_initialization(func_ctx, module)?;
+ }
+
+ if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
+ self.write_ep_arguments_initialization(module, func, index)?;
+ }
+
+ // Write function local variables
+ for (handle, local) in func.local_variables.iter() {
+ // Write indentation (only for readability)
+ write!(self.out, "{}", back::INDENT)?;
+
+ // Write the local name
+ // The leading space is important
+ self.write_type(module, local.ty)?;
+ write!(self.out, " {}", self.names[&func_ctx.name_key(handle)])?;
+ // Write size for array type
+ if let TypeInner::Array { base, size, .. } = module.types[local.ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+
+ write!(self.out, " = ")?;
+ // Write the local initializer if needed
+ if let Some(init) = local.init {
+ self.write_expr(module, init, func_ctx)?;
+ } else {
+ // Zero initialize local variables
+ self.write_default_init(module, local.ty)?;
+ }
+
+ // Finish the local with `;` and add a newline (only for readability)
+ writeln!(self.out, ";")?
+ }
+
+ if !func.local_variables.is_empty() {
+ writeln!(self.out)?;
+ }
+
+ // Write the function body (statement list)
+ for sta in func.body.iter() {
+ // The indentation should always be 1 when writing the function body
+ self.write_stmt(module, sta, func_ctx, back::Level(1))?;
+ }
+
+ writeln!(self.out, "}}")?;
+
+ self.named_expressions.clear();
+
+ Ok(())
+ }
+
+ fn need_workgroup_variables_initialization(
+ &mut self,
+ func_ctx: &back::FunctionCtx,
+ module: &Module,
+ ) -> bool {
+ self.options.zero_initialize_workgroup_memory
+ && func_ctx.ty.is_compute_entry_point(module)
+ && module.global_variables.iter().any(|(handle, var)| {
+ !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
+ })
+ }
+
+ fn write_workgroup_variables_initialization(
+ &mut self,
+ func_ctx: &back::FunctionCtx,
+ module: &Module,
+ ) -> BackendResult {
+ let level = back::Level(1);
+
+ writeln!(
+ self.out,
+ "{level}if (all(__local_invocation_id == uint3(0u, 0u, 0u))) {{"
+ )?;
+
+ let vars = module.global_variables.iter().filter(|&(handle, var)| {
+ !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
+ });
+
+ for (handle, var) in vars {
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ write!(self.out, "{}{} = ", level.next(), name)?;
+ self.write_default_init(module, var.ty)?;
+ writeln!(self.out, ";")?;
+ }
+
+ writeln!(self.out, "{level}}}")?;
+ self.write_barrier(crate::Barrier::WORK_GROUP, level)
+ }
+
+ /// Helper method used to write statements
+ ///
+ /// # Notes
+ /// Always adds a newline
+ fn write_stmt(
+ &mut self,
+ module: &Module,
+ stmt: &crate::Statement,
+ func_ctx: &back::FunctionCtx<'_>,
+ level: back::Level,
+ ) -> BackendResult {
+ use crate::Statement;
+
+ match *stmt {
+ Statement::Emit(ref range) => {
+ for handle in range.clone() {
+ let ptr_class = func_ctx.resolve_type(handle, &module.types).pointer_space();
+ let expr_name = if ptr_class.is_some() {
+ // HLSL can't save a pointer-valued expression in a variable,
+ // but we shouldn't ever need to: they should never be named expressions,
+ // and none of the expression types flagged by bake_ref_count can be pointer-valued.
+ None
+ } else if let Some(name) = func_ctx.named_expressions.get(&handle) {
+ // Front end provides names for all variables at the start of writing.
+ // But we write them to step by step. We need to recache them
+ // Otherwise, we could accidentally write variable name instead of full expression.
+ // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords.
+ Some(self.namer.call(name))
+ } else if self.need_bake_expressions.contains(&handle) {
+ Some(format!("_expr{}", handle.index()))
+ } else {
+ None
+ };
+
+ if let Some(name) = expr_name {
+ write!(self.out, "{level}")?;
+ self.write_named_expr(module, handle, name, handle, func_ctx)?;
+ }
+ }
+ }
+ // TODO: copy-paste from glsl-out
+ Statement::Block(ref block) => {
+ write!(self.out, "{level}")?;
+ writeln!(self.out, "{{")?;
+ for sta in block.iter() {
+ // Increase the indentation to help with readability
+ self.write_stmt(module, sta, func_ctx, level.next())?
+ }
+ writeln!(self.out, "{level}}}")?
+ }
+ // TODO: copy-paste from glsl-out
+ Statement::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ write!(self.out, "{level}")?;
+ write!(self.out, "if (")?;
+ self.write_expr(module, condition, func_ctx)?;
+ writeln!(self.out, ") {{")?;
+
+ let l2 = level.next();
+ for sta in accept {
+ // Increase indentation to help with readability
+ self.write_stmt(module, sta, func_ctx, l2)?;
+ }
+
+ // If there are no statements in the reject block we skip writing it
+ // This is only for readability
+ if !reject.is_empty() {
+ writeln!(self.out, "{level}}} else {{")?;
+
+ for sta in reject {
+ // Increase indentation to help with readability
+ self.write_stmt(module, sta, func_ctx, l2)?;
+ }
+ }
+
+ writeln!(self.out, "{level}}}")?
+ }
+ // TODO: copy-paste from glsl-out
+ Statement::Kill => writeln!(self.out, "{level}discard;")?,
+ Statement::Return { value: None } => {
+ writeln!(self.out, "{level}return;")?;
+ }
+ Statement::Return { value: Some(expr) } => {
+ let base_ty_res = &func_ctx.info[expr].ty;
+ let mut resolved = base_ty_res.inner_with(&module.types);
+ if let TypeInner::Pointer { base, space: _ } = *resolved {
+ resolved = &module.types[base].inner;
+ }
+
+ if let TypeInner::Struct { .. } = *resolved {
+ // We can safely unwrap here, since we now we working with struct
+ let ty = base_ty_res.handle().unwrap();
+ let struct_name = &self.names[&NameKey::Type(ty)];
+ let variable_name = self.namer.call(&struct_name.to_lowercase());
+ write!(self.out, "{level}const {struct_name} {variable_name} = ",)?;
+ self.write_expr(module, expr, func_ctx)?;
+ writeln!(self.out, ";")?;
+
+ // for entry point returns, we may need to reshuffle the outputs into a different struct
+ let ep_output = match func_ctx.ty {
+ back::FunctionType::Function(_) => None,
+ back::FunctionType::EntryPoint(index) => {
+ self.entry_point_io[index as usize].output.as_ref()
+ }
+ };
+ let final_name = match ep_output {
+ Some(ep_output) => {
+ let final_name = self.namer.call(&variable_name);
+ write!(
+ self.out,
+ "{}const {} {} = {{ ",
+ level, ep_output.ty_name, final_name,
+ )?;
+ for (index, m) in ep_output.members.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ let member_name = &self.names[&NameKey::StructMember(ty, m.index)];
+ write!(self.out, "{variable_name}.{member_name}")?;
+ }
+ writeln!(self.out, " }};")?;
+ final_name
+ }
+ None => variable_name,
+ };
+ writeln!(self.out, "{level}return {final_name};")?;
+ } else {
+ write!(self.out, "{level}return ")?;
+ self.write_expr(module, expr, func_ctx)?;
+ writeln!(self.out, ";")?
+ }
+ }
+ Statement::Store { pointer, value } => {
+ let ty_inner = func_ctx.resolve_type(pointer, &module.types);
+ if let Some(crate::AddressSpace::Storage { .. }) = ty_inner.pointer_space() {
+ let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
+ self.write_storage_store(
+ module,
+ var_handle,
+ StoreValue::Expression(value),
+ func_ctx,
+ level,
+ )?;
+ } else {
+ // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
+ // See the module-level block comment in mod.rs for details.
+ //
+ // We handle matrix Stores here directly (including sub accesses for Vectors and Scalars).
+ // Loads are handled by `Expression::AccessIndex` (since sub accesses work fine for Loads).
+ struct MatrixAccess {
+ base: Handle<crate::Expression>,
+ index: u32,
+ }
+ enum Index {
+ Expression(Handle<crate::Expression>),
+ Static(u32),
+ }
+
+ let get_members = |expr: Handle<crate::Expression>| {
+ let resolved = func_ctx.resolve_type(expr, &module.types);
+ match *resolved {
+ TypeInner::Pointer { base, .. } => match module.types[base].inner {
+ TypeInner::Struct { ref members, .. } => Some(members),
+ _ => None,
+ },
+ _ => None,
+ }
+ };
+
+ let mut matrix = None;
+ let mut vector = None;
+ let mut scalar = None;
+
+ let mut current_expr = pointer;
+ for _ in 0..3 {
+ let resolved = func_ctx.resolve_type(current_expr, &module.types);
+
+ match (resolved, &func_ctx.expressions[current_expr]) {
+ (
+ &TypeInner::Pointer { base: ty, .. },
+ &crate::Expression::AccessIndex { base, index },
+ ) if matches!(
+ module.types[ty].inner,
+ TypeInner::Matrix {
+ rows: crate::VectorSize::Bi,
+ ..
+ }
+ ) && get_members(base)
+ .map(|members| members[index as usize].binding.is_none())
+ == Some(true) =>
+ {
+ matrix = Some(MatrixAccess { base, index });
+ break;
+ }
+ (
+ &TypeInner::ValuePointer {
+ size: Some(crate::VectorSize::Bi),
+ ..
+ },
+ &crate::Expression::Access { base, index },
+ ) => {
+ vector = Some(Index::Expression(index));
+ current_expr = base;
+ }
+ (
+ &TypeInner::ValuePointer {
+ size: Some(crate::VectorSize::Bi),
+ ..
+ },
+ &crate::Expression::AccessIndex { base, index },
+ ) => {
+ vector = Some(Index::Static(index));
+ current_expr = base;
+ }
+ (
+ &TypeInner::ValuePointer { size: None, .. },
+ &crate::Expression::Access { base, index },
+ ) => {
+ scalar = Some(Index::Expression(index));
+ current_expr = base;
+ }
+ (
+ &TypeInner::ValuePointer { size: None, .. },
+ &crate::Expression::AccessIndex { base, index },
+ ) => {
+ scalar = Some(Index::Static(index));
+ current_expr = base;
+ }
+ _ => break,
+ }
+ }
+
+ write!(self.out, "{level}")?;
+
+ if let Some(MatrixAccess { index, base }) = matrix {
+ let base_ty_res = &func_ctx.info[base].ty;
+ let resolved = base_ty_res.inner_with(&module.types);
+ let ty = match *resolved {
+ TypeInner::Pointer { base, .. } => base,
+ _ => base_ty_res.handle().unwrap(),
+ };
+
+ if let Some(Index::Static(vec_index)) = vector {
+ self.write_expr(module, base, func_ctx)?;
+ write!(
+ self.out,
+ ".{}_{}",
+ &self.names[&NameKey::StructMember(ty, index)],
+ vec_index
+ )?;
+
+ if let Some(scalar_index) = scalar {
+ write!(self.out, "[")?;
+ match scalar_index {
+ Index::Static(index) => {
+ write!(self.out, "{index}")?;
+ }
+ Index::Expression(index) => {
+ self.write_expr(module, index, func_ctx)?;
+ }
+ }
+ write!(self.out, "]")?;
+ }
+
+ write!(self.out, " = ")?;
+ self.write_expr(module, value, func_ctx)?;
+ writeln!(self.out, ";")?;
+ } else {
+ let access = WrappedStructMatrixAccess { ty, index };
+ match (&vector, &scalar) {
+ (&Some(_), &Some(_)) => {
+ self.write_wrapped_struct_matrix_set_scalar_function_name(
+ access,
+ )?;
+ }
+ (&Some(_), &None) => {
+ self.write_wrapped_struct_matrix_set_vec_function_name(access)?;
+ }
+ (&None, _) => {
+ self.write_wrapped_struct_matrix_set_function_name(access)?;
+ }
+ }
+
+ write!(self.out, "(")?;
+ self.write_expr(module, base, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, value, func_ctx)?;
+
+ if let Some(Index::Expression(vec_index)) = vector {
+ write!(self.out, ", ")?;
+ self.write_expr(module, vec_index, func_ctx)?;
+
+ if let Some(scalar_index) = scalar {
+ write!(self.out, ", ")?;
+ match scalar_index {
+ Index::Static(index) => {
+ write!(self.out, "{index}")?;
+ }
+ Index::Expression(index) => {
+ self.write_expr(module, index, func_ctx)?;
+ }
+ }
+ }
+ }
+ writeln!(self.out, ");")?;
+ }
+ } else {
+ // We handle `Store`s to __matCx2 column vectors and scalar elements via
+ // the previously injected functions __set_col_of_matCx2 / __set_el_of_matCx2.
+ struct MatrixData {
+ columns: crate::VectorSize,
+ base: Handle<crate::Expression>,
+ }
+
+ enum Index {
+ Expression(Handle<crate::Expression>),
+ Static(u32),
+ }
+
+ let mut matrix = None;
+ let mut vector = None;
+ let mut scalar = None;
+
+ let mut current_expr = pointer;
+ for _ in 0..3 {
+ let resolved = func_ctx.resolve_type(current_expr, &module.types);
+ match (resolved, &func_ctx.expressions[current_expr]) {
+ (
+ &TypeInner::ValuePointer {
+ size: Some(crate::VectorSize::Bi),
+ ..
+ },
+ &crate::Expression::Access { base, index },
+ ) => {
+ vector = Some(index);
+ current_expr = base;
+ }
+ (
+ &TypeInner::ValuePointer { size: None, .. },
+ &crate::Expression::Access { base, index },
+ ) => {
+ scalar = Some(Index::Expression(index));
+ current_expr = base;
+ }
+ (
+ &TypeInner::ValuePointer { size: None, .. },
+ &crate::Expression::AccessIndex { base, index },
+ ) => {
+ scalar = Some(Index::Static(index));
+ current_expr = base;
+ }
+ _ => {
+ if let Some(MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = get_inner_matrix_of_struct_array_member(
+ module,
+ current_expr,
+ func_ctx,
+ true,
+ ) {
+ matrix = Some(MatrixData {
+ columns,
+ base: current_expr,
+ });
+ }
+
+ break;
+ }
+ }
+ }
+
+ if let (Some(MatrixData { columns, base }), Some(vec_index)) =
+ (matrix, vector)
+ {
+ if scalar.is_some() {
+ write!(self.out, "__set_el_of_mat{}x2", columns as u8)?;
+ } else {
+ write!(self.out, "__set_col_of_mat{}x2", columns as u8)?;
+ }
+ write!(self.out, "(")?;
+ self.write_expr(module, base, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, vec_index, func_ctx)?;
+
+ if let Some(scalar_index) = scalar {
+ write!(self.out, ", ")?;
+ match scalar_index {
+ Index::Static(index) => {
+ write!(self.out, "{index}")?;
+ }
+ Index::Expression(index) => {
+ self.write_expr(module, index, func_ctx)?;
+ }
+ }
+ }
+
+ write!(self.out, ", ")?;
+ self.write_expr(module, value, func_ctx)?;
+
+ writeln!(self.out, ");")?;
+ } else {
+ self.write_expr(module, pointer, func_ctx)?;
+ write!(self.out, " = ")?;
+
+ // We cast the RHS of this store in cases where the LHS
+ // is a struct member with type:
+ // - matCx2 or
+ // - a (possibly nested) array of matCx2's
+ if let Some(MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = get_inner_matrix_of_struct_array_member(
+ module, pointer, func_ctx, false,
+ ) {
+ let mut resolved = func_ctx.resolve_type(pointer, &module.types);
+ if let TypeInner::Pointer { base, .. } = *resolved {
+ resolved = &module.types[base].inner;
+ }
+
+ write!(self.out, "(__mat{}x2", columns as u8)?;
+ if let TypeInner::Array { base, size, .. } = *resolved {
+ self.write_array_size(module, base, size)?;
+ }
+ write!(self.out, ")")?;
+ }
+
+ self.write_expr(module, value, func_ctx)?;
+ writeln!(self.out, ";")?
+ }
+ }
+ }
+ }
+ Statement::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ let l2 = level.next();
+ if !continuing.is_empty() || break_if.is_some() {
+ let gate_name = self.namer.call("loop_init");
+ writeln!(self.out, "{level}bool {gate_name} = true;")?;
+ writeln!(self.out, "{level}while(true) {{")?;
+ writeln!(self.out, "{l2}if (!{gate_name}) {{")?;
+ let l3 = l2.next();
+ for sta in continuing.iter() {
+ self.write_stmt(module, sta, func_ctx, l3)?;
+ }
+ if let Some(condition) = break_if {
+ write!(self.out, "{l3}if (")?;
+ self.write_expr(module, condition, func_ctx)?;
+ writeln!(self.out, ") {{")?;
+ writeln!(self.out, "{}break;", l3.next())?;
+ writeln!(self.out, "{l3}}}")?;
+ }
+ writeln!(self.out, "{l2}}}")?;
+ writeln!(self.out, "{l2}{gate_name} = false;")?;
+ } else {
+ writeln!(self.out, "{level}while(true) {{")?;
+ }
+
+ for sta in body.iter() {
+ self.write_stmt(module, sta, func_ctx, l2)?;
+ }
+ writeln!(self.out, "{level}}}")?
+ }
+ Statement::Break => writeln!(self.out, "{level}break;")?,
+ Statement::Continue => writeln!(self.out, "{level}continue;")?,
+ Statement::Barrier(barrier) => {
+ self.write_barrier(barrier, level)?;
+ }
+ Statement::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => {
+ write!(self.out, "{level}")?;
+ self.write_expr(module, image, func_ctx)?;
+
+ write!(self.out, "[")?;
+ if let Some(index) = array_index {
+ // Array index accepted only for texture_storage_2d_array, so we can safety use int3(coordinate, array_index) here
+ write!(self.out, "int3(")?;
+ self.write_expr(module, coordinate, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, index, func_ctx)?;
+ write!(self.out, ")")?;
+ } else {
+ self.write_expr(module, coordinate, func_ctx)?;
+ }
+ write!(self.out, "]")?;
+
+ write!(self.out, " = ")?;
+ self.write_expr(module, value, func_ctx)?;
+ writeln!(self.out, ";")?;
+ }
+ Statement::Call {
+ function,
+ ref arguments,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ if let Some(expr) = result {
+ write!(self.out, "const ")?;
+ let name = format!("{}{}", back::BAKE_PREFIX, expr.index());
+ let expr_ty = &func_ctx.info[expr].ty;
+ match *expr_ty {
+ proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
+ proc::TypeResolution::Value(ref value) => {
+ self.write_value_type(module, value)?
+ }
+ };
+ write!(self.out, " {name} = ")?;
+ self.named_expressions.insert(expr, name);
+ }
+ let func_name = &self.names[&NameKey::Function(function)];
+ write!(self.out, "{func_name}(")?;
+ for (index, argument) in arguments.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ self.write_expr(module, *argument, func_ctx)?;
+ }
+ writeln!(self.out, ");")?
+ }
+ Statement::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ match func_ctx.info[result].ty {
+ proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
+ proc::TypeResolution::Value(ref value) => {
+ self.write_value_type(module, value)?
+ }
+ };
+
+ // Validation ensures that `pointer` has a `Pointer` type.
+ let pointer_space = func_ctx
+ .resolve_type(pointer, &module.types)
+ .pointer_space()
+ .unwrap();
+
+ let fun_str = fun.to_hlsl_suffix();
+ write!(self.out, " {res_name}; ")?;
+ match pointer_space {
+ crate::AddressSpace::WorkGroup => {
+ write!(self.out, "Interlocked{fun_str}(")?;
+ self.write_expr(module, pointer, func_ctx)?;
+ }
+ crate::AddressSpace::Storage { .. } => {
+ let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
+ // The call to `self.write_storage_address` wants
+ // mutable access to all of `self`, so temporarily take
+ // ownership of our reusable access chain buffer.
+ let chain = mem::take(&mut self.temp_access_chain);
+ let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
+ write!(self.out, "{var_name}.Interlocked{fun_str}(")?;
+ self.write_storage_address(module, &chain, func_ctx)?;
+ self.temp_access_chain = chain;
+ }
+ ref other => {
+ return Err(Error::Custom(format!(
+ "invalid address space {other:?} for atomic statement"
+ )))
+ }
+ }
+ write!(self.out, ", ")?;
+ // handle the special cases
+ match *fun {
+ crate::AtomicFunction::Subtract => {
+ // we just wrote `InterlockedAdd`, so negate the argument
+ write!(self.out, "-")?;
+ }
+ crate::AtomicFunction::Exchange { compare: Some(_) } => {
+ return Err(Error::Unimplemented("atomic CompareExchange".to_string()));
+ }
+ _ => {}
+ }
+ self.write_expr(module, value, func_ctx)?;
+ writeln!(self.out, ", {res_name});")?;
+ self.named_expressions.insert(result, res_name);
+ }
+ Statement::WorkGroupUniformLoad { pointer, result } => {
+ self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
+ write!(self.out, "{level}")?;
+ let name = format!("_expr{}", result.index());
+ self.write_named_expr(module, pointer, name, result, func_ctx)?;
+
+ self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
+ }
+ Statement::Switch {
+ selector,
+ ref cases,
+ } => {
+ // Start the switch
+ write!(self.out, "{level}")?;
+ write!(self.out, "switch(")?;
+ self.write_expr(module, selector, func_ctx)?;
+ writeln!(self.out, ") {{")?;
+
+ // Write all cases
+ let indent_level_1 = level.next();
+ let indent_level_2 = indent_level_1.next();
+
+ for (i, case) in cases.iter().enumerate() {
+ match case.value {
+ crate::SwitchValue::I32(value) => {
+ write!(self.out, "{indent_level_1}case {value}:")?
+ }
+ crate::SwitchValue::U32(value) => {
+ write!(self.out, "{indent_level_1}case {value}u:")?
+ }
+ crate::SwitchValue::Default => {
+ write!(self.out, "{indent_level_1}default:")?
+ }
+ }
+
+ // The new block is not only stylistic, it plays a role here:
+ // We might end up having to write the same case body
+ // multiple times due to FXC not supporting fallthrough.
+ // Therefore, some `Expression`s written by `Statement::Emit`
+ // will end up having the same name (`_expr<handle_index>`).
+ // So we need to put each case in its own scope.
+ let write_block_braces = !(case.fall_through && case.body.is_empty());
+ if write_block_braces {
+ writeln!(self.out, " {{")?;
+ } else {
+ writeln!(self.out)?;
+ }
+
+ // Although FXC does support a series of case clauses before
+ // a block[^yes], it does not support fallthrough from a
+ // non-empty case block to the next[^no]. If this case has a
+ // non-empty body with a fallthrough, emulate that by
+ // duplicating the bodies of all the cases it would fall
+ // into as extensions of this case's own body. This makes
+ // the HLSL output potentially quadratic in the size of the
+ // Naga IR.
+ //
+ // [^yes]: ```hlsl
+ // case 1:
+ // case 2: do_stuff()
+ // ```
+ // [^no]: ```hlsl
+ // case 1: do_this();
+ // case 2: do_that();
+ // ```
+ if case.fall_through && !case.body.is_empty() {
+ let curr_len = i + 1;
+ let end_case_idx = curr_len
+ + cases
+ .iter()
+ .skip(curr_len)
+ .position(|case| !case.fall_through)
+ .unwrap();
+ let indent_level_3 = indent_level_2.next();
+ for case in &cases[i..=end_case_idx] {
+ writeln!(self.out, "{indent_level_2}{{")?;
+ let prev_len = self.named_expressions.len();
+ for sta in case.body.iter() {
+ self.write_stmt(module, sta, func_ctx, indent_level_3)?;
+ }
+ // Clear all named expressions that were previously inserted by the statements in the block
+ self.named_expressions.truncate(prev_len);
+ writeln!(self.out, "{indent_level_2}}}")?;
+ }
+
+ let last_case = &cases[end_case_idx];
+ if last_case.body.last().map_or(true, |s| !s.is_terminator()) {
+ writeln!(self.out, "{indent_level_2}break;")?;
+ }
+ } else {
+ for sta in case.body.iter() {
+ self.write_stmt(module, sta, func_ctx, indent_level_2)?;
+ }
+ if !case.fall_through
+ && case.body.last().map_or(true, |s| !s.is_terminator())
+ {
+ writeln!(self.out, "{indent_level_2}break;")?;
+ }
+ }
+
+ if write_block_braces {
+ writeln!(self.out, "{indent_level_1}}}")?;
+ }
+ }
+
+ writeln!(self.out, "{level}}}")?
+ }
+ Statement::RayQuery { .. } => unreachable!(),
+ }
+
+ Ok(())
+ }
+
+ fn write_const_expression(
+ &mut self,
+ module: &Module,
+ expr: Handle<crate::Expression>,
+ ) -> BackendResult {
+ self.write_possibly_const_expression(
+ module,
+ expr,
+ &module.const_expressions,
+ |writer, expr| writer.write_const_expression(module, expr),
+ )
+ }
+
+ fn write_possibly_const_expression<E>(
+ &mut self,
+ module: &Module,
+ expr: Handle<crate::Expression>,
+ expressions: &crate::Arena<crate::Expression>,
+ write_expression: E,
+ ) -> BackendResult
+ where
+ E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
+ {
+ use crate::Expression;
+
+ match expressions[expr] {
+ Expression::Literal(literal) => match literal {
+ // Floats are written using `Debug` instead of `Display` because it always appends the
+ // decimal part even it's zero
+ crate::Literal::F64(value) => write!(self.out, "{value:?}L")?,
+ crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
+ crate::Literal::U32(value) => write!(self.out, "{}u", value)?,
+ crate::Literal::I32(value) => write!(self.out, "{}", value)?,
+ crate::Literal::I64(value) => write!(self.out, "{}L", value)?,
+ crate::Literal::Bool(value) => write!(self.out, "{}", value)?,
+ crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
+ return Err(Error::Custom(
+ "Abstract types should not appear in IR presented to backends".into(),
+ ));
+ }
+ },
+ Expression::Constant(handle) => {
+ let constant = &module.constants[handle];
+ if constant.name.is_some() {
+ write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
+ } else {
+ self.write_const_expression(module, constant.init)?;
+ }
+ }
+ Expression::ZeroValue(ty) => self.write_default_init(module, ty)?,
+ Expression::Compose { ty, ref components } => {
+ match module.types[ty].inner {
+ TypeInner::Struct { .. } | TypeInner::Array { .. } => {
+ self.write_wrapped_constructor_function_name(
+ module,
+ WrappedConstructor { ty },
+ )?;
+ }
+ _ => {
+ self.write_type(module, ty)?;
+ }
+ };
+ write!(self.out, "(")?;
+ for (index, component) in components.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ write_expression(self, *component)?;
+ }
+ write!(self.out, ")")?;
+ }
+ Expression::Splat { size, value } => {
+ // hlsl is not supported one value constructor
+ // if we write, for example, int4(0), dxc returns error:
+ // error: too few elements in vector initialization (expected 4 elements, have 1)
+ let number_of_components = match size {
+ crate::VectorSize::Bi => "xx",
+ crate::VectorSize::Tri => "xxx",
+ crate::VectorSize::Quad => "xxxx",
+ };
+ write!(self.out, "(")?;
+ write_expression(self, value)?;
+ write!(self.out, ").{number_of_components}")?
+ }
+ _ => unreachable!(),
+ }
+
+ Ok(())
+ }
+
+ /// Helper method to write expressions
+ ///
+ /// # Notes
+ /// Doesn't add any newlines or leading/trailing spaces
+ pub(super) fn write_expr(
+ &mut self,
+ module: &Module,
+ expr: Handle<crate::Expression>,
+ func_ctx: &back::FunctionCtx<'_>,
+ ) -> BackendResult {
+ use crate::Expression;
+
+ // Handle the special semantics of vertex_index/instance_index
+ let ff_input = if self.options.special_constants_binding.is_some() {
+ func_ctx.is_fixed_function_input(expr, module)
+ } else {
+ None
+ };
+ let closing_bracket = match ff_input {
+ Some(crate::BuiltIn::VertexIndex) => {
+ write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX} + ")?;
+ ")"
+ }
+ Some(crate::BuiltIn::InstanceIndex) => {
+ write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE} + ",)?;
+ ")"
+ }
+ Some(crate::BuiltIn::NumWorkGroups) => {
+ // Note: despite their names (`FIRST_VERTEX` and `FIRST_INSTANCE`),
+ // in compute shaders the special constants contain the number
+ // of workgroups, which we are using here.
+ write!(
+ self.out,
+ "uint3({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX}, {SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE}, {SPECIAL_CBUF_VAR}.{SPECIAL_OTHER})",
+ )?;
+ return Ok(());
+ }
+ _ => "",
+ };
+
+ if let Some(name) = self.named_expressions.get(&expr) {
+ write!(self.out, "{name}{closing_bracket}")?;
+ return Ok(());
+ }
+
+ let expression = &func_ctx.expressions[expr];
+
+ match *expression {
+ Expression::Literal(_)
+ | Expression::Constant(_)
+ | Expression::ZeroValue(_)
+ | Expression::Compose { .. }
+ | Expression::Splat { .. } => {
+ self.write_possibly_const_expression(
+ module,
+ expr,
+ func_ctx.expressions,
+ |writer, expr| writer.write_expr(module, expr, func_ctx),
+ )?;
+ }
+ // All of the multiplication can be expressed as `mul`,
+ // except vector * vector, which needs to use the "*" operator.
+ Expression::Binary {
+ op: crate::BinaryOperator::Multiply,
+ left,
+ right,
+ } if func_ctx.resolve_type(left, &module.types).is_matrix()
+ || func_ctx.resolve_type(right, &module.types).is_matrix() =>
+ {
+ // We intentionally flip the order of multiplication as our matrices are implicitly transposed.
+ write!(self.out, "mul(")?;
+ self.write_expr(module, right, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, left, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+
+ // TODO: handle undefined behavior of BinaryOperator::Modulo
+ //
+ // sint:
+ // if right == 0 return 0
+ // if left == min(type_of(left)) && right == -1 return 0
+ // if sign(left) != sign(right) return result as defined by WGSL
+ //
+ // uint:
+ // if right == 0 return 0
+ //
+ // float:
+ // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798
+
+ // While HLSL supports float operands with the % operator it is only
+ // defined in cases where both sides are either positive or negative.
+ Expression::Binary {
+ op: crate::BinaryOperator::Modulo,
+ left,
+ right,
+ } if func_ctx.resolve_type(left, &module.types).scalar_kind()
+ == Some(crate::ScalarKind::Float) =>
+ {
+ write!(self.out, "fmod(")?;
+ self.write_expr(module, left, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, right, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ Expression::Binary { op, left, right } => {
+ write!(self.out, "(")?;
+ self.write_expr(module, left, func_ctx)?;
+ write!(self.out, " {} ", crate::back::binary_operation_str(op))?;
+ self.write_expr(module, right, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ Expression::Access { base, index } => {
+ if let Some(crate::AddressSpace::Storage { .. }) =
+ func_ctx.resolve_type(expr, &module.types).pointer_space()
+ {
+ // do nothing, the chain is written on `Load`/`Store`
+ } else {
+ // We use the function __get_col_of_matCx2 here in cases
+ // where `base`s type resolves to a matCx2 and is part of a
+ // struct member with type of (possibly nested) array of matCx2's.
+ //
+ // Note that this only works for `Load`s and we handle
+ // `Store`s differently in `Statement::Store`.
+ if let Some(MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
+ {
+ write!(self.out, "__get_col_of_mat{}x2(", columns as u8)?;
+ self.write_expr(module, base, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, index, func_ctx)?;
+ write!(self.out, ")")?;
+ return Ok(());
+ }
+
+ let resolved = func_ctx.resolve_type(base, &module.types);
+
+ let non_uniform_qualifier = match *resolved {
+ TypeInner::BindingArray { .. } => {
+ let uniformity = &func_ctx.info[index].uniformity;
+
+ uniformity.non_uniform_result.is_some()
+ }
+ _ => false,
+ };
+
+ self.write_expr(module, base, func_ctx)?;
+ write!(self.out, "[")?;
+ if non_uniform_qualifier {
+ write!(self.out, "NonUniformResourceIndex(")?;
+ }
+ self.write_expr(module, index, func_ctx)?;
+ if non_uniform_qualifier {
+ write!(self.out, ")")?;
+ }
+ write!(self.out, "]")?;
+ }
+ }
+ Expression::AccessIndex { base, index } => {
+ if let Some(crate::AddressSpace::Storage { .. }) =
+ func_ctx.resolve_type(expr, &module.types).pointer_space()
+ {
+ // do nothing, the chain is written on `Load`/`Store`
+ } else {
+ fn write_access<W: fmt::Write>(
+ writer: &mut super::Writer<'_, W>,
+ resolved: &TypeInner,
+ base_ty_handle: Option<Handle<crate::Type>>,
+ index: u32,
+ ) -> BackendResult {
+ match *resolved {
+ // We specifically lift the ValuePointer to this case. While `[0]` is valid
+ // HLSL for any vector behind a value pointer, FXC completely miscompiles
+ // it and generates completely nonsensical DXBC.
+ //
+ // See https://github.com/gfx-rs/naga/issues/2095 for more details.
+ TypeInner::Vector { .. } | TypeInner::ValuePointer { .. } => {
+ // Write vector access as a swizzle
+ write!(writer.out, ".{}", back::COMPONENTS[index as usize])?
+ }
+ TypeInner::Matrix { .. }
+ | TypeInner::Array { .. }
+ | TypeInner::BindingArray { .. } => write!(writer.out, "[{index}]")?,
+ TypeInner::Struct { .. } => {
+ // This will never panic in case the type is a `Struct`, this is not true
+ // for other types so we can only check while inside this match arm
+ let ty = base_ty_handle.unwrap();
+
+ write!(
+ writer.out,
+ ".{}",
+ &writer.names[&NameKey::StructMember(ty, index)]
+ )?
+ }
+ ref other => {
+ return Err(Error::Custom(format!("Cannot index {other:?}")))
+ }
+ }
+ Ok(())
+ }
+
+ // We write the matrix column access in a special way since
+ // the type of `base` is our special __matCx2 struct.
+ if let Some(MatrixType {
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ ..
+ }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
+ {
+ self.write_expr(module, base, func_ctx)?;
+ write!(self.out, "._{index}")?;
+ return Ok(());
+ }
+
+ let base_ty_res = &func_ctx.info[base].ty;
+ let mut resolved = base_ty_res.inner_with(&module.types);
+ let base_ty_handle = match *resolved {
+ TypeInner::Pointer { base, .. } => {
+ resolved = &module.types[base].inner;
+ Some(base)
+ }
+ _ => base_ty_res.handle(),
+ };
+
+ // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
+ // See the module-level block comment in mod.rs for details.
+ //
+ // We handle matrix reconstruction here for Loads.
+ // Stores are handled directly by `Statement::Store`.
+ if let TypeInner::Struct { ref members, .. } = *resolved {
+ let member = &members[index as usize];
+
+ match module.types[member.ty].inner {
+ TypeInner::Matrix {
+ rows: crate::VectorSize::Bi,
+ ..
+ } if member.binding.is_none() => {
+ let ty = base_ty_handle.unwrap();
+ self.write_wrapped_struct_matrix_get_function_name(
+ WrappedStructMatrixAccess { ty, index },
+ )?;
+ write!(self.out, "(")?;
+ self.write_expr(module, base, func_ctx)?;
+ write!(self.out, ")")?;
+ return Ok(());
+ }
+ _ => {}
+ }
+ }
+
+ self.write_expr(module, base, func_ctx)?;
+ write_access(self, resolved, base_ty_handle, index)?;
+ }
+ }
+ Expression::FunctionArgument(pos) => {
+ let key = func_ctx.argument_key(pos);
+ let name = &self.names[&key];
+ write!(self.out, "{name}")?;
+ }
+ Expression::ImageSample {
+ image,
+ sampler,
+ gather,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ } => {
+ use crate::SampleLevel as Sl;
+ const COMPONENTS: [&str; 4] = ["", "Green", "Blue", "Alpha"];
+
+ let (base_str, component_str) = match gather {
+ Some(component) => ("Gather", COMPONENTS[component as usize]),
+ None => ("Sample", ""),
+ };
+ let cmp_str = match depth_ref {
+ Some(_) => "Cmp",
+ None => "",
+ };
+ let level_str = match level {
+ Sl::Zero if gather.is_none() => "LevelZero",
+ Sl::Auto | Sl::Zero => "",
+ Sl::Exact(_) => "Level",
+ Sl::Bias(_) => "Bias",
+ Sl::Gradient { .. } => "Grad",
+ };
+
+ self.write_expr(module, image, func_ctx)?;
+ write!(self.out, ".{base_str}{cmp_str}{component_str}{level_str}(")?;
+ self.write_expr(module, sampler, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_texture_coordinates(
+ "float",
+ coordinate,
+ array_index,
+ None,
+ module,
+ func_ctx,
+ )?;
+
+ if let Some(depth_ref) = depth_ref {
+ write!(self.out, ", ")?;
+ self.write_expr(module, depth_ref, func_ctx)?;
+ }
+
+ match level {
+ Sl::Auto | Sl::Zero => {}
+ Sl::Exact(expr) => {
+ write!(self.out, ", ")?;
+ self.write_expr(module, expr, func_ctx)?;
+ }
+ Sl::Bias(expr) => {
+ write!(self.out, ", ")?;
+ self.write_expr(module, expr, func_ctx)?;
+ }
+ Sl::Gradient { x, y } => {
+ write!(self.out, ", ")?;
+ self.write_expr(module, x, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, y, func_ctx)?;
+ }
+ }
+
+ if let Some(offset) = offset {
+ write!(self.out, ", ")?;
+ write!(self.out, "int2(")?; // work around https://github.com/microsoft/DirectXShaderCompiler/issues/5082#issuecomment-1540147807
+ self.write_const_expression(module, offset)?;
+ write!(self.out, ")")?;
+ }
+
+ write!(self.out, ")")?;
+ }
+ Expression::ImageQuery { image, query } => {
+ // use wrapped image query function
+ if let TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } = *func_ctx.resolve_type(image, &module.types)
+ {
+ let wrapped_image_query = WrappedImageQuery {
+ dim,
+ arrayed,
+ class,
+ query: query.into(),
+ };
+
+ self.write_wrapped_image_query_function_name(wrapped_image_query)?;
+ write!(self.out, "(")?;
+ // Image always first param
+ self.write_expr(module, image, func_ctx)?;
+ if let crate::ImageQuery::Size { level: Some(level) } = query {
+ write!(self.out, ", ")?;
+ self.write_expr(module, level, func_ctx)?;
+ }
+ write!(self.out, ")")?;
+ }
+ }
+ Expression::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => {
+ // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-load
+ self.write_expr(module, image, func_ctx)?;
+ write!(self.out, ".Load(")?;
+
+ self.write_texture_coordinates(
+ "int",
+ coordinate,
+ array_index,
+ level,
+ module,
+ func_ctx,
+ )?;
+
+ if let Some(sample) = sample {
+ write!(self.out, ", ")?;
+ self.write_expr(module, sample, func_ctx)?;
+ }
+
+ // close bracket for Load function
+ write!(self.out, ")")?;
+
+ // return x component if return type is scalar
+ if let TypeInner::Scalar(_) = *func_ctx.resolve_type(expr, &module.types) {
+ write!(self.out, ".x")?;
+ }
+ }
+ Expression::GlobalVariable(handle) => match module.global_variables[handle].space {
+ crate::AddressSpace::Storage { .. } => {}
+ _ => {
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ write!(self.out, "{name}")?;
+ }
+ },
+ Expression::LocalVariable(handle) => {
+ write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
+ }
+ Expression::Load { pointer } => {
+ match func_ctx
+ .resolve_type(pointer, &module.types)
+ .pointer_space()
+ {
+ Some(crate::AddressSpace::Storage { .. }) => {
+ let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
+ let result_ty = func_ctx.info[expr].ty.clone();
+ self.write_storage_load(module, var_handle, result_ty, func_ctx)?;
+ }
+ _ => {
+ let mut close_paren = false;
+
+ // We cast the value loaded to a native HLSL floatCx2
+ // in cases where it is of type:
+ // - __matCx2 or
+ // - a (possibly nested) array of __matCx2's
+ if let Some(MatrixType {
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ ..
+ }) = get_inner_matrix_of_struct_array_member(
+ module, pointer, func_ctx, false,
+ )
+ .or_else(|| get_inner_matrix_of_global_uniform(module, pointer, func_ctx))
+ {
+ let mut resolved = func_ctx.resolve_type(pointer, &module.types);
+ if let TypeInner::Pointer { base, .. } = *resolved {
+ resolved = &module.types[base].inner;
+ }
+
+ write!(self.out, "((")?;
+ if let TypeInner::Array { base, size, .. } = *resolved {
+ self.write_type(module, base)?;
+ self.write_array_size(module, base, size)?;
+ } else {
+ self.write_value_type(module, resolved)?;
+ }
+ write!(self.out, ")")?;
+ close_paren = true;
+ }
+
+ self.write_expr(module, pointer, func_ctx)?;
+
+ if close_paren {
+ write!(self.out, ")")?;
+ }
+ }
+ }
+ }
+ Expression::Unary { op, expr } => {
+ // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-operators#unary-operators
+ let op_str = match op {
+ crate::UnaryOperator::Negate => "-",
+ crate::UnaryOperator::LogicalNot => "!",
+ crate::UnaryOperator::BitwiseNot => "~",
+ };
+ write!(self.out, "{op_str}(")?;
+ self.write_expr(module, expr, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ Expression::As {
+ expr,
+ kind,
+ convert,
+ } => {
+ let inner = func_ctx.resolve_type(expr, &module.types);
+ match convert {
+ Some(dst_width) => {
+ let scalar = crate::Scalar {
+ kind,
+ width: dst_width,
+ };
+ match *inner {
+ TypeInner::Vector { size, .. } => {
+ write!(
+ self.out,
+ "{}{}(",
+ scalar.to_hlsl_str()?,
+ back::vector_size_str(size)
+ )?;
+ }
+ TypeInner::Scalar(_) => {
+ write!(self.out, "{}(", scalar.to_hlsl_str()?,)?;
+ }
+ TypeInner::Matrix { columns, rows, .. } => {
+ write!(
+ self.out,
+ "{}{}x{}(",
+ scalar.to_hlsl_str()?,
+ back::vector_size_str(columns),
+ back::vector_size_str(rows)
+ )?;
+ }
+ _ => {
+ return Err(Error::Unimplemented(format!(
+ "write_expr expression::as {inner:?}"
+ )));
+ }
+ };
+ }
+ None => {
+ write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
+ }
+ }
+ self.write_expr(module, expr, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ use crate::MathFunction as Mf;
+
+ enum Function {
+ Asincosh { is_sin: bool },
+ Atanh,
+ ExtractBits,
+ InsertBits,
+ Pack2x16float,
+ Pack2x16snorm,
+ Pack2x16unorm,
+ Pack4x8snorm,
+ Pack4x8unorm,
+ Unpack2x16float,
+ Unpack2x16snorm,
+ Unpack2x16unorm,
+ Unpack4x8snorm,
+ Unpack4x8unorm,
+ Regular(&'static str),
+ MissingIntOverload(&'static str),
+ MissingIntReturnType(&'static str),
+ CountTrailingZeros,
+ CountLeadingZeros,
+ }
+
+ let fun = match fun {
+ // comparison
+ Mf::Abs => Function::Regular("abs"),
+ Mf::Min => Function::Regular("min"),
+ Mf::Max => Function::Regular("max"),
+ Mf::Clamp => Function::Regular("clamp"),
+ Mf::Saturate => Function::Regular("saturate"),
+ // trigonometry
+ Mf::Cos => Function::Regular("cos"),
+ Mf::Cosh => Function::Regular("cosh"),
+ Mf::Sin => Function::Regular("sin"),
+ Mf::Sinh => Function::Regular("sinh"),
+ Mf::Tan => Function::Regular("tan"),
+ Mf::Tanh => Function::Regular("tanh"),
+ Mf::Acos => Function::Regular("acos"),
+ Mf::Asin => Function::Regular("asin"),
+ Mf::Atan => Function::Regular("atan"),
+ Mf::Atan2 => Function::Regular("atan2"),
+ Mf::Asinh => Function::Asincosh { is_sin: true },
+ Mf::Acosh => Function::Asincosh { is_sin: false },
+ Mf::Atanh => Function::Atanh,
+ Mf::Radians => Function::Regular("radians"),
+ Mf::Degrees => Function::Regular("degrees"),
+ // decomposition
+ Mf::Ceil => Function::Regular("ceil"),
+ Mf::Floor => Function::Regular("floor"),
+ Mf::Round => Function::Regular("round"),
+ Mf::Fract => Function::Regular("frac"),
+ Mf::Trunc => Function::Regular("trunc"),
+ Mf::Modf => Function::Regular(MODF_FUNCTION),
+ Mf::Frexp => Function::Regular(FREXP_FUNCTION),
+ Mf::Ldexp => Function::Regular("ldexp"),
+ // exponent
+ Mf::Exp => Function::Regular("exp"),
+ Mf::Exp2 => Function::Regular("exp2"),
+ Mf::Log => Function::Regular("log"),
+ Mf::Log2 => Function::Regular("log2"),
+ Mf::Pow => Function::Regular("pow"),
+ // geometry
+ Mf::Dot => Function::Regular("dot"),
+ //Mf::Outer => ,
+ Mf::Cross => Function::Regular("cross"),
+ Mf::Distance => Function::Regular("distance"),
+ Mf::Length => Function::Regular("length"),
+ Mf::Normalize => Function::Regular("normalize"),
+ Mf::FaceForward => Function::Regular("faceforward"),
+ Mf::Reflect => Function::Regular("reflect"),
+ Mf::Refract => Function::Regular("refract"),
+ // computational
+ Mf::Sign => Function::Regular("sign"),
+ Mf::Fma => Function::Regular("mad"),
+ Mf::Mix => Function::Regular("lerp"),
+ Mf::Step => Function::Regular("step"),
+ Mf::SmoothStep => Function::Regular("smoothstep"),
+ Mf::Sqrt => Function::Regular("sqrt"),
+ Mf::InverseSqrt => Function::Regular("rsqrt"),
+ //Mf::Inverse =>,
+ Mf::Transpose => Function::Regular("transpose"),
+ Mf::Determinant => Function::Regular("determinant"),
+ // bits
+ Mf::CountTrailingZeros => Function::CountTrailingZeros,
+ Mf::CountLeadingZeros => Function::CountLeadingZeros,
+ Mf::CountOneBits => Function::MissingIntOverload("countbits"),
+ Mf::ReverseBits => Function::MissingIntOverload("reversebits"),
+ Mf::FindLsb => Function::MissingIntReturnType("firstbitlow"),
+ Mf::FindMsb => Function::MissingIntReturnType("firstbithigh"),
+ Mf::ExtractBits => Function::ExtractBits,
+ Mf::InsertBits => Function::InsertBits,
+ // Data Packing
+ Mf::Pack2x16float => Function::Pack2x16float,
+ Mf::Pack2x16snorm => Function::Pack2x16snorm,
+ Mf::Pack2x16unorm => Function::Pack2x16unorm,
+ Mf::Pack4x8snorm => Function::Pack4x8snorm,
+ Mf::Pack4x8unorm => Function::Pack4x8unorm,
+ // Data Unpacking
+ Mf::Unpack2x16float => Function::Unpack2x16float,
+ Mf::Unpack2x16snorm => Function::Unpack2x16snorm,
+ Mf::Unpack2x16unorm => Function::Unpack2x16unorm,
+ Mf::Unpack4x8snorm => Function::Unpack4x8snorm,
+ Mf::Unpack4x8unorm => Function::Unpack4x8unorm,
+ _ => return Err(Error::Unimplemented(format!("write_expr_math {fun:?}"))),
+ };
+
+ match fun {
+ Function::Asincosh { is_sin } => {
+ write!(self.out, "log(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " + sqrt(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " * ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ match is_sin {
+ true => write!(self.out, " + 1.0))")?,
+ false => write!(self.out, " - 1.0))")?,
+ }
+ }
+ Function::Atanh => {
+ write!(self.out, "0.5 * log((1.0 + ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ") / (1.0 - ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "))")?;
+ }
+ Function::ExtractBits => {
+ // e: T,
+ // offset: u32,
+ // count: u32
+ // T is u32 or i32 or vecN<u32> or vecN<i32>
+ if let (Some(offset), Some(count)) = (arg1, arg2) {
+ let scalar_width: u8 = 32;
+ // Works for signed and unsigned
+ // (count == 0 ? 0 : (e << (32 - count - offset)) >> (32 - count))
+ write!(self.out, "(")?;
+ self.write_expr(module, count, func_ctx)?;
+ write!(self.out, " == 0 ? 0 : (")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " << ({scalar_width} - ")?;
+ self.write_expr(module, count, func_ctx)?;
+ write!(self.out, " - ")?;
+ self.write_expr(module, offset, func_ctx)?;
+ write!(self.out, ")) >> ({scalar_width} - ")?;
+ self.write_expr(module, count, func_ctx)?;
+ write!(self.out, "))")?;
+ }
+ }
+ Function::InsertBits => {
+ // e: T,
+ // newbits: T,
+ // offset: u32,
+ // count: u32
+ // returns T
+ // T is i32, u32, vecN<i32>, or vecN<u32>
+ if let (Some(newbits), Some(offset), Some(count)) = (arg1, arg2, arg3) {
+ let scalar_width: u8 = 32;
+ let scalar_max: u32 = 0xFFFFFFFF;
+ // mask = ((0xFFFFFFFFu >> (32 - count)) << offset)
+ // (count == 0 ? e : ((e & ~mask) | ((newbits << offset) & mask)))
+ write!(self.out, "(")?;
+ self.write_expr(module, count, func_ctx)?;
+ write!(self.out, " == 0 ? ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " : ")?;
+ write!(self.out, "(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " & ~")?;
+ // mask
+ write!(self.out, "(({scalar_max}u >> ({scalar_width}u - ")?;
+ self.write_expr(module, count, func_ctx)?;
+ write!(self.out, ")) << ")?;
+ self.write_expr(module, offset, func_ctx)?;
+ write!(self.out, ")")?;
+ // end mask
+ write!(self.out, ") | ((")?;
+ self.write_expr(module, newbits, func_ctx)?;
+ write!(self.out, " << ")?;
+ self.write_expr(module, offset, func_ctx)?;
+ write!(self.out, ") & ")?;
+ // // mask
+ write!(self.out, "(({scalar_max}u >> ({scalar_width}u - ")?;
+ self.write_expr(module, count, func_ctx)?;
+ write!(self.out, ")) << ")?;
+ self.write_expr(module, offset, func_ctx)?;
+ write!(self.out, ")")?;
+ // // end mask
+ write!(self.out, "))")?;
+ }
+ }
+ Function::Pack2x16float => {
+ write!(self.out, "(f32tof16(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "[0]) | f32tof16(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "[1]) << 16)")?;
+ }
+ Function::Pack2x16snorm => {
+ let scale = 32767;
+
+ write!(self.out, "uint((int(round(clamp(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(
+ self.out,
+ "[0], -1.0, 1.0) * {scale}.0)) & 0xFFFF) | ((int(round(clamp("
+ )?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "[1], -1.0, 1.0) * {scale}.0)) & 0xFFFF) << 16))",)?;
+ }
+ Function::Pack2x16unorm => {
+ let scale = 65535;
+
+ write!(self.out, "(uint(round(clamp(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "[1], 0.0, 1.0) * {scale}.0)) << 16)")?;
+ }
+ Function::Pack4x8snorm => {
+ let scale = 127;
+
+ write!(self.out, "uint((int(round(clamp(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(
+ self.out,
+ "[0], -1.0, 1.0) * {scale}.0)) & 0xFF) | ((int(round(clamp("
+ )?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(
+ self.out,
+ "[1], -1.0, 1.0) * {scale}.0)) & 0xFF) << 8) | ((int(round(clamp("
+ )?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(
+ self.out,
+ "[2], -1.0, 1.0) * {scale}.0)) & 0xFF) << 16) | ((int(round(clamp("
+ )?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "[3], -1.0, 1.0) * {scale}.0)) & 0xFF) << 24))",)?;
+ }
+ Function::Pack4x8unorm => {
+ let scale = 255;
+
+ write!(self.out, "(uint(round(clamp(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(
+ self.out,
+ "[1], 0.0, 1.0) * {scale}.0)) << 8 | uint(round(clamp("
+ )?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(
+ self.out,
+ "[2], 0.0, 1.0) * {scale}.0)) << 16 | uint(round(clamp("
+ )?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "[3], 0.0, 1.0) * {scale}.0)) << 24)")?;
+ }
+
+ Function::Unpack2x16float => {
+ write!(self.out, "float2(f16tof32(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "), f16tof32((")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ") >> 16))")?;
+ }
+ Function::Unpack2x16snorm => {
+ let scale = 32767;
+
+ write!(self.out, "(float2(int2(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " << 16, ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ") >> 16) / {scale}.0)")?;
+ }
+ Function::Unpack2x16unorm => {
+ let scale = 65535;
+
+ write!(self.out, "(float2(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " & 0xFFFF, ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " >> 16) / {scale}.0)")?;
+ }
+ Function::Unpack4x8snorm => {
+ let scale = 127;
+
+ write!(self.out, "(float4(int4(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " << 24, ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " << 16, ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " << 8, ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ") >> 24) / {scale}.0)")?;
+ }
+ Function::Unpack4x8unorm => {
+ let scale = 255;
+
+ write!(self.out, "(float4(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " & 0xFF, ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " >> 8 & 0xFF, ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " >> 16 & 0xFF, ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " >> 24) / {scale}.0)")?;
+ }
+ Function::Regular(fun_name) => {
+ write!(self.out, "{fun_name}(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ if let Some(arg) = arg1 {
+ write!(self.out, ", ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ }
+ if let Some(arg) = arg2 {
+ write!(self.out, ", ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ }
+ if let Some(arg) = arg3 {
+ write!(self.out, ", ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ }
+ write!(self.out, ")")?
+ }
+ Function::MissingIntOverload(fun_name) => {
+ let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar_kind();
+ if let Some(ScalarKind::Sint) = scalar_kind {
+ write!(self.out, "asint({fun_name}(asuint(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ")))")?;
+ } else {
+ write!(self.out, "{fun_name}(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ }
+ Function::MissingIntReturnType(fun_name) => {
+ let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar_kind();
+ if let Some(ScalarKind::Sint) = scalar_kind {
+ write!(self.out, "asint({fun_name}(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "))")?;
+ } else {
+ write!(self.out, "{fun_name}(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ }
+ Function::CountTrailingZeros => {
+ match *func_ctx.resolve_type(arg, &module.types) {
+ TypeInner::Vector { size, scalar } => {
+ let s = match size {
+ crate::VectorSize::Bi => ".xx",
+ crate::VectorSize::Tri => ".xxx",
+ crate::VectorSize::Quad => ".xxxx",
+ };
+
+ if let ScalarKind::Uint = scalar.kind {
+ write!(self.out, "min((32u){s}, firstbitlow(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "))")?;
+ } else {
+ write!(self.out, "asint(min((32u){s}, firstbitlow(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ")))")?;
+ }
+ }
+ TypeInner::Scalar(scalar) => {
+ if let ScalarKind::Uint = scalar.kind {
+ write!(self.out, "min(32u, firstbitlow(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "))")?;
+ } else {
+ write!(self.out, "asint(min(32u, firstbitlow(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ")))")?;
+ }
+ }
+ _ => unreachable!(),
+ }
+
+ return Ok(());
+ }
+ Function::CountLeadingZeros => {
+ match *func_ctx.resolve_type(arg, &module.types) {
+ TypeInner::Vector { size, scalar } => {
+ let s = match size {
+ crate::VectorSize::Bi => ".xx",
+ crate::VectorSize::Tri => ".xxx",
+ crate::VectorSize::Quad => ".xxxx",
+ };
+
+ if let ScalarKind::Uint = scalar.kind {
+ write!(self.out, "((31u){s} - firstbithigh(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "))")?;
+ } else {
+ write!(self.out, "(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(
+ self.out,
+ " < (0){s} ? (0){s} : (31){s} - asint(firstbithigh("
+ )?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ")))")?;
+ }
+ }
+ TypeInner::Scalar(scalar) => {
+ if let ScalarKind::Uint = scalar.kind {
+ write!(self.out, "(31u - firstbithigh(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "))")?;
+ } else {
+ write!(self.out, "(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " < 0 ? 0 : 31 - asint(firstbithigh(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ")))")?;
+ }
+ }
+ _ => unreachable!(),
+ }
+
+ return Ok(());
+ }
+ }
+ }
+ Expression::Swizzle {
+ size,
+ vector,
+ pattern,
+ } => {
+ self.write_expr(module, vector, func_ctx)?;
+ write!(self.out, ".")?;
+ for &sc in pattern[..size as usize].iter() {
+ self.out.write_char(back::COMPONENTS[sc as usize])?;
+ }
+ }
+ Expression::ArrayLength(expr) => {
+ let var_handle = match func_ctx.expressions[expr] {
+ Expression::AccessIndex { base, index: _ } => {
+ match func_ctx.expressions[base] {
+ Expression::GlobalVariable(handle) => handle,
+ _ => unreachable!(),
+ }
+ }
+ Expression::GlobalVariable(handle) => handle,
+ _ => unreachable!(),
+ };
+
+ let var = &module.global_variables[var_handle];
+ let (offset, stride) = match module.types[var.ty].inner {
+ TypeInner::Array { stride, .. } => (0, stride),
+ TypeInner::Struct { ref members, .. } => {
+ let last = members.last().unwrap();
+ let stride = match module.types[last.ty].inner {
+ TypeInner::Array { stride, .. } => stride,
+ _ => unreachable!(),
+ };
+ (last.offset, stride)
+ }
+ _ => unreachable!(),
+ };
+
+ let storage_access = match var.space {
+ crate::AddressSpace::Storage { access } => access,
+ _ => crate::StorageAccess::default(),
+ };
+ let wrapped_array_length = WrappedArrayLength {
+ writable: storage_access.contains(crate::StorageAccess::STORE),
+ };
+
+ write!(self.out, "((")?;
+ self.write_wrapped_array_length_function_name(wrapped_array_length)?;
+ let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
+ write!(self.out, "({var_name}) - {offset}) / {stride})")?
+ }
+ Expression::Derivative { axis, ctrl, expr } => {
+ use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
+ if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
+ let tail = match ctrl {
+ Ctrl::Coarse => "coarse",
+ Ctrl::Fine => "fine",
+ Ctrl::None => unreachable!(),
+ };
+ write!(self.out, "abs(ddx_{tail}(")?;
+ self.write_expr(module, expr, func_ctx)?;
+ write!(self.out, ")) + abs(ddy_{tail}(")?;
+ self.write_expr(module, expr, func_ctx)?;
+ write!(self.out, "))")?
+ } else {
+ let fun_str = match (axis, ctrl) {
+ (Axis::X, Ctrl::Coarse) => "ddx_coarse",
+ (Axis::X, Ctrl::Fine) => "ddx_fine",
+ (Axis::X, Ctrl::None) => "ddx",
+ (Axis::Y, Ctrl::Coarse) => "ddy_coarse",
+ (Axis::Y, Ctrl::Fine) => "ddy_fine",
+ (Axis::Y, Ctrl::None) => "ddy",
+ (Axis::Width, Ctrl::Coarse | Ctrl::Fine) => unreachable!(),
+ (Axis::Width, Ctrl::None) => "fwidth",
+ };
+ write!(self.out, "{fun_str}(")?;
+ self.write_expr(module, expr, func_ctx)?;
+ write!(self.out, ")")?
+ }
+ }
+ Expression::Relational { fun, argument } => {
+ use crate::RelationalFunction as Rf;
+
+ let fun_str = match fun {
+ Rf::All => "all",
+ Rf::Any => "any",
+ Rf::IsNan => "isnan",
+ Rf::IsInf => "isinf",
+ };
+ write!(self.out, "{fun_str}(")?;
+ self.write_expr(module, argument, func_ctx)?;
+ write!(self.out, ")")?
+ }
+ Expression::Select {
+ condition,
+ accept,
+ reject,
+ } => {
+ write!(self.out, "(")?;
+ self.write_expr(module, condition, func_ctx)?;
+ write!(self.out, " ? ")?;
+ self.write_expr(module, accept, func_ctx)?;
+ write!(self.out, " : ")?;
+ self.write_expr(module, reject, func_ctx)?;
+ write!(self.out, ")")?
+ }
+ // Not supported yet
+ Expression::RayQueryGetIntersection { .. } => unreachable!(),
+ // Nothing to do here, since call expression already cached
+ Expression::CallResult(_)
+ | Expression::AtomicResult { .. }
+ | Expression::WorkGroupUniformLoadResult { .. }
+ | Expression::RayQueryProceedResult => {}
+ }
+
+ if !closing_bracket.is_empty() {
+ write!(self.out, "{closing_bracket}")?;
+ }
+ Ok(())
+ }
+
+ fn write_named_expr(
+ &mut self,
+ module: &Module,
+ handle: Handle<crate::Expression>,
+ name: String,
+ // The expression which is being named.
+ // Generally, this is the same as handle, except in WorkGroupUniformLoad
+ named: Handle<crate::Expression>,
+ ctx: &back::FunctionCtx,
+ ) -> BackendResult {
+ match ctx.info[named].ty {
+ proc::TypeResolution::Handle(ty_handle) => match module.types[ty_handle].inner {
+ TypeInner::Struct { .. } => {
+ let ty_name = &self.names[&NameKey::Type(ty_handle)];
+ write!(self.out, "{ty_name}")?;
+ }
+ _ => {
+ self.write_type(module, ty_handle)?;
+ }
+ },
+ proc::TypeResolution::Value(ref inner) => {
+ self.write_value_type(module, inner)?;
+ }
+ }
+
+ let resolved = ctx.resolve_type(named, &module.types);
+
+ write!(self.out, " {name}")?;
+ // If rhs is a array type, we should write array size
+ if let TypeInner::Array { base, size, .. } = *resolved {
+ self.write_array_size(module, base, size)?;
+ }
+ write!(self.out, " = ")?;
+ self.write_expr(module, handle, ctx)?;
+ writeln!(self.out, ";")?;
+ self.named_expressions.insert(named, name);
+
+ Ok(())
+ }
+
+ /// Helper function that write default zero initialization
+ fn write_default_init(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
+ write!(self.out, "(")?;
+ self.write_type(module, ty)?;
+ if let TypeInner::Array { base, size, .. } = module.types[ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+ write!(self.out, ")0")?;
+ Ok(())
+ }
+
+ fn write_barrier(&mut self, barrier: crate::Barrier, level: back::Level) -> BackendResult {
+ if barrier.contains(crate::Barrier::STORAGE) {
+ writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
+ }
+ if barrier.contains(crate::Barrier::WORK_GROUP) {
+ writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?;
+ }
+ Ok(())
+ }
+}
+
+pub(super) struct MatrixType {
+ pub(super) columns: crate::VectorSize,
+ pub(super) rows: crate::VectorSize,
+ pub(super) width: crate::Bytes,
+}
+
+pub(super) fn get_inner_matrix_data(
+ module: &Module,
+ handle: Handle<crate::Type>,
+) -> Option<MatrixType> {
+ match module.types[handle].inner {
+ TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => Some(MatrixType {
+ columns,
+ rows,
+ width: scalar.width,
+ }),
+ TypeInner::Array { base, .. } => get_inner_matrix_data(module, base),
+ _ => None,
+ }
+}
+
+/// Returns the matrix data if the access chain starting at `base`:
+/// - starts with an expression with resolved type of [`TypeInner::Matrix`] if `direct = true`
+/// - contains one or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`]
+/// - ends at an expression with resolved type of [`TypeInner::Struct`]
+pub(super) fn get_inner_matrix_of_struct_array_member(
+ module: &Module,
+ base: Handle<crate::Expression>,
+ func_ctx: &back::FunctionCtx<'_>,
+ direct: bool,
+) -> Option<MatrixType> {
+ let mut mat_data = None;
+ let mut array_base = None;
+
+ let mut current_base = base;
+ loop {
+ let mut resolved = func_ctx.resolve_type(current_base, &module.types);
+ if let TypeInner::Pointer { base, .. } = *resolved {
+ resolved = &module.types[base].inner;
+ };
+
+ match *resolved {
+ TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ mat_data = Some(MatrixType {
+ columns,
+ rows,
+ width: scalar.width,
+ })
+ }
+ TypeInner::Array { base, .. } => {
+ array_base = Some(base);
+ }
+ TypeInner::Struct { .. } => {
+ if let Some(array_base) = array_base {
+ if direct {
+ return mat_data;
+ } else {
+ return get_inner_matrix_data(module, array_base);
+ }
+ }
+
+ break;
+ }
+ _ => break,
+ }
+
+ current_base = match func_ctx.expressions[current_base] {
+ crate::Expression::Access { base, .. } => base,
+ crate::Expression::AccessIndex { base, .. } => base,
+ _ => break,
+ };
+ }
+ None
+}
+
+/// Returns the matrix data if the access chain starting at `base`:
+/// - starts with an expression with resolved type of [`TypeInner::Matrix`]
+/// - contains zero or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`]
+/// - ends with an [`Expression::GlobalVariable`](crate::Expression::GlobalVariable) in [`AddressSpace::Uniform`](crate::AddressSpace::Uniform)
+fn get_inner_matrix_of_global_uniform(
+ module: &Module,
+ base: Handle<crate::Expression>,
+ func_ctx: &back::FunctionCtx<'_>,
+) -> Option<MatrixType> {
+ let mut mat_data = None;
+ let mut array_base = None;
+
+ let mut current_base = base;
+ loop {
+ let mut resolved = func_ctx.resolve_type(current_base, &module.types);
+ if let TypeInner::Pointer { base, .. } = *resolved {
+ resolved = &module.types[base].inner;
+ };
+
+ match *resolved {
+ TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ mat_data = Some(MatrixType {
+ columns,
+ rows,
+ width: scalar.width,
+ })
+ }
+ TypeInner::Array { base, .. } => {
+ array_base = Some(base);
+ }
+ _ => break,
+ }
+
+ current_base = match func_ctx.expressions[current_base] {
+ crate::Expression::Access { base, .. } => base,
+ crate::Expression::AccessIndex { base, .. } => base,
+ crate::Expression::GlobalVariable(handle)
+ if module.global_variables[handle].space == crate::AddressSpace::Uniform =>
+ {
+ return mat_data.or_else(|| {
+ array_base.and_then(|array_base| get_inner_matrix_data(module, array_base))
+ })
+ }
+ _ => break,
+ };
+ }
+ None
+}