diff options
Diffstat (limited to 'third_party/rust/naga/src/back/hlsl/help.rs')
-rw-r--r-- | third_party/rust/naga/src/back/hlsl/help.rs | 150 |
1 files changed, 149 insertions, 1 deletions
diff --git a/third_party/rust/naga/src/back/hlsl/help.rs b/third_party/rust/naga/src/back/hlsl/help.rs index fa6062a1ad..4dd9ea5987 100644 --- a/third_party/rust/naga/src/back/hlsl/help.rs +++ b/third_party/rust/naga/src/back/hlsl/help.rs @@ -26,7 +26,11 @@ int dim_1d = NagaDimensions1D(image_1d); ``` */ -use super::{super::FunctionCtx, BackendResult}; +use super::{ + super::FunctionCtx, + writer::{EXTRACT_BITS_FUNCTION, INSERT_BITS_FUNCTION}, + BackendResult, +}; use crate::{arena::Handle, proc::NameKey}; use std::fmt::Write; @@ -59,6 +63,13 @@ pub(super) struct WrappedMatCx2 { pub(super) columns: crate::VectorSize, } +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +pub(super) struct WrappedMath { + pub(super) fun: crate::MathFunction, + pub(super) scalar: crate::Scalar, + pub(super) components: Option<u32>, +} + /// HLSL backend requires its own `ImageQuery` enum. /// /// It is used inside `WrappedImageQuery` and should be unique per ImageQuery function. @@ -851,12 +862,149 @@ impl<'a, W: Write> super::Writer<'a, W> { Ok(()) } + pub(super) fn write_wrapped_math_functions( + &mut self, + module: &crate::Module, + func_ctx: &FunctionCtx, + ) -> BackendResult { + for (_, expression) in func_ctx.expressions.iter() { + if let crate::Expression::Math { + fun, + arg, + arg1: _arg1, + arg2: _arg2, + arg3: _arg3, + } = *expression + { + match fun { + crate::MathFunction::ExtractBits => { + // The behavior of our extractBits polyfill is undefined if offset + count > bit_width. We need + // to first sanitize the offset and count first. If we don't do this, we will get out-of-spec + // values if the extracted range is not within the bit width. + // + // This encodes the exact formula specified by the wgsl spec: + // https://gpuweb.github.io/gpuweb/wgsl/#extractBits-unsigned-builtin + // + // w = sizeof(x) * 8 + // o = min(offset, w) + // c = min(count, w - o) + // + // bitfieldExtract(x, o, c) + let arg_ty = func_ctx.resolve_type(arg, &module.types); + let scalar = arg_ty.scalar().unwrap(); + let components = arg_ty.components(); + + let wrapped = WrappedMath { + fun, + scalar, + components, + }; + + if !self.wrapped.math.insert(wrapped) { + continue; + } + + // Write return type + self.write_value_type(module, arg_ty)?; + + let scalar_width: u8 = scalar.width * 8; + + // Write function name and parameters + writeln!(self.out, " {EXTRACT_BITS_FUNCTION}(")?; + write!(self.out, " ")?; + self.write_value_type(module, arg_ty)?; + writeln!(self.out, " e,")?; + writeln!(self.out, " uint offset,")?; + writeln!(self.out, " uint count")?; + writeln!(self.out, ") {{")?; + + // Write function body + writeln!(self.out, " uint w = {scalar_width};")?; + writeln!(self.out, " uint o = min(offset, w);")?; + writeln!(self.out, " uint c = min(count, w - o);")?; + writeln!( + self.out, + " return (c == 0 ? 0 : (e << (w - c - o)) >> (w - c));" + )?; + + // End of function body + writeln!(self.out, "}}")?; + } + crate::MathFunction::InsertBits => { + // The behavior of our insertBits polyfill has the same constraints as the extractBits polyfill. + + let arg_ty = func_ctx.resolve_type(arg, &module.types); + let scalar = arg_ty.scalar().unwrap(); + let components = arg_ty.components(); + + let wrapped = WrappedMath { + fun, + scalar, + components, + }; + + if !self.wrapped.math.insert(wrapped) { + continue; + } + + // Write return type + self.write_value_type(module, arg_ty)?; + + let scalar_width: u8 = scalar.width * 8; + let scalar_max: u64 = match scalar.width { + 1 => 0xFF, + 2 => 0xFFFF, + 4 => 0xFFFFFFFF, + 8 => 0xFFFFFFFFFFFFFFFF, + _ => unreachable!(), + }; + + // Write function name and parameters + writeln!(self.out, " {INSERT_BITS_FUNCTION}(")?; + write!(self.out, " ")?; + self.write_value_type(module, arg_ty)?; + writeln!(self.out, " e,")?; + write!(self.out, " ")?; + self.write_value_type(module, arg_ty)?; + writeln!(self.out, " newbits,")?; + writeln!(self.out, " uint offset,")?; + writeln!(self.out, " uint count")?; + writeln!(self.out, ") {{")?; + + // Write function body + writeln!(self.out, " uint w = {scalar_width}u;")?; + writeln!(self.out, " uint o = min(offset, w);")?; + writeln!(self.out, " uint c = min(count, w - o);")?; + + // The `u` suffix on the literals is _extremely_ important. Otherwise it will use + // i32 shifting instead of the intended u32 shifting. + writeln!( + self.out, + " uint mask = (({scalar_max}u >> ({scalar_width}u - c)) << o);" + )?; + writeln!( + self.out, + " return (c == 0 ? e : ((e & ~mask) | ((newbits << o) & mask)));" + )?; + + // End of function body + writeln!(self.out, "}}")?; + } + _ => {} + } + } + } + + Ok(()) + } + /// Helper function that writes various wrapped functions pub(super) fn write_wrapped_functions( &mut self, module: &crate::Module, func_ctx: &FunctionCtx, ) -> BackendResult { + self.write_wrapped_math_functions(module, func_ctx)?; self.write_wrapped_compose_functions(module, func_ctx.expressions)?; for (handle, _) in func_ctx.expressions.iter() { |