diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 01:13:27 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 01:13:27 +0000 |
commit | 40a355a42d4a9444dc753c04c6608dade2f06a23 (patch) | |
tree | 871fc667d2de662f171103ce5ec067014ef85e61 /third_party/rust/naga/src/back/hlsl | |
parent | Adding upstream version 124.0.1. (diff) | |
download | firefox-upstream/125.0.1.tar.xz firefox-upstream/125.0.1.zip |
Adding upstream version 125.0.1.upstream/125.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/naga/src/back/hlsl')
-rw-r--r-- | third_party/rust/naga/src/back/hlsl/conv.rs | 12 | ||||
-rw-r--r-- | third_party/rust/naga/src/back/hlsl/help.rs | 150 | ||||
-rw-r--r-- | third_party/rust/naga/src/back/hlsl/keywords.rs | 2 | ||||
-rw-r--r-- | third_party/rust/naga/src/back/hlsl/mod.rs | 2 | ||||
-rw-r--r-- | third_party/rust/naga/src/back/hlsl/storage.rs | 84 | ||||
-rw-r--r-- | third_party/rust/naga/src/back/hlsl/writer.rs | 183 |
6 files changed, 306 insertions, 127 deletions
diff --git a/third_party/rust/naga/src/back/hlsl/conv.rs b/third_party/rust/naga/src/back/hlsl/conv.rs index b6918ddc42..2a6db35db8 100644 --- a/third_party/rust/naga/src/back/hlsl/conv.rs +++ b/third_party/rust/naga/src/back/hlsl/conv.rs @@ -21,8 +21,16 @@ impl crate::Scalar { /// <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-scalar> pub(super) const fn to_hlsl_str(self) -> Result<&'static str, Error> { match self.kind { - crate::ScalarKind::Sint => Ok("int"), - crate::ScalarKind::Uint => Ok("uint"), + crate::ScalarKind::Sint => match self.width { + 4 => Ok("int"), + 8 => Ok("int64_t"), + _ => Err(Error::UnsupportedScalar(self)), + }, + crate::ScalarKind::Uint => match self.width { + 4 => Ok("uint"), + 8 => Ok("uint64_t"), + _ => Err(Error::UnsupportedScalar(self)), + }, crate::ScalarKind::Float => match self.width { 2 => Ok("half"), 4 => Ok("float"), diff --git a/third_party/rust/naga/src/back/hlsl/help.rs b/third_party/rust/naga/src/back/hlsl/help.rs index fa6062a1ad..4dd9ea5987 100644 --- a/third_party/rust/naga/src/back/hlsl/help.rs +++ b/third_party/rust/naga/src/back/hlsl/help.rs @@ -26,7 +26,11 @@ int dim_1d = NagaDimensions1D(image_1d); ``` */ -use super::{super::FunctionCtx, BackendResult}; +use super::{ + super::FunctionCtx, + writer::{EXTRACT_BITS_FUNCTION, INSERT_BITS_FUNCTION}, + BackendResult, +}; use crate::{arena::Handle, proc::NameKey}; use std::fmt::Write; @@ -59,6 +63,13 @@ pub(super) struct WrappedMatCx2 { pub(super) columns: crate::VectorSize, } +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +pub(super) struct WrappedMath { + pub(super) fun: crate::MathFunction, + pub(super) scalar: crate::Scalar, + pub(super) components: Option<u32>, +} + /// HLSL backend requires its own `ImageQuery` enum. /// /// It is used inside `WrappedImageQuery` and should be unique per ImageQuery function. @@ -851,12 +862,149 @@ impl<'a, W: Write> super::Writer<'a, W> { Ok(()) } + pub(super) fn write_wrapped_math_functions( + &mut self, + module: &crate::Module, + func_ctx: &FunctionCtx, + ) -> BackendResult { + for (_, expression) in func_ctx.expressions.iter() { + if let crate::Expression::Math { + fun, + arg, + arg1: _arg1, + arg2: _arg2, + arg3: _arg3, + } = *expression + { + match fun { + crate::MathFunction::ExtractBits => { + // The behavior of our extractBits polyfill is undefined if offset + count > bit_width. We need + // to first sanitize the offset and count first. If we don't do this, we will get out-of-spec + // values if the extracted range is not within the bit width. + // + // This encodes the exact formula specified by the wgsl spec: + // https://gpuweb.github.io/gpuweb/wgsl/#extractBits-unsigned-builtin + // + // w = sizeof(x) * 8 + // o = min(offset, w) + // c = min(count, w - o) + // + // bitfieldExtract(x, o, c) + let arg_ty = func_ctx.resolve_type(arg, &module.types); + let scalar = arg_ty.scalar().unwrap(); + let components = arg_ty.components(); + + let wrapped = WrappedMath { + fun, + scalar, + components, + }; + + if !self.wrapped.math.insert(wrapped) { + continue; + } + + // Write return type + self.write_value_type(module, arg_ty)?; + + let scalar_width: u8 = scalar.width * 8; + + // Write function name and parameters + writeln!(self.out, " {EXTRACT_BITS_FUNCTION}(")?; + write!(self.out, " ")?; + self.write_value_type(module, arg_ty)?; + writeln!(self.out, " e,")?; + writeln!(self.out, " uint offset,")?; + writeln!(self.out, " uint count")?; + writeln!(self.out, ") {{")?; + + // Write function body + writeln!(self.out, " uint w = {scalar_width};")?; + writeln!(self.out, " uint o = min(offset, w);")?; + writeln!(self.out, " uint c = min(count, w - o);")?; + writeln!( + self.out, + " return (c == 0 ? 0 : (e << (w - c - o)) >> (w - c));" + )?; + + // End of function body + writeln!(self.out, "}}")?; + } + crate::MathFunction::InsertBits => { + // The behavior of our insertBits polyfill has the same constraints as the extractBits polyfill. + + let arg_ty = func_ctx.resolve_type(arg, &module.types); + let scalar = arg_ty.scalar().unwrap(); + let components = arg_ty.components(); + + let wrapped = WrappedMath { + fun, + scalar, + components, + }; + + if !self.wrapped.math.insert(wrapped) { + continue; + } + + // Write return type + self.write_value_type(module, arg_ty)?; + + let scalar_width: u8 = scalar.width * 8; + let scalar_max: u64 = match scalar.width { + 1 => 0xFF, + 2 => 0xFFFF, + 4 => 0xFFFFFFFF, + 8 => 0xFFFFFFFFFFFFFFFF, + _ => unreachable!(), + }; + + // Write function name and parameters + writeln!(self.out, " {INSERT_BITS_FUNCTION}(")?; + write!(self.out, " ")?; + self.write_value_type(module, arg_ty)?; + writeln!(self.out, " e,")?; + write!(self.out, " ")?; + self.write_value_type(module, arg_ty)?; + writeln!(self.out, " newbits,")?; + writeln!(self.out, " uint offset,")?; + writeln!(self.out, " uint count")?; + writeln!(self.out, ") {{")?; + + // Write function body + writeln!(self.out, " uint w = {scalar_width}u;")?; + writeln!(self.out, " uint o = min(offset, w);")?; + writeln!(self.out, " uint c = min(count, w - o);")?; + + // The `u` suffix on the literals is _extremely_ important. Otherwise it will use + // i32 shifting instead of the intended u32 shifting. + writeln!( + self.out, + " uint mask = (({scalar_max}u >> ({scalar_width}u - c)) << o);" + )?; + writeln!( + self.out, + " return (c == 0 ? e : ((e & ~mask) | ((newbits << o) & mask)));" + )?; + + // End of function body + writeln!(self.out, "}}")?; + } + _ => {} + } + } + } + + Ok(()) + } + /// Helper function that writes various wrapped functions pub(super) fn write_wrapped_functions( &mut self, module: &crate::Module, func_ctx: &FunctionCtx, ) -> BackendResult { + self.write_wrapped_math_functions(module, func_ctx)?; self.write_wrapped_compose_functions(module, func_ctx.expressions)?; for (handle, _) in func_ctx.expressions.iter() { diff --git a/third_party/rust/naga/src/back/hlsl/keywords.rs b/third_party/rust/naga/src/back/hlsl/keywords.rs index 059e533ff7..2cb715c42c 100644 --- a/third_party/rust/naga/src/back/hlsl/keywords.rs +++ b/third_party/rust/naga/src/back/hlsl/keywords.rs @@ -817,6 +817,8 @@ pub const RESERVED: &[&str] = &[ // Naga utilities super::writer::MODF_FUNCTION, super::writer::FREXP_FUNCTION, + super::writer::EXTRACT_BITS_FUNCTION, + super::writer::INSERT_BITS_FUNCTION, ]; // DXC scalar types, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/tools/clang/lib/AST/ASTContextHLSL.cpp#L48-L254 diff --git a/third_party/rust/naga/src/back/hlsl/mod.rs b/third_party/rust/naga/src/back/hlsl/mod.rs index 37ddbd3d67..f37a223f47 100644 --- a/third_party/rust/naga/src/back/hlsl/mod.rs +++ b/third_party/rust/naga/src/back/hlsl/mod.rs @@ -256,6 +256,7 @@ struct Wrapped { constructors: crate::FastHashSet<help::WrappedConstructor>, struct_matrix_access: crate::FastHashSet<help::WrappedStructMatrixAccess>, mat_cx2s: crate::FastHashSet<help::WrappedMatCx2>, + math: crate::FastHashSet<help::WrappedMath>, } impl Wrapped { @@ -265,6 +266,7 @@ impl Wrapped { self.constructors.clear(); self.struct_matrix_access.clear(); self.mat_cx2s.clear(); + self.math.clear(); } } diff --git a/third_party/rust/naga/src/back/hlsl/storage.rs b/third_party/rust/naga/src/back/hlsl/storage.rs index 1b8a6ec12d..4d3a6af56d 100644 --- a/third_party/rust/naga/src/back/hlsl/storage.rs +++ b/third_party/rust/naga/src/back/hlsl/storage.rs @@ -32,6 +32,16 @@ The [`temp_access_chain`] field is a member of [`Writer`] solely to allow re-use of the `Vec`'s dynamic allocation. Its value is no longer needed once HLSL for the access has been generated. +Note about DXC and Load/Store functions: + +DXC's HLSL has a generic [`Load` and `Store`] function for [`ByteAddressBuffer`] and +[`RWByteAddressBuffer`]. This is not available in FXC's HLSL, so we use +it only for types that are only available in DXC. Notably 64 and 16 bit types. + +FXC's HLSL has functions Load, Load2, Load3, and Load4 and Store, Store2, Store3, Store4. +This loads/stores a vector of length 1, 2, 3, or 4. We use that for 32bit types, bitcasting to the +correct type if necessary. + [`Storage`]: crate::AddressSpace::Storage [`ByteAddressBuffer`]: https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-object-byteaddressbuffer [`RWByteAddressBuffer`]: https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-object-rwbyteaddressbuffer @@ -42,6 +52,7 @@ needed once HLSL for the access has been generated. [`Writer::temp_access_chain`]: super::Writer::temp_access_chain [`temp_access_chain`]: super::Writer::temp_access_chain [`Writer`]: super::Writer +[`Load` and `Store`]: https://github.com/microsoft/DirectXShaderCompiler/wiki/ByteAddressBuffer-Load-Store-Additions */ use super::{super::FunctionCtx, BackendResult, Error}; @@ -161,20 +172,39 @@ impl<W: fmt::Write> super::Writer<'_, W> { // working around the borrow checker in `self.write_expr` let chain = mem::take(&mut self.temp_access_chain); let var_name = &self.names[&NameKey::GlobalVariable(var_handle)]; - let cast = scalar.kind.to_hlsl_cast(); - write!(self.out, "{cast}({var_name}.Load(")?; + // See note about DXC and Load/Store in the module's documentation. + if scalar.width == 4 { + let cast = scalar.kind.to_hlsl_cast(); + write!(self.out, "{cast}({var_name}.Load(")?; + } else { + let ty = scalar.to_hlsl_str()?; + write!(self.out, "{var_name}.Load<{ty}>(")?; + }; self.write_storage_address(module, &chain, func_ctx)?; - write!(self.out, "))")?; + write!(self.out, ")")?; + if scalar.width == 4 { + write!(self.out, ")")?; + } self.temp_access_chain = chain; } crate::TypeInner::Vector { size, scalar } => { // working around the borrow checker in `self.write_expr` let chain = mem::take(&mut self.temp_access_chain); let var_name = &self.names[&NameKey::GlobalVariable(var_handle)]; - let cast = scalar.kind.to_hlsl_cast(); - write!(self.out, "{}({}.Load{}(", cast, var_name, size as u8)?; + let size = size as u8; + // See note about DXC and Load/Store in the module's documentation. + if scalar.width == 4 { + let cast = scalar.kind.to_hlsl_cast(); + write!(self.out, "{cast}({var_name}.Load{size}(")?; + } else { + let ty = scalar.to_hlsl_str()?; + write!(self.out, "{var_name}.Load<{ty}{size}>(")?; + }; self.write_storage_address(module, &chain, func_ctx)?; - write!(self.out, "))")?; + write!(self.out, ")")?; + if scalar.width == 4 { + write!(self.out, ")")?; + } self.temp_access_chain = chain; } crate::TypeInner::Matrix { @@ -288,26 +318,44 @@ impl<W: fmt::Write> super::Writer<'_, W> { } }; match *ty_resolution.inner_with(&module.types) { - crate::TypeInner::Scalar(_) => { + crate::TypeInner::Scalar(scalar) => { // working around the borrow checker in `self.write_expr` let chain = mem::take(&mut self.temp_access_chain); let var_name = &self.names[&NameKey::GlobalVariable(var_handle)]; - write!(self.out, "{level}{var_name}.Store(")?; - self.write_storage_address(module, &chain, func_ctx)?; - write!(self.out, ", asuint(")?; - self.write_store_value(module, &value, func_ctx)?; - writeln!(self.out, "));")?; + // See note about DXC and Load/Store in the module's documentation. + if scalar.width == 4 { + write!(self.out, "{level}{var_name}.Store(")?; + self.write_storage_address(module, &chain, func_ctx)?; + write!(self.out, ", asuint(")?; + self.write_store_value(module, &value, func_ctx)?; + writeln!(self.out, "));")?; + } else { + write!(self.out, "{level}{var_name}.Store(")?; + self.write_storage_address(module, &chain, func_ctx)?; + write!(self.out, ", ")?; + self.write_store_value(module, &value, func_ctx)?; + writeln!(self.out, ");")?; + } self.temp_access_chain = chain; } - crate::TypeInner::Vector { size, .. } => { + crate::TypeInner::Vector { size, scalar } => { // working around the borrow checker in `self.write_expr` let chain = mem::take(&mut self.temp_access_chain); let var_name = &self.names[&NameKey::GlobalVariable(var_handle)]; - write!(self.out, "{}{}.Store{}(", level, var_name, size as u8)?; - self.write_storage_address(module, &chain, func_ctx)?; - write!(self.out, ", asuint(")?; - self.write_store_value(module, &value, func_ctx)?; - writeln!(self.out, "));")?; + // See note about DXC and Load/Store in the module's documentation. + if scalar.width == 4 { + write!(self.out, "{}{}.Store{}(", level, var_name, size as u8)?; + self.write_storage_address(module, &chain, func_ctx)?; + write!(self.out, ", asuint(")?; + self.write_store_value(module, &value, func_ctx)?; + writeln!(self.out, "));")?; + } else { + write!(self.out, "{}{}.Store(", level, var_name)?; + self.write_storage_address(module, &chain, func_ctx)?; + write!(self.out, ", ")?; + self.write_store_value(module, &value, func_ctx)?; + writeln!(self.out, ");")?; + } self.temp_access_chain = chain; } crate::TypeInner::Matrix { diff --git a/third_party/rust/naga/src/back/hlsl/writer.rs b/third_party/rust/naga/src/back/hlsl/writer.rs index 43f7212837..4ba856946b 100644 --- a/third_party/rust/naga/src/back/hlsl/writer.rs +++ b/third_party/rust/naga/src/back/hlsl/writer.rs @@ -19,6 +19,8 @@ const SPECIAL_OTHER: &str = "other"; pub(crate) const MODF_FUNCTION: &str = "naga_modf"; pub(crate) const FREXP_FUNCTION: &str = "naga_frexp"; +pub(crate) const EXTRACT_BITS_FUNCTION: &str = "naga_extractBits"; +pub(crate) const INSERT_BITS_FUNCTION: &str = "naga_insertBits"; struct EpStructMember { name: String, @@ -125,14 +127,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.need_bake_expressions.insert(fun_handle); } - if let Expression::Math { - fun, - arg, - arg1, - arg2, - arg3, - } = *expr - { + if let Expression::Math { fun, arg, .. } = *expr { match fun { crate::MathFunction::Asinh | crate::MathFunction::Acosh @@ -149,17 +144,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { | crate::MathFunction::Pack4x8unorm => { self.need_bake_expressions.insert(arg); } - crate::MathFunction::ExtractBits => { - self.need_bake_expressions.insert(arg); - self.need_bake_expressions.insert(arg1.unwrap()); - self.need_bake_expressions.insert(arg2.unwrap()); - } - crate::MathFunction::InsertBits => { - self.need_bake_expressions.insert(arg); - self.need_bake_expressions.insert(arg1.unwrap()); - self.need_bake_expressions.insert(arg2.unwrap()); - self.need_bake_expressions.insert(arg3.unwrap()); - } crate::MathFunction::CountLeadingZeros => { let inner = info[fun_handle].ty.inner_with(&module.types); if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() { @@ -2038,6 +2022,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { crate::Literal::F32(value) => write!(self.out, "{value:?}")?, crate::Literal::U32(value) => write!(self.out, "{}u", value)?, crate::Literal::I32(value) => write!(self.out, "{}", value)?, + crate::Literal::U64(value) => write!(self.out, "{}uL", value)?, crate::Literal::I64(value) => write!(self.out, "{}L", value)?, crate::Literal::Bool(value) => write!(self.out, "{}", value)?, crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => { @@ -2567,7 +2552,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { convert, } => { let inner = func_ctx.resolve_type(expr, &module.types); - match convert { + let close_paren = match convert { Some(dst_width) => { let scalar = crate::Scalar { kind, @@ -2600,13 +2585,21 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { ))); } }; + true } None => { - write!(self.out, "{}(", kind.to_hlsl_cast(),)?; + if inner.scalar_width() == Some(64) { + false + } else { + write!(self.out, "{}(", kind.to_hlsl_cast(),)?; + true + } } - } + }; self.write_expr(module, expr, func_ctx)?; - write!(self.out, ")")?; + if close_paren { + write!(self.out, ")")?; + } } Expression::Math { fun, @@ -2620,8 +2613,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { enum Function { Asincosh { is_sin: bool }, Atanh, - ExtractBits, - InsertBits, Pack2x16float, Pack2x16snorm, Pack2x16unorm, @@ -2705,8 +2696,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { Mf::ReverseBits => Function::MissingIntOverload("reversebits"), Mf::FindLsb => Function::MissingIntReturnType("firstbitlow"), Mf::FindMsb => Function::MissingIntReturnType("firstbithigh"), - Mf::ExtractBits => Function::ExtractBits, - Mf::InsertBits => Function::InsertBits, + Mf::ExtractBits => Function::Regular(EXTRACT_BITS_FUNCTION), + Mf::InsertBits => Function::Regular(INSERT_BITS_FUNCTION), // Data Packing Mf::Pack2x16float => Function::Pack2x16float, Mf::Pack2x16snorm => Function::Pack2x16snorm, @@ -2742,70 +2733,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.write_expr(module, arg, func_ctx)?; write!(self.out, "))")?; } - Function::ExtractBits => { - // e: T, - // offset: u32, - // count: u32 - // T is u32 or i32 or vecN<u32> or vecN<i32> - if let (Some(offset), Some(count)) = (arg1, arg2) { - let scalar_width: u8 = 32; - // Works for signed and unsigned - // (count == 0 ? 0 : (e << (32 - count - offset)) >> (32 - count)) - write!(self.out, "(")?; - self.write_expr(module, count, func_ctx)?; - write!(self.out, " == 0 ? 0 : (")?; - self.write_expr(module, arg, func_ctx)?; - write!(self.out, " << ({scalar_width} - ")?; - self.write_expr(module, count, func_ctx)?; - write!(self.out, " - ")?; - self.write_expr(module, offset, func_ctx)?; - write!(self.out, ")) >> ({scalar_width} - ")?; - self.write_expr(module, count, func_ctx)?; - write!(self.out, "))")?; - } - } - Function::InsertBits => { - // e: T, - // newbits: T, - // offset: u32, - // count: u32 - // returns T - // T is i32, u32, vecN<i32>, or vecN<u32> - if let (Some(newbits), Some(offset), Some(count)) = (arg1, arg2, arg3) { - let scalar_width: u8 = 32; - let scalar_max: u32 = 0xFFFFFFFF; - // mask = ((0xFFFFFFFFu >> (32 - count)) << offset) - // (count == 0 ? e : ((e & ~mask) | ((newbits << offset) & mask))) - write!(self.out, "(")?; - self.write_expr(module, count, func_ctx)?; - write!(self.out, " == 0 ? ")?; - self.write_expr(module, arg, func_ctx)?; - write!(self.out, " : ")?; - write!(self.out, "(")?; - self.write_expr(module, arg, func_ctx)?; - write!(self.out, " & ~")?; - // mask - write!(self.out, "(({scalar_max}u >> ({scalar_width}u - ")?; - self.write_expr(module, count, func_ctx)?; - write!(self.out, ")) << ")?; - self.write_expr(module, offset, func_ctx)?; - write!(self.out, ")")?; - // end mask - write!(self.out, ") | ((")?; - self.write_expr(module, newbits, func_ctx)?; - write!(self.out, " << ")?; - self.write_expr(module, offset, func_ctx)?; - write!(self.out, ") & ")?; - // // mask - write!(self.out, "(({scalar_max}u >> ({scalar_width}u - ")?; - self.write_expr(module, count, func_ctx)?; - write!(self.out, ")) << ")?; - self.write_expr(module, offset, func_ctx)?; - write!(self.out, ")")?; - // // end mask - write!(self.out, "))")?; - } - } Function::Pack2x16float => { write!(self.out, "(f32tof16(")?; self.write_expr(module, arg, func_ctx)?; @@ -2944,9 +2871,15 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } write!(self.out, ")")? } + // These overloads are only missing on FXC, so this is only needed for 32bit types, + // as non-32bit types are DXC only. Function::MissingIntOverload(fun_name) => { - let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar_kind(); - if let Some(ScalarKind::Sint) = scalar_kind { + let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar(); + if let Some(crate::Scalar { + kind: ScalarKind::Sint, + width: 4, + }) = scalar_kind + { write!(self.out, "asint({fun_name}(asuint(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, ")))")?; @@ -2956,9 +2889,15 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { write!(self.out, ")")?; } } + // These overloads are only missing on FXC, so this is only needed for 32bit types, + // as non-32bit types are DXC only. Function::MissingIntReturnType(fun_name) => { - let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar_kind(); - if let Some(ScalarKind::Sint) = scalar_kind { + let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar(); + if let Some(crate::Scalar { + kind: ScalarKind::Sint, + width: 4, + }) = scalar_kind + { write!(self.out, "asint({fun_name}(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, "))")?; @@ -2977,23 +2916,38 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { crate::VectorSize::Quad => ".xxxx", }; - if let ScalarKind::Uint = scalar.kind { - write!(self.out, "min((32u){s}, firstbitlow(")?; + let scalar_width_bits = scalar.width * 8; + + if scalar.kind == ScalarKind::Uint || scalar.width != 4 { + write!( + self.out, + "min(({scalar_width_bits}u){s}, firstbitlow(" + )?; self.write_expr(module, arg, func_ctx)?; write!(self.out, "))")?; } else { - write!(self.out, "asint(min((32u){s}, firstbitlow(")?; + // This is only needed for the FXC path, on 32bit signed integers. + write!( + self.out, + "asint(min(({scalar_width_bits}u){s}, firstbitlow(" + )?; self.write_expr(module, arg, func_ctx)?; write!(self.out, ")))")?; } } TypeInner::Scalar(scalar) => { - if let ScalarKind::Uint = scalar.kind { - write!(self.out, "min(32u, firstbitlow(")?; + let scalar_width_bits = scalar.width * 8; + + if scalar.kind == ScalarKind::Uint || scalar.width != 4 { + write!(self.out, "min({scalar_width_bits}u, firstbitlow(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, "))")?; } else { - write!(self.out, "asint(min(32u, firstbitlow(")?; + // This is only needed for the FXC path, on 32bit signed integers. + write!( + self.out, + "asint(min({scalar_width_bits}u, firstbitlow(" + )?; self.write_expr(module, arg, func_ctx)?; write!(self.out, ")))")?; } @@ -3012,30 +2966,47 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { crate::VectorSize::Quad => ".xxxx", }; - if let ScalarKind::Uint = scalar.kind { - write!(self.out, "((31u){s} - firstbithigh(")?; + // scalar width - 1 + let constant = scalar.width * 8 - 1; + + if scalar.kind == ScalarKind::Uint { + write!(self.out, "(({constant}u){s} - firstbithigh(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, "))")?; } else { + let conversion_func = match scalar.width { + 4 => "asint", + _ => "", + }; write!(self.out, "(")?; self.write_expr(module, arg, func_ctx)?; write!( self.out, - " < (0){s} ? (0){s} : (31){s} - asint(firstbithigh(" + " < (0){s} ? (0){s} : ({constant}){s} - {conversion_func}(firstbithigh(" )?; self.write_expr(module, arg, func_ctx)?; write!(self.out, ")))")?; } } TypeInner::Scalar(scalar) => { + // scalar width - 1 + let constant = scalar.width * 8 - 1; + if let ScalarKind::Uint = scalar.kind { - write!(self.out, "(31u - firstbithigh(")?; + write!(self.out, "({constant}u - firstbithigh(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, "))")?; } else { + let conversion_func = match scalar.width { + 4 => "asint", + _ => "", + }; write!(self.out, "(")?; self.write_expr(module, arg, func_ctx)?; - write!(self.out, " < 0 ? 0 : 31 - asint(firstbithigh(")?; + write!( + self.out, + " < 0 ? 0 : {constant} - {conversion_func}(firstbithigh(" + )?; self.write_expr(module, arg, func_ctx)?; write!(self.out, ")))")?; } |