diff options
Diffstat (limited to 'third_party/rust/naga/src/back/wgsl')
-rw-r--r-- | third_party/rust/naga/src/back/wgsl/mod.rs | 52 | ||||
-rw-r--r-- | third_party/rust/naga/src/back/wgsl/writer.rs | 1968 |
2 files changed, 2020 insertions, 0 deletions
diff --git a/third_party/rust/naga/src/back/wgsl/mod.rs b/third_party/rust/naga/src/back/wgsl/mod.rs new file mode 100644 index 0000000000..d731b1ca0c --- /dev/null +++ b/third_party/rust/naga/src/back/wgsl/mod.rs @@ -0,0 +1,52 @@ +/*! +Backend for [WGSL][wgsl] (WebGPU Shading Language). + +[wgsl]: https://gpuweb.github.io/gpuweb/wgsl.html +*/ + +mod writer; + +use thiserror::Error; + +pub use writer::{Writer, WriterFlags}; + +#[derive(Error, Debug)] +pub enum Error { + #[error(transparent)] + FmtError(#[from] std::fmt::Error), + #[error("{0}")] + Custom(String), + #[error("{0}")] + Unimplemented(String), // TODO: Error used only during development + #[error("Unsupported math function: {0:?}")] + UnsupportedMathFunction(crate::MathFunction), + #[error("Unsupported relational function: {0:?}")] + UnsupportedRelationalFunction(crate::RelationalFunction), +} + +pub fn write_string( + module: &crate::Module, + info: &crate::valid::ModuleInfo, + flags: WriterFlags, +) -> Result<String, Error> { + let mut w = Writer::new(String::new(), flags); + w.write(module, info)?; + let output = w.finish(); + Ok(output) +} + +impl crate::AtomicFunction { + const fn to_wgsl(self) -> &'static str { + match self { + Self::Add => "Add", + Self::Subtract => "Sub", + Self::And => "And", + Self::InclusiveOr => "Or", + Self::ExclusiveOr => "Xor", + Self::Min => "Min", + Self::Max => "Max", + Self::Exchange { compare: None } => "Exchange", + Self::Exchange { .. } => "CompareExchangeWeak", + } + } +} diff --git a/third_party/rust/naga/src/back/wgsl/writer.rs b/third_party/rust/naga/src/back/wgsl/writer.rs new file mode 100644 index 0000000000..92086c94a8 --- /dev/null +++ b/third_party/rust/naga/src/back/wgsl/writer.rs @@ -0,0 +1,1968 @@ +use super::Error; +use crate::{ + back, + proc::{self, NameKey}, + valid, Handle, Module, ShaderStage, TypeInner, +}; +use std::fmt::Write; + +/// Shorthand result used internally by the backend +type BackendResult = Result<(), Error>; + +/// WGSL [attribute](https://gpuweb.github.io/gpuweb/wgsl/#attributes) +enum Attribute { + Binding(u32), + BuiltIn(crate::BuiltIn), + Group(u32), + Invariant, + Interpolate(Option<crate::Interpolation>, Option<crate::Sampling>), + Location(u32), + Stage(ShaderStage), + WorkGroupSize([u32; 3]), +} + +/// The WGSL form that `write_expr_with_indirection` should use to render a Naga +/// expression. +/// +/// Sometimes a Naga `Expression` alone doesn't provide enough information to +/// choose the right rendering for it in WGSL. For example, one natural WGSL +/// rendering of a Naga `LocalVariable(x)` expression might be `&x`, since +/// `LocalVariable` produces a pointer to the local variable's storage. But when +/// rendering a `Store` statement, the `pointer` operand must be the left hand +/// side of a WGSL assignment, so the proper rendering is `x`. +/// +/// The caller of `write_expr_with_indirection` must provide an `Expected` value +/// to indicate how ambiguous expressions should be rendered. +#[derive(Clone, Copy, Debug)] +enum Indirection { + /// Render pointer-construction expressions as WGSL `ptr`-typed expressions. + /// + /// This is the right choice for most cases. Whenever a Naga pointer + /// expression is not the `pointer` operand of a `Load` or `Store`, it + /// must be a WGSL pointer expression. + Ordinary, + + /// Render pointer-construction expressions as WGSL reference-typed + /// expressions. + /// + /// For example, this is the right choice for the `pointer` operand when + /// rendering a `Store` statement as a WGSL assignment. + Reference, +} + +bitflags::bitflags! { + #[cfg_attr(feature = "serialize", derive(serde::Serialize))] + #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] + pub struct WriterFlags: u32 { + /// Always annotate the type information instead of inferring. + const EXPLICIT_TYPES = 0x1; + } +} + +pub struct Writer<W> { + out: W, + flags: WriterFlags, + names: crate::FastHashMap<NameKey, String>, + namer: proc::Namer, + named_expressions: crate::NamedExpressions, + ep_results: Vec<(ShaderStage, Handle<crate::Type>)>, +} + +impl<W: Write> Writer<W> { + pub fn new(out: W, flags: WriterFlags) -> Self { + Writer { + out, + flags, + names: crate::FastHashMap::default(), + namer: proc::Namer::default(), + named_expressions: crate::NamedExpressions::default(), + ep_results: vec![], + } + } + + fn reset(&mut self, module: &Module) { + self.names.clear(); + self.namer.reset( + module, + crate::keywords::wgsl::RESERVED, + // an identifier must not start with two underscore + &["__"], + &mut self.names, + ); + self.named_expressions.clear(); + self.ep_results.clear(); + } + + pub fn write(&mut self, module: &Module, info: &valid::ModuleInfo) -> BackendResult { + self.reset(module); + + // Save all ep result types + for (_, ep) in module.entry_points.iter().enumerate() { + if let Some(ref result) = ep.function.result { + self.ep_results.push((ep.stage, result.ty)); + } + } + + // Write all structs + for (handle, ty) in module.types.iter() { + if let TypeInner::Struct { + ref members, + span: _, + } = ty.inner + { + self.write_struct(module, handle, members)?; + writeln!(self.out)?; + } + } + + // Write all constants + for (handle, constant) in module.constants.iter() { + if constant.name.is_some() { + self.write_global_constant(module, &constant.inner, handle)?; + } + } + + // Write all globals + for (ty, global) in module.global_variables.iter() { + self.write_global(module, global, ty)?; + } + + if !module.global_variables.is_empty() { + // Add extra newline for readability + writeln!(self.out)?; + } + + // Write all regular functions + for (handle, function) in module.functions.iter() { + let fun_info = &info[handle]; + + let func_ctx = back::FunctionCtx { + ty: back::FunctionType::Function(handle), + info: fun_info, + expressions: &function.expressions, + named_expressions: &function.named_expressions, + }; + + // Write the function + self.write_function(module, function, &func_ctx)?; + + writeln!(self.out)?; + } + + // Write all entry points + for (index, ep) in module.entry_points.iter().enumerate() { + let attributes = match ep.stage { + ShaderStage::Vertex | ShaderStage::Fragment => vec![Attribute::Stage(ep.stage)], + ShaderStage::Compute => vec![ + Attribute::Stage(ShaderStage::Compute), + Attribute::WorkGroupSize(ep.workgroup_size), + ], + }; + + self.write_attributes(&attributes)?; + // Add a newline after attribute + writeln!(self.out)?; + + let func_ctx = back::FunctionCtx { + ty: back::FunctionType::EntryPoint(index as u16), + info: info.get_entry_point(index), + expressions: &ep.function.expressions, + named_expressions: &ep.function.named_expressions, + }; + self.write_function(module, &ep.function, &func_ctx)?; + + if index < module.entry_points.len() - 1 { + writeln!(self.out)?; + } + } + + Ok(()) + } + + /// Helper method used to write [`ScalarValue`](crate::ScalarValue) + /// + /// # Notes + /// Adds no trailing or leading whitespace + fn write_scalar_value(&mut self, value: crate::ScalarValue) -> BackendResult { + use crate::ScalarValue as Sv; + + match value { + Sv::Sint(value) => write!(self.out, "{value}")?, + Sv::Uint(value) => write!(self.out, "{value}u")?, + // Floats are written using `Debug` instead of `Display` because it always appends the + // decimal part even it's zero + Sv::Float(value) => write!(self.out, "{value:?}")?, + Sv::Bool(value) => write!(self.out, "{value}")?, + } + + Ok(()) + } + + /// Helper method used to write struct name + /// + /// # Notes + /// Adds no trailing or leading whitespace + fn write_struct_name(&mut self, module: &Module, handle: Handle<crate::Type>) -> BackendResult { + if module.types[handle].name.is_none() { + if let Some(&(stage, _)) = self.ep_results.iter().find(|&&(_, ty)| ty == handle) { + let name = match stage { + ShaderStage::Compute => "ComputeOutput", + ShaderStage::Fragment => "FragmentOutput", + ShaderStage::Vertex => "VertexOutput", + }; + + write!(self.out, "{name}")?; + return Ok(()); + } + } + + write!(self.out, "{}", self.names[&NameKey::Type(handle)])?; + + Ok(()) + } + + /// Helper method used to write + /// [functions](https://gpuweb.github.io/gpuweb/wgsl/#functions) + /// + /// # Notes + /// Ends in a newline + fn write_function( + &mut self, + module: &Module, + func: &crate::Function, + func_ctx: &back::FunctionCtx<'_>, + ) -> BackendResult { + let func_name = match func_ctx.ty { + back::FunctionType::EntryPoint(index) => &self.names[&NameKey::EntryPoint(index)], + back::FunctionType::Function(handle) => &self.names[&NameKey::Function(handle)], + }; + + // Write function name + write!(self.out, "fn {func_name}(")?; + + // Write function arguments + for (index, arg) in func.arguments.iter().enumerate() { + // Write argument attribute if a binding is present + if let Some(ref binding) = arg.binding { + self.write_attributes(&map_binding_to_attribute( + binding, + module.types[arg.ty].inner.scalar_kind(), + ))?; + } + // Write argument name + let argument_name = match func_ctx.ty { + back::FunctionType::Function(handle) => { + &self.names[&NameKey::FunctionArgument(handle, index as u32)] + } + back::FunctionType::EntryPoint(ep_index) => { + &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)] + } + }; + + write!(self.out, "{argument_name}: ")?; + // Write argument type + self.write_type(module, arg.ty)?; + if index < func.arguments.len() - 1 { + // Add a separator between args + write!(self.out, ", ")?; + } + } + + write!(self.out, ")")?; + + // Write function return type + if let Some(ref result) = func.result { + write!(self.out, " -> ")?; + if let Some(ref binding) = result.binding { + self.write_attributes(&map_binding_to_attribute( + binding, + module.types[result.ty].inner.scalar_kind(), + ))?; + } + self.write_type(module, result.ty)?; + } + + write!(self.out, " {{")?; + writeln!(self.out)?; + + // Write function local variables + for (handle, local) in func.local_variables.iter() { + // Write indentation (only for readability) + write!(self.out, "{}", back::INDENT)?; + + // Write the local name + // The leading space is important + write!(self.out, "var {}: ", self.names[&func_ctx.name_key(handle)])?; + + // Write the local type + self.write_type(module, local.ty)?; + + // Write the local initializer if needed + if let Some(init) = local.init { + // Put the equal signal only if there's a initializer + // The leading and trailing spaces aren't needed but help with readability + write!(self.out, " = ")?; + + // Write the constant + // `write_constant` adds no trailing or leading space/newline + self.write_constant(module, init)?; + } + + // Finish the local with `;` and add a newline (only for readability) + writeln!(self.out, ";")? + } + + if !func.local_variables.is_empty() { + writeln!(self.out)?; + } + + // Write the function body (statement list) + for sta in func.body.iter() { + // The indentation should always be 1 when writing the function body + self.write_stmt(module, sta, func_ctx, back::Level(1))?; + } + + writeln!(self.out, "}}")?; + + self.named_expressions.clear(); + + Ok(()) + } + + /// Helper method to write a attribute + fn write_attributes(&mut self, attributes: &[Attribute]) -> BackendResult { + for attribute in attributes { + match *attribute { + Attribute::Location(id) => write!(self.out, "@location({id}) ")?, + Attribute::BuiltIn(builtin_attrib) => { + let builtin = builtin_str(builtin_attrib)?; + write!(self.out, "@builtin({builtin}) ")?; + } + Attribute::Stage(shader_stage) => { + let stage_str = match shader_stage { + ShaderStage::Vertex => "vertex", + ShaderStage::Fragment => "fragment", + ShaderStage::Compute => "compute", + }; + write!(self.out, "@{stage_str} ")?; + } + Attribute::WorkGroupSize(size) => { + write!( + self.out, + "@workgroup_size({}, {}, {}) ", + size[0], size[1], size[2] + )?; + } + Attribute::Binding(id) => write!(self.out, "@binding({id}) ")?, + Attribute::Group(id) => write!(self.out, "@group({id}) ")?, + Attribute::Invariant => write!(self.out, "@invariant ")?, + Attribute::Interpolate(interpolation, sampling) => { + if sampling.is_some() && sampling != Some(crate::Sampling::Center) { + write!( + self.out, + "@interpolate({}, {}) ", + interpolation_str( + interpolation.unwrap_or(crate::Interpolation::Perspective) + ), + sampling_str(sampling.unwrap_or(crate::Sampling::Center)) + )?; + } else if interpolation.is_some() + && interpolation != Some(crate::Interpolation::Perspective) + { + write!( + self.out, + "@interpolate({}) ", + interpolation_str( + interpolation.unwrap_or(crate::Interpolation::Perspective) + ) + )?; + } + } + }; + } + Ok(()) + } + + /// Helper method used to write structs + /// + /// # Notes + /// Ends in a newline + fn write_struct( + &mut self, + module: &Module, + handle: Handle<crate::Type>, + members: &[crate::StructMember], + ) -> BackendResult { + write!(self.out, "struct ")?; + self.write_struct_name(module, handle)?; + write!(self.out, " {{")?; + writeln!(self.out)?; + for (index, member) in members.iter().enumerate() { + // The indentation is only for readability + write!(self.out, "{}", back::INDENT)?; + if let Some(ref binding) = member.binding { + self.write_attributes(&map_binding_to_attribute( + binding, + module.types[member.ty].inner.scalar_kind(), + ))?; + } + // Write struct member name and type + let member_name = &self.names[&NameKey::StructMember(handle, index as u32)]; + write!(self.out, "{member_name}: ")?; + self.write_type(module, member.ty)?; + write!(self.out, ",")?; + writeln!(self.out)?; + } + + write!(self.out, "}}")?; + + writeln!(self.out)?; + + Ok(()) + } + + /// Helper method used to write non image/sampler types + /// + /// # Notes + /// Adds no trailing or leading whitespace + fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult { + let inner = &module.types[ty].inner; + match *inner { + TypeInner::Struct { .. } => self.write_struct_name(module, ty)?, + ref other => self.write_value_type(module, other)?, + } + + Ok(()) + } + + /// Helper method used to write value types + /// + /// # Notes + /// Adds no trailing or leading whitespace + fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult { + match *inner { + TypeInner::Vector { size, kind, width } => write!( + self.out, + "vec{}<{}>", + back::vector_size_str(size), + scalar_kind_str(kind, width), + )?, + TypeInner::Sampler { comparison: false } => { + write!(self.out, "sampler")?; + } + TypeInner::Sampler { comparison: true } => { + write!(self.out, "sampler_comparison")?; + } + TypeInner::Image { + dim, + arrayed, + class, + } => { + // More about texture types: https://gpuweb.github.io/gpuweb/wgsl/#sampled-texture-type + use crate::ImageClass as Ic; + + let dim_str = image_dimension_str(dim); + let arrayed_str = if arrayed { "_array" } else { "" }; + let (class_str, multisampled_str, format_str, storage_str) = match class { + Ic::Sampled { kind, multi } => ( + "", + if multi { "multisampled_" } else { "" }, + scalar_kind_str(kind, 4), + "", + ), + Ic::Depth { multi } => { + ("depth_", if multi { "multisampled_" } else { "" }, "", "") + } + Ic::Storage { format, access } => ( + "storage_", + "", + storage_format_str(format), + if access.contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE) + { + ",read_write" + } else if access.contains(crate::StorageAccess::LOAD) { + ",read" + } else { + ",write" + }, + ), + }; + write!( + self.out, + "texture_{class_str}{multisampled_str}{dim_str}{arrayed_str}" + )?; + + if !format_str.is_empty() { + write!(self.out, "<{format_str}{storage_str}>")?; + } + } + TypeInner::Scalar { kind, width } => { + write!(self.out, "{}", scalar_kind_str(kind, width))?; + } + TypeInner::Atomic { kind, width } => { + write!(self.out, "atomic<{}>", scalar_kind_str(kind, width))?; + } + TypeInner::Array { + base, + size, + stride: _, + } => { + // More info https://gpuweb.github.io/gpuweb/wgsl/#array-types + // array<A, 3> -- Constant array + // array<A> -- Dynamic array + write!(self.out, "array<")?; + match size { + crate::ArraySize::Constant(handle) => { + self.write_type(module, base)?; + write!(self.out, ",")?; + self.write_constant(module, handle)?; + } + crate::ArraySize::Dynamic => { + self.write_type(module, base)?; + } + } + write!(self.out, ">")?; + } + TypeInner::BindingArray { base, size } => { + // More info https://github.com/gpuweb/gpuweb/issues/2105 + write!(self.out, "binding_array<")?; + match size { + crate::ArraySize::Constant(handle) => { + self.write_type(module, base)?; + write!(self.out, ",")?; + self.write_constant(module, handle)?; + } + crate::ArraySize::Dynamic => { + self.write_type(module, base)?; + } + } + write!(self.out, ">")?; + } + TypeInner::Matrix { + columns, + rows, + width: _, + } => { + write!( + self.out, + //TODO: Can matrix be other than f32? + "mat{}x{}<f32>", + back::vector_size_str(columns), + back::vector_size_str(rows), + )?; + } + TypeInner::Pointer { base, space } => { + let (address, maybe_access) = address_space_str(space); + // Everything but `AddressSpace::Handle` gives us a `address` name, but + // Naga IR never produces pointers to handles, so it doesn't matter much + // how we write such a type. Just write it as the base type alone. + if let Some(space) = address { + write!(self.out, "ptr<{space}, ")?; + } + self.write_type(module, base)?; + if address.is_some() { + if let Some(access) = maybe_access { + write!(self.out, ", {access}")?; + } + write!(self.out, ">")?; + } + } + TypeInner::ValuePointer { + size: None, + kind, + width, + space, + } => { + let (address, maybe_access) = address_space_str(space); + if let Some(space) = address { + write!(self.out, "ptr<{}, {}", space, scalar_kind_str(kind, width))?; + if let Some(access) = maybe_access { + write!(self.out, ", {access}")?; + } + write!(self.out, ">")?; + } else { + return Err(Error::Unimplemented(format!( + "ValuePointer to AddressSpace::Handle {inner:?}" + ))); + } + } + TypeInner::ValuePointer { + size: Some(size), + kind, + width, + space, + } => { + let (address, maybe_access) = address_space_str(space); + if let Some(space) = address { + write!( + self.out, + "ptr<{}, vec{}<{}>", + space, + back::vector_size_str(size), + scalar_kind_str(kind, width) + )?; + if let Some(access) = maybe_access { + write!(self.out, ", {access}")?; + } + write!(self.out, ">")?; + } else { + return Err(Error::Unimplemented(format!( + "ValuePointer to AddressSpace::Handle {inner:?}" + ))); + } + write!(self.out, ">")?; + } + _ => { + return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))); + } + } + + Ok(()) + } + /// Helper method used to write statements + /// + /// # Notes + /// Always adds a newline + fn write_stmt( + &mut self, + module: &Module, + stmt: &crate::Statement, + func_ctx: &back::FunctionCtx<'_>, + level: back::Level, + ) -> BackendResult { + use crate::{Expression, Statement}; + + match *stmt { + Statement::Emit(ref range) => { + for handle in range.clone() { + let info = &func_ctx.info[handle]; + let expr_name = if let Some(name) = func_ctx.named_expressions.get(&handle) { + // Front end provides names for all variables at the start of writing. + // But we write them to step by step. We need to recache them + // Otherwise, we could accidentally write variable name instead of full expression. + // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords. + Some(self.namer.call(name)) + } else if info.ref_count == 0 { + write!(self.out, "{level}_ = ")?; + self.write_expr(module, handle, func_ctx)?; + writeln!(self.out, ";")?; + continue; + } else { + let expr = &func_ctx.expressions[handle]; + let min_ref_count = expr.bake_ref_count(); + // Forcefully creating baking expressions in some cases to help with readability + let required_baking_expr = match *expr { + Expression::ImageLoad { .. } + | Expression::ImageQuery { .. } + | Expression::ImageSample { .. } => true, + _ => false, + }; + if min_ref_count <= info.ref_count || required_baking_expr { + Some(format!("{}{}", back::BAKE_PREFIX, handle.index())) + } else { + None + } + }; + + if let Some(name) = expr_name { + write!(self.out, "{level}")?; + self.start_named_expr(module, handle, func_ctx, &name)?; + self.write_expr(module, handle, func_ctx)?; + self.named_expressions.insert(handle, name); + writeln!(self.out, ";")?; + } + } + } + // TODO: copy-paste from glsl-out + Statement::If { + condition, + ref accept, + ref reject, + } => { + write!(self.out, "{level}")?; + write!(self.out, "if ")?; + self.write_expr(module, condition, func_ctx)?; + writeln!(self.out, " {{")?; + + let l2 = level.next(); + for sta in accept { + // Increase indentation to help with readability + self.write_stmt(module, sta, func_ctx, l2)?; + } + + // If there are no statements in the reject block we skip writing it + // This is only for readability + if !reject.is_empty() { + writeln!(self.out, "{level}}} else {{")?; + + for sta in reject { + // Increase indentation to help with readability + self.write_stmt(module, sta, func_ctx, l2)?; + } + } + + writeln!(self.out, "{level}}}")? + } + Statement::Return { value } => { + write!(self.out, "{level}")?; + write!(self.out, "return")?; + if let Some(return_value) = value { + // The leading space is important + write!(self.out, " ")?; + self.write_expr(module, return_value, func_ctx)?; + } + writeln!(self.out, ";")?; + } + // TODO: copy-paste from glsl-out + Statement::Kill => { + write!(self.out, "{level}")?; + writeln!(self.out, "discard;")? + } + Statement::Store { pointer, value } => { + write!(self.out, "{level}")?; + + let is_atomic = match *func_ctx.info[pointer].ty.inner_with(&module.types) { + crate::TypeInner::Pointer { base, .. } => match module.types[base].inner { + crate::TypeInner::Atomic { .. } => true, + _ => false, + }, + _ => false, + }; + if is_atomic { + write!(self.out, "atomicStore(")?; + self.write_expr(module, pointer, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, value, func_ctx)?; + write!(self.out, ")")?; + } else { + self.write_expr_with_indirection( + module, + pointer, + func_ctx, + Indirection::Reference, + )?; + write!(self.out, " = ")?; + self.write_expr(module, value, func_ctx)?; + } + writeln!(self.out, ";")? + } + Statement::Call { + function, + ref arguments, + result, + } => { + write!(self.out, "{level}")?; + if let Some(expr) = result { + let name = format!("{}{}", back::BAKE_PREFIX, expr.index()); + self.start_named_expr(module, expr, func_ctx, &name)?; + self.named_expressions.insert(expr, name); + } + let func_name = &self.names[&NameKey::Function(function)]; + write!(self.out, "{func_name}(")?; + for (index, &argument) in arguments.iter().enumerate() { + self.write_expr(module, argument, func_ctx)?; + // Only write a comma if isn't the last element + if index != arguments.len().saturating_sub(1) { + // The leading space is for readability only + write!(self.out, ", ")?; + } + } + writeln!(self.out, ");")? + } + Statement::Atomic { + pointer, + ref fun, + value, + result, + } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_named_expr(module, result, func_ctx, &res_name)?; + self.named_expressions.insert(result, res_name); + + let fun_str = fun.to_wgsl(); + write!(self.out, "atomic{fun_str}(")?; + self.write_expr(module, pointer, func_ctx)?; + if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun { + write!(self.out, ", ")?; + self.write_expr(module, cmp, func_ctx)?; + } + write!(self.out, ", ")?; + self.write_expr(module, value, func_ctx)?; + writeln!(self.out, ");")? + } + Statement::ImageStore { + image, + coordinate, + array_index, + value, + } => { + write!(self.out, "{level}")?; + write!(self.out, "textureStore(")?; + self.write_expr(module, image, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, coordinate, func_ctx)?; + if let Some(array_index_expr) = array_index { + write!(self.out, ", ")?; + self.write_expr(module, array_index_expr, func_ctx)?; + } + write!(self.out, ", ")?; + self.write_expr(module, value, func_ctx)?; + writeln!(self.out, ");")?; + } + // TODO: copy-paste from glsl-out + Statement::Block(ref block) => { + write!(self.out, "{level}")?; + writeln!(self.out, "{{")?; + for sta in block.iter() { + // Increase the indentation to help with readability + self.write_stmt(module, sta, func_ctx, level.next())? + } + writeln!(self.out, "{level}}}")? + } + Statement::Switch { + selector, + ref cases, + } => { + // Start the switch + write!(self.out, "{level}")?; + write!(self.out, "switch ")?; + self.write_expr(module, selector, func_ctx)?; + writeln!(self.out, " {{")?; + + let l2 = level.next(); + let mut new_case = true; + for case in cases { + if case.fall_through && !case.body.is_empty() { + // TODO: we could do the same workaround as we did for the HLSL backend + return Err(Error::Unimplemented( + "fall-through switch case block".into(), + )); + } + + match case.value { + crate::SwitchValue::I32(value) => { + if new_case { + write!(self.out, "{l2}case ")?; + } + write!(self.out, "{value}")?; + } + crate::SwitchValue::U32(value) => { + if new_case { + write!(self.out, "{l2}case ")?; + } + write!(self.out, "{value}u")?; + } + crate::SwitchValue::Default => { + if new_case { + if case.fall_through { + write!(self.out, "{l2}case ")?; + } else { + write!(self.out, "{l2}")?; + } + } + write!(self.out, "default")?; + } + } + + new_case = !case.fall_through; + + if case.fall_through { + write!(self.out, ", ")?; + } else { + writeln!(self.out, ": {{")?; + } + + for sta in case.body.iter() { + self.write_stmt(module, sta, func_ctx, l2.next())?; + } + + if !case.fall_through { + writeln!(self.out, "{l2}}}")?; + } + } + + writeln!(self.out, "{level}}}")? + } + Statement::Loop { + ref body, + ref continuing, + break_if, + } => { + write!(self.out, "{level}")?; + writeln!(self.out, "loop {{")?; + + let l2 = level.next(); + for sta in body.iter() { + self.write_stmt(module, sta, func_ctx, l2)?; + } + + // The continuing is optional so we don't need to write it if + // it is empty, but the `break if` counts as a continuing statement + // so even if `continuing` is empty we must generate it if a + // `break if` exists + if !continuing.is_empty() || break_if.is_some() { + writeln!(self.out, "{l2}continuing {{")?; + for sta in continuing.iter() { + self.write_stmt(module, sta, func_ctx, l2.next())?; + } + + // The `break if` is always the last + // statement of the `continuing` block + if let Some(condition) = break_if { + // The trailing space is important + write!(self.out, "{}break if ", l2.next())?; + self.write_expr(module, condition, func_ctx)?; + // Close the `break if` statement + writeln!(self.out, ";")?; + } + + writeln!(self.out, "{l2}}}")?; + } + + writeln!(self.out, "{level}}}")? + } + Statement::Break => { + writeln!(self.out, "{level}break;")?; + } + Statement::Continue => { + writeln!(self.out, "{level}continue;")?; + } + Statement::Barrier(barrier) => { + if barrier.contains(crate::Barrier::STORAGE) { + writeln!(self.out, "{level}storageBarrier();")?; + } + + if barrier.contains(crate::Barrier::WORK_GROUP) { + writeln!(self.out, "{level}workgroupBarrier();")?; + } + } + Statement::RayQuery { .. } => unreachable!(), + } + + Ok(()) + } + + /// Return the sort of indirection that `expr`'s plain form evaluates to. + /// + /// An expression's 'plain form' is the most general rendition of that + /// expression into WGSL, lacking `&` or `*` operators: + /// + /// - The plain form of `LocalVariable(x)` is simply `x`, which is a reference + /// to the local variable's storage. + /// + /// - The plain form of `GlobalVariable(g)` is simply `g`, which is usually a + /// reference to the global variable's storage. However, globals in the + /// `Handle` address space are immutable, and `GlobalVariable` expressions for + /// those produce the value directly, not a pointer to it. Such + /// `GlobalVariable` expressions are `Ordinary`. + /// + /// - `Access` and `AccessIndex` are `Reference` when their `base` operand is a + /// pointer. If they are applied directly to a composite value, they are + /// `Ordinary`. + /// + /// Note that `FunctionArgument` expressions are never `Reference`, even when + /// the argument's type is `Pointer`. `FunctionArgument` always evaluates to the + /// argument's value directly, so any pointer it produces is merely the value + /// passed by the caller. + fn plain_form_indirection( + &self, + expr: Handle<crate::Expression>, + module: &Module, + func_ctx: &back::FunctionCtx<'_>, + ) -> Indirection { + use crate::Expression as Ex; + + // Named expressions are `let` expressions, which apply the Load Rule, + // so if their type is a Naga pointer, then that must be a WGSL pointer + // as well. + if self.named_expressions.contains_key(&expr) { + return Indirection::Ordinary; + } + + match func_ctx.expressions[expr] { + Ex::LocalVariable(_) => Indirection::Reference, + Ex::GlobalVariable(handle) => { + let global = &module.global_variables[handle]; + match global.space { + crate::AddressSpace::Handle => Indirection::Ordinary, + _ => Indirection::Reference, + } + } + Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => { + let base_ty = func_ctx.info[base].ty.inner_with(&module.types); + match *base_ty { + crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } => { + Indirection::Reference + } + _ => Indirection::Ordinary, + } + } + _ => Indirection::Ordinary, + } + } + + fn start_named_expr( + &mut self, + module: &Module, + handle: Handle<crate::Expression>, + func_ctx: &back::FunctionCtx, + name: &str, + ) -> BackendResult { + // Write variable name + write!(self.out, "let {name}")?; + if self.flags.contains(WriterFlags::EXPLICIT_TYPES) { + write!(self.out, ": ")?; + let ty = &func_ctx.info[handle].ty; + // Write variable type + match *ty { + proc::TypeResolution::Handle(handle) => { + self.write_type(module, handle)?; + } + proc::TypeResolution::Value(ref inner) => { + self.write_value_type(module, inner)?; + } + } + } + + write!(self.out, " = ")?; + Ok(()) + } + + /// Write the ordinary WGSL form of `expr`. + /// + /// See `write_expr_with_indirection` for details. + fn write_expr( + &mut self, + module: &Module, + expr: Handle<crate::Expression>, + func_ctx: &back::FunctionCtx<'_>, + ) -> BackendResult { + self.write_expr_with_indirection(module, expr, func_ctx, Indirection::Ordinary) + } + + /// Write `expr` as a WGSL expression with the requested indirection. + /// + /// In terms of the WGSL grammar, the resulting expression is a + /// `singular_expression`. It may be parenthesized. This makes it suitable + /// for use as the operand of a unary or binary operator without worrying + /// about precedence. + /// + /// This does not produce newlines or indentation. + /// + /// The `requested` argument indicates (roughly) whether Naga + /// `Pointer`-valued expressions represent WGSL references or pointers. See + /// `Indirection` for details. + fn write_expr_with_indirection( + &mut self, + module: &Module, + expr: Handle<crate::Expression>, + func_ctx: &back::FunctionCtx<'_>, + requested: Indirection, + ) -> BackendResult { + // If the plain form of the expression is not what we need, emit the + // operator necessary to correct that. + let plain = self.plain_form_indirection(expr, module, func_ctx); + match (requested, plain) { + (Indirection::Ordinary, Indirection::Reference) => { + write!(self.out, "(&")?; + self.write_expr_plain_form(module, expr, func_ctx, plain)?; + write!(self.out, ")")?; + } + (Indirection::Reference, Indirection::Ordinary) => { + write!(self.out, "(*")?; + self.write_expr_plain_form(module, expr, func_ctx, plain)?; + write!(self.out, ")")?; + } + (_, _) => self.write_expr_plain_form(module, expr, func_ctx, plain)?, + } + + Ok(()) + } + + /// Write the 'plain form' of `expr`. + /// + /// An expression's 'plain form' is the most general rendition of that + /// expression into WGSL, lacking `&` or `*` operators. The plain forms of + /// `LocalVariable(x)` and `GlobalVariable(g)` are simply `x` and `g`. Such + /// Naga expressions represent both WGSL pointers and references; it's the + /// caller's responsibility to distinguish those cases appropriately. + fn write_expr_plain_form( + &mut self, + module: &Module, + expr: Handle<crate::Expression>, + func_ctx: &back::FunctionCtx<'_>, + indirection: Indirection, + ) -> BackendResult { + use crate::Expression; + + if let Some(name) = self.named_expressions.get(&expr) { + write!(self.out, "{name}")?; + return Ok(()); + } + + let expression = &func_ctx.expressions[expr]; + + // Write the plain WGSL form of a Naga expression. + // + // The plain form of `LocalVariable` and `GlobalVariable` expressions is + // simply the variable name; `*` and `&` operators are never emitted. + // + // The plain form of `Access` and `AccessIndex` expressions are WGSL + // `postfix_expression` forms for member/component access and + // subscripting. + match *expression { + Expression::Constant(constant) => self.write_constant(module, constant)?, + Expression::Compose { ty, ref components } => { + self.write_type(module, ty)?; + write!(self.out, "(")?; + for (index, component) in components.iter().enumerate() { + self.write_expr(module, *component, func_ctx)?; + // Only write a comma if isn't the last element + if index != components.len().saturating_sub(1) { + // The leading space is for readability only + write!(self.out, ", ")?; + } + } + write!(self.out, ")")? + } + Expression::FunctionArgument(pos) => { + let name_key = func_ctx.argument_key(pos); + let name = &self.names[&name_key]; + write!(self.out, "{name}")?; + } + Expression::Binary { op, left, right } => { + write!(self.out, "(")?; + self.write_expr(module, left, func_ctx)?; + write!(self.out, " {} ", back::binary_operation_str(op))?; + self.write_expr(module, right, func_ctx)?; + write!(self.out, ")")?; + } + Expression::Access { base, index } => { + self.write_expr_with_indirection(module, base, func_ctx, indirection)?; + write!(self.out, "[")?; + self.write_expr(module, index, func_ctx)?; + write!(self.out, "]")? + } + Expression::AccessIndex { base, index } => { + let base_ty_res = &func_ctx.info[base].ty; + let mut resolved = base_ty_res.inner_with(&module.types); + + self.write_expr_with_indirection(module, base, func_ctx, indirection)?; + + let base_ty_handle = match *resolved { + TypeInner::Pointer { base, space: _ } => { + resolved = &module.types[base].inner; + Some(base) + } + _ => base_ty_res.handle(), + }; + + match *resolved { + TypeInner::Vector { .. } => { + // Write vector access as a swizzle + write!(self.out, ".{}", back::COMPONENTS[index as usize])? + } + TypeInner::Matrix { .. } + | TypeInner::Array { .. } + | TypeInner::BindingArray { .. } + | TypeInner::ValuePointer { .. } => write!(self.out, "[{index}]")?, + TypeInner::Struct { .. } => { + // This will never panic in case the type is a `Struct`, this is not true + // for other types so we can only check while inside this match arm + let ty = base_ty_handle.unwrap(); + + write!( + self.out, + ".{}", + &self.names[&NameKey::StructMember(ty, index)] + )? + } + ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))), + } + } + Expression::ImageSample { + image, + sampler, + gather: None, + coordinate, + array_index, + offset, + level, + depth_ref, + } => { + use crate::SampleLevel as Sl; + + let suffix_cmp = match depth_ref { + Some(_) => "Compare", + None => "", + }; + let suffix_level = match level { + Sl::Auto => "", + Sl::Zero | Sl::Exact(_) => "Level", + Sl::Bias(_) => "Bias", + Sl::Gradient { .. } => "Grad", + }; + + write!(self.out, "textureSample{suffix_cmp}{suffix_level}(")?; + self.write_expr(module, image, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, sampler, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, coordinate, func_ctx)?; + + if let Some(array_index) = array_index { + write!(self.out, ", ")?; + self.write_expr(module, array_index, func_ctx)?; + } + + if let Some(depth_ref) = depth_ref { + write!(self.out, ", ")?; + self.write_expr(module, depth_ref, func_ctx)?; + } + + match level { + Sl::Auto => {} + Sl::Zero => { + // Level 0 is implied for depth comparison + if depth_ref.is_none() { + write!(self.out, ", 0.0")?; + } + } + Sl::Exact(expr) => { + write!(self.out, ", ")?; + self.write_expr(module, expr, func_ctx)?; + } + Sl::Bias(expr) => { + write!(self.out, ", ")?; + self.write_expr(module, expr, func_ctx)?; + } + Sl::Gradient { x, y } => { + write!(self.out, ", ")?; + self.write_expr(module, x, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, y, func_ctx)?; + } + } + + if let Some(offset) = offset { + write!(self.out, ", ")?; + self.write_constant(module, offset)?; + } + + write!(self.out, ")")?; + } + Expression::ImageSample { + image, + sampler, + gather: Some(component), + coordinate, + array_index, + offset, + level: _, + depth_ref, + } => { + let suffix_cmp = match depth_ref { + Some(_) => "Compare", + None => "", + }; + + write!(self.out, "textureGather{suffix_cmp}(")?; + match *func_ctx.info[image].ty.inner_with(&module.types) { + TypeInner::Image { + class: crate::ImageClass::Depth { multi: _ }, + .. + } => {} + _ => { + write!(self.out, "{}, ", component as u8)?; + } + } + self.write_expr(module, image, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, sampler, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, coordinate, func_ctx)?; + + if let Some(array_index) = array_index { + write!(self.out, ", ")?; + self.write_expr(module, array_index, func_ctx)?; + } + + if let Some(depth_ref) = depth_ref { + write!(self.out, ", ")?; + self.write_expr(module, depth_ref, func_ctx)?; + } + + if let Some(offset) = offset { + write!(self.out, ", ")?; + self.write_constant(module, offset)?; + } + + write!(self.out, ")")?; + } + Expression::ImageQuery { image, query } => { + use crate::ImageQuery as Iq; + + let texture_function = match query { + Iq::Size { .. } => "textureDimensions", + Iq::NumLevels => "textureNumLevels", + Iq::NumLayers => "textureNumLayers", + Iq::NumSamples => "textureNumSamples", + }; + + write!(self.out, "{texture_function}(")?; + self.write_expr(module, image, func_ctx)?; + if let Iq::Size { level: Some(level) } = query { + write!(self.out, ", ")?; + self.write_expr(module, level, func_ctx)?; + }; + write!(self.out, ")")?; + } + Expression::ImageLoad { + image, + coordinate, + array_index, + sample, + level, + } => { + write!(self.out, "textureLoad(")?; + self.write_expr(module, image, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, coordinate, func_ctx)?; + if let Some(array_index) = array_index { + write!(self.out, ", ")?; + self.write_expr(module, array_index, func_ctx)?; + } + if let Some(index) = sample.or(level) { + write!(self.out, ", ")?; + self.write_expr(module, index, func_ctx)?; + } + write!(self.out, ")")?; + } + Expression::GlobalVariable(handle) => { + let name = &self.names[&NameKey::GlobalVariable(handle)]; + write!(self.out, "{name}")?; + } + Expression::As { + expr, + kind, + convert, + } => { + let inner = func_ctx.info[expr].ty.inner_with(&module.types); + match *inner { + TypeInner::Matrix { + columns, + rows, + width, + .. + } => { + let scalar_kind_str = scalar_kind_str(kind, convert.unwrap_or(width)); + write!( + self.out, + "mat{}x{}<{}>", + back::vector_size_str(columns), + back::vector_size_str(rows), + scalar_kind_str + )?; + } + TypeInner::Vector { size, width, .. } => { + let vector_size_str = back::vector_size_str(size); + let scalar_kind_str = scalar_kind_str(kind, convert.unwrap_or(width)); + if convert.is_some() { + write!(self.out, "vec{vector_size_str}<{scalar_kind_str}>")?; + } else { + write!(self.out, "bitcast<vec{vector_size_str}<{scalar_kind_str}>>")?; + } + } + TypeInner::Scalar { width, .. } => { + let scalar_kind_str = scalar_kind_str(kind, convert.unwrap_or(width)); + if convert.is_some() { + write!(self.out, "{scalar_kind_str}")? + } else { + write!(self.out, "bitcast<{scalar_kind_str}>")? + } + } + _ => { + return Err(Error::Unimplemented(format!( + "write_expr expression::as {inner:?}" + ))); + } + }; + write!(self.out, "(")?; + self.write_expr(module, expr, func_ctx)?; + write!(self.out, ")")?; + } + Expression::Splat { size, value } => { + let inner = func_ctx.info[value].ty.inner_with(&module.types); + let (scalar_kind, scalar_width) = match *inner { + crate::TypeInner::Scalar { kind, width } => (kind, width), + _ => { + return Err(Error::Unimplemented(format!( + "write_expr expression::splat {inner:?}" + ))); + } + }; + let scalar = scalar_kind_str(scalar_kind, scalar_width); + let size = back::vector_size_str(size); + + write!(self.out, "vec{size}<{scalar}>(")?; + self.write_expr(module, value, func_ctx)?; + write!(self.out, ")")?; + } + Expression::Load { pointer } => { + let is_atomic = match *func_ctx.info[pointer].ty.inner_with(&module.types) { + crate::TypeInner::Pointer { base, .. } => match module.types[base].inner { + crate::TypeInner::Atomic { .. } => true, + _ => false, + }, + _ => false, + }; + + if is_atomic { + write!(self.out, "atomicLoad(")?; + self.write_expr(module, pointer, func_ctx)?; + write!(self.out, ")")?; + } else { + self.write_expr_with_indirection( + module, + pointer, + func_ctx, + Indirection::Reference, + )?; + } + } + Expression::LocalVariable(handle) => { + write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])? + } + Expression::ArrayLength(expr) => { + write!(self.out, "arrayLength(")?; + self.write_expr(module, expr, func_ctx)?; + write!(self.out, ")")?; + } + Expression::Math { + fun, + arg, + arg1, + arg2, + arg3, + } => { + use crate::MathFunction as Mf; + + enum Function { + Regular(&'static str), + } + + let function = match fun { + Mf::Abs => Function::Regular("abs"), + Mf::Min => Function::Regular("min"), + Mf::Max => Function::Regular("max"), + Mf::Clamp => Function::Regular("clamp"), + Mf::Saturate => Function::Regular("saturate"), + // trigonometry + Mf::Cos => Function::Regular("cos"), + Mf::Cosh => Function::Regular("cosh"), + Mf::Sin => Function::Regular("sin"), + Mf::Sinh => Function::Regular("sinh"), + Mf::Tan => Function::Regular("tan"), + Mf::Tanh => Function::Regular("tanh"), + Mf::Acos => Function::Regular("acos"), + Mf::Asin => Function::Regular("asin"), + Mf::Atan => Function::Regular("atan"), + Mf::Atan2 => Function::Regular("atan2"), + Mf::Asinh => Function::Regular("asinh"), + Mf::Acosh => Function::Regular("acosh"), + Mf::Atanh => Function::Regular("atanh"), + Mf::Radians => Function::Regular("radians"), + Mf::Degrees => Function::Regular("degrees"), + // decomposition + Mf::Ceil => Function::Regular("ceil"), + Mf::Floor => Function::Regular("floor"), + Mf::Round => Function::Regular("round"), + Mf::Fract => Function::Regular("fract"), + Mf::Trunc => Function::Regular("trunc"), + Mf::Modf => Function::Regular("modf"), + Mf::Frexp => Function::Regular("frexp"), + Mf::Ldexp => Function::Regular("ldexp"), + // exponent + Mf::Exp => Function::Regular("exp"), + Mf::Exp2 => Function::Regular("exp2"), + Mf::Log => Function::Regular("log"), + Mf::Log2 => Function::Regular("log2"), + Mf::Pow => Function::Regular("pow"), + // geometry + Mf::Dot => Function::Regular("dot"), + Mf::Outer => Function::Regular("outerProduct"), + Mf::Cross => Function::Regular("cross"), + Mf::Distance => Function::Regular("distance"), + Mf::Length => Function::Regular("length"), + Mf::Normalize => Function::Regular("normalize"), + Mf::FaceForward => Function::Regular("faceForward"), + Mf::Reflect => Function::Regular("reflect"), + Mf::Refract => Function::Regular("refract"), + // computational + Mf::Sign => Function::Regular("sign"), + Mf::Fma => Function::Regular("fma"), + Mf::Mix => Function::Regular("mix"), + Mf::Step => Function::Regular("step"), + Mf::SmoothStep => Function::Regular("smoothstep"), + Mf::Sqrt => Function::Regular("sqrt"), + Mf::InverseSqrt => Function::Regular("inverseSqrt"), + Mf::Transpose => Function::Regular("transpose"), + Mf::Determinant => Function::Regular("determinant"), + // bits + Mf::CountTrailingZeros => Function::Regular("countTrailingZeros"), + Mf::CountLeadingZeros => Function::Regular("countLeadingZeros"), + Mf::CountOneBits => Function::Regular("countOneBits"), + Mf::ReverseBits => Function::Regular("reverseBits"), + Mf::ExtractBits => Function::Regular("extractBits"), + Mf::InsertBits => Function::Regular("insertBits"), + Mf::FindLsb => Function::Regular("firstTrailingBit"), + Mf::FindMsb => Function::Regular("firstLeadingBit"), + // data packing + Mf::Pack4x8snorm => Function::Regular("pack4x8snorm"), + Mf::Pack4x8unorm => Function::Regular("pack4x8unorm"), + Mf::Pack2x16snorm => Function::Regular("pack2x16snorm"), + Mf::Pack2x16unorm => Function::Regular("pack2x16unorm"), + Mf::Pack2x16float => Function::Regular("pack2x16float"), + // data unpacking + Mf::Unpack4x8snorm => Function::Regular("unpack4x8snorm"), + Mf::Unpack4x8unorm => Function::Regular("unpack4x8unorm"), + Mf::Unpack2x16snorm => Function::Regular("unpack2x16snorm"), + Mf::Unpack2x16unorm => Function::Regular("unpack2x16unorm"), + Mf::Unpack2x16float => Function::Regular("unpack2x16float"), + Mf::Inverse => { + return Err(Error::UnsupportedMathFunction(fun)); + } + }; + + match function { + Function::Regular(fun_name) => { + write!(self.out, "{fun_name}(")?; + self.write_expr(module, arg, func_ctx)?; + for arg in IntoIterator::into_iter([arg1, arg2, arg3]).flatten() { + write!(self.out, ", ")?; + self.write_expr(module, arg, func_ctx)?; + } + write!(self.out, ")")? + } + } + } + Expression::Swizzle { + size, + vector, + pattern, + } => { + self.write_expr(module, vector, func_ctx)?; + write!(self.out, ".")?; + for &sc in pattern[..size as usize].iter() { + self.out.write_char(back::COMPONENTS[sc as usize])?; + } + } + Expression::Unary { op, expr } => { + let unary = match op { + crate::UnaryOperator::Negate => "-", + crate::UnaryOperator::Not => { + match *func_ctx.info[expr].ty.inner_with(&module.types) { + TypeInner::Scalar { + kind: crate::ScalarKind::Bool, + .. + } + | TypeInner::Vector { .. } => "!", + _ => "~", + } + } + }; + + write!(self.out, "{unary}(")?; + self.write_expr(module, expr, func_ctx)?; + + write!(self.out, ")")? + } + Expression::Select { + condition, + accept, + reject, + } => { + write!(self.out, "select(")?; + self.write_expr(module, reject, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, accept, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, condition, func_ctx)?; + write!(self.out, ")")? + } + Expression::Derivative { axis, ctrl, expr } => { + use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl}; + let op = match (axis, ctrl) { + (Axis::X, Ctrl::Coarse) => "dpdxCoarse", + (Axis::X, Ctrl::Fine) => "dpdxFine", + (Axis::X, Ctrl::None) => "dpdx", + (Axis::Y, Ctrl::Coarse) => "dpdyCoarse", + (Axis::Y, Ctrl::Fine) => "dpdyFine", + (Axis::Y, Ctrl::None) => "dpdy", + (Axis::Width, Ctrl::Coarse) => "fwidthCoarse", + (Axis::Width, Ctrl::Fine) => "fwidthFine", + (Axis::Width, Ctrl::None) => "fwidth", + }; + write!(self.out, "{op}(")?; + self.write_expr(module, expr, func_ctx)?; + write!(self.out, ")")? + } + Expression::Relational { fun, argument } => { + use crate::RelationalFunction as Rf; + + let fun_name = match fun { + Rf::All => "all", + Rf::Any => "any", + _ => return Err(Error::UnsupportedRelationalFunction(fun)), + }; + write!(self.out, "{fun_name}(")?; + + self.write_expr(module, argument, func_ctx)?; + + write!(self.out, ")")? + } + // Not supported yet + Expression::RayQueryGetIntersection { .. } => unreachable!(), + // Nothing to do here, since call expression already cached + Expression::CallResult(_) + | Expression::AtomicResult { .. } + | Expression::RayQueryProceedResult => {} + } + + Ok(()) + } + + /// Helper method used to write global variables + /// # Notes + /// Always adds a newline + fn write_global( + &mut self, + module: &Module, + global: &crate::GlobalVariable, + handle: Handle<crate::GlobalVariable>, + ) -> BackendResult { + // Write group and binding attributes if present + if let Some(ref binding) = global.binding { + self.write_attributes(&[ + Attribute::Group(binding.group), + Attribute::Binding(binding.binding), + ])?; + writeln!(self.out)?; + } + + // First write global name and address space if supported + write!(self.out, "var")?; + let (address, maybe_access) = address_space_str(global.space); + if let Some(space) = address { + write!(self.out, "<{space}")?; + if let Some(access) = maybe_access { + write!(self.out, ", {access}")?; + } + write!(self.out, ">")?; + } + write!( + self.out, + " {}: ", + &self.names[&NameKey::GlobalVariable(handle)] + )?; + + // Write global type + self.write_type(module, global.ty)?; + + // Write initializer + if let Some(init) = global.init { + write!(self.out, " = ")?; + self.write_constant(module, init)?; + } + + // End with semicolon + writeln!(self.out, ";")?; + + Ok(()) + } + + /// Helper method used to write constants + /// + /// # Notes + /// Doesn't add any newlines or leading/trailing spaces + fn write_constant( + &mut self, + module: &Module, + handle: Handle<crate::Constant>, + ) -> BackendResult { + let constant = &module.constants[handle]; + match constant.inner { + crate::ConstantInner::Scalar { + width: _, + ref value, + } => { + if constant.name.is_some() { + write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?; + } else { + self.write_scalar_value(*value)?; + } + } + crate::ConstantInner::Composite { ty, ref components } => { + self.write_type(module, ty)?; + write!(self.out, "(")?; + + // Write the comma separated constants + for (index, constant) in components.iter().enumerate() { + self.write_constant(module, *constant)?; + // Only write a comma if isn't the last element + if index != components.len().saturating_sub(1) { + // The leading space is for readability only + write!(self.out, ", ")?; + } + } + write!(self.out, ")")? + } + } + + Ok(()) + } + + /// Helper method used to write global constants + /// + /// # Notes + /// Ends in a newline + fn write_global_constant( + &mut self, + module: &Module, + inner: &crate::ConstantInner, + handle: Handle<crate::Constant>, + ) -> BackendResult { + match *inner { + crate::ConstantInner::Scalar { + width: _, + ref value, + } => { + let name = &self.names[&NameKey::Constant(handle)]; + // First write only constant name + write!(self.out, "const {name}: ")?; + // Next write constant type and value + match *value { + crate::ScalarValue::Sint(value) => { + write!(self.out, "i32 = {value}")?; + } + crate::ScalarValue::Uint(value) => { + write!(self.out, "u32 = {value}u")?; + } + crate::ScalarValue::Float(value) => { + // Floats are written using `Debug` instead of `Display` because it always appends the + // decimal part even it's zero + write!(self.out, "f32 = {value:?}")?; + } + crate::ScalarValue::Bool(value) => { + write!(self.out, "bool = {value}")?; + } + }; + // End with semicolon + writeln!(self.out, ";")?; + } + crate::ConstantInner::Composite { ty, ref components } => { + let name = &self.names[&NameKey::Constant(handle)]; + // First write only constant name + write!(self.out, "const {name}: ")?; + // Next write constant type + self.write_type(module, ty)?; + + write!(self.out, " = ")?; + self.write_type(module, ty)?; + + write!(self.out, "(")?; + for (index, constant) in components.iter().enumerate() { + self.write_constant(module, *constant)?; + // Only write a comma if isn't the last element + if index != components.len().saturating_sub(1) { + // The leading space is for readability only + write!(self.out, ", ")?; + } + } + write!(self.out, ");")?; + } + } + // End with extra newline for readability + writeln!(self.out)?; + Ok(()) + } + + // See https://github.com/rust-lang/rust-clippy/issues/4979. + #[allow(clippy::missing_const_for_fn)] + pub fn finish(self) -> W { + self.out + } +} + +fn builtin_str(built_in: crate::BuiltIn) -> Result<&'static str, Error> { + use crate::BuiltIn as Bi; + + Ok(match built_in { + Bi::VertexIndex => "vertex_index", + Bi::InstanceIndex => "instance_index", + Bi::Position { .. } => "position", + Bi::FrontFacing => "front_facing", + Bi::FragDepth => "frag_depth", + Bi::LocalInvocationId => "local_invocation_id", + Bi::LocalInvocationIndex => "local_invocation_index", + Bi::GlobalInvocationId => "global_invocation_id", + Bi::WorkGroupId => "workgroup_id", + Bi::NumWorkGroups => "num_workgroups", + Bi::SampleIndex => "sample_index", + Bi::SampleMask => "sample_mask", + Bi::PrimitiveIndex => "primitive_index", + Bi::ViewIndex => "view_index", + Bi::BaseInstance + | Bi::BaseVertex + | Bi::ClipDistance + | Bi::CullDistance + | Bi::PointSize + | Bi::PointCoord + | Bi::WorkGroupSize => { + return Err(Error::Custom(format!("Unsupported builtin {built_in:?}"))) + } + }) +} + +const fn image_dimension_str(dim: crate::ImageDimension) -> &'static str { + use crate::ImageDimension as IDim; + + match dim { + IDim::D1 => "1d", + IDim::D2 => "2d", + IDim::D3 => "3d", + IDim::Cube => "cube", + } +} + +const fn scalar_kind_str(kind: crate::ScalarKind, width: u8) -> &'static str { + use crate::ScalarKind as Sk; + + match (kind, width) { + (Sk::Float, 8) => "f64", + (Sk::Float, 4) => "f32", + (Sk::Sint, 4) => "i32", + (Sk::Uint, 4) => "u32", + (Sk::Bool, 1) => "bool", + _ => unreachable!(), + } +} + +const fn storage_format_str(format: crate::StorageFormat) -> &'static str { + use crate::StorageFormat as Sf; + + match format { + Sf::R8Unorm => "r8unorm", + Sf::R8Snorm => "r8snorm", + Sf::R8Uint => "r8uint", + Sf::R8Sint => "r8sint", + Sf::R16Uint => "r16uint", + Sf::R16Sint => "r16sint", + Sf::R16Float => "r16float", + Sf::Rg8Unorm => "rg8unorm", + Sf::Rg8Snorm => "rg8snorm", + Sf::Rg8Uint => "rg8uint", + Sf::Rg8Sint => "rg8sint", + Sf::R32Uint => "r32uint", + Sf::R32Sint => "r32sint", + Sf::R32Float => "r32float", + Sf::Rg16Uint => "rg16uint", + Sf::Rg16Sint => "rg16sint", + Sf::Rg16Float => "rg16float", + Sf::Rgba8Unorm => "rgba8unorm", + Sf::Rgba8Snorm => "rgba8snorm", + Sf::Rgba8Uint => "rgba8uint", + Sf::Rgba8Sint => "rgba8sint", + Sf::Rgb10a2Unorm => "rgb10a2unorm", + Sf::Rg11b10Float => "rg11b10float", + Sf::Rg32Uint => "rg32uint", + Sf::Rg32Sint => "rg32sint", + Sf::Rg32Float => "rg32float", + Sf::Rgba16Uint => "rgba16uint", + Sf::Rgba16Sint => "rgba16sint", + Sf::Rgba16Float => "rgba16float", + Sf::Rgba32Uint => "rgba32uint", + Sf::Rgba32Sint => "rgba32sint", + Sf::Rgba32Float => "rgba32float", + Sf::R16Unorm => "r16unorm", + Sf::R16Snorm => "r16snorm", + Sf::Rg16Unorm => "rg16unorm", + Sf::Rg16Snorm => "rg16snorm", + Sf::Rgba16Unorm => "rgba16unorm", + Sf::Rgba16Snorm => "rgba16snorm", + } +} + +/// Helper function that returns the string corresponding to the WGSL interpolation qualifier +const fn interpolation_str(interpolation: crate::Interpolation) -> &'static str { + use crate::Interpolation as I; + + match interpolation { + I::Perspective => "perspective", + I::Linear => "linear", + I::Flat => "flat", + } +} + +/// Return the WGSL auxiliary qualifier for the given sampling value. +const fn sampling_str(sampling: crate::Sampling) -> &'static str { + use crate::Sampling as S; + + match sampling { + S::Center => "", + S::Centroid => "centroid", + S::Sample => "sample", + } +} + +const fn address_space_str( + space: crate::AddressSpace, +) -> (Option<&'static str>, Option<&'static str>) { + use crate::AddressSpace as As; + + ( + Some(match space { + As::Private => "private", + As::Uniform => "uniform", + As::Storage { access } => { + if access.contains(crate::StorageAccess::STORE) { + return (Some("storage"), Some("read_write")); + } else { + "storage" + } + } + As::PushConstant => "push_constant", + As::WorkGroup => "workgroup", + As::Handle => return (None, None), + As::Function => "function", + }), + None, + ) +} + +fn map_binding_to_attribute( + binding: &crate::Binding, + scalar_kind: Option<crate::ScalarKind>, +) -> Vec<Attribute> { + match *binding { + crate::Binding::BuiltIn(built_in) => { + if let crate::BuiltIn::Position { invariant: true } = built_in { + vec![Attribute::BuiltIn(built_in), Attribute::Invariant] + } else { + vec![Attribute::BuiltIn(built_in)] + } + } + crate::Binding::Location { + location, + interpolation, + sampling, + } => match scalar_kind { + Some(crate::ScalarKind::Float) => vec![ + Attribute::Location(location), + Attribute::Interpolate(interpolation, sampling), + ], + _ => vec![Attribute::Location(location)], + }, + } +} |