summaryrefslogtreecommitdiffstats
path: root/third_party/rust/naga/src/back/hlsl/writer.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/naga/src/back/hlsl/writer.rs')
-rw-r--r--third_party/rust/naga/src/back/hlsl/writer.rs183
1 files changed, 77 insertions, 106 deletions
diff --git a/third_party/rust/naga/src/back/hlsl/writer.rs b/third_party/rust/naga/src/back/hlsl/writer.rs
index 43f7212837..4ba856946b 100644
--- a/third_party/rust/naga/src/back/hlsl/writer.rs
+++ b/third_party/rust/naga/src/back/hlsl/writer.rs
@@ -19,6 +19,8 @@ const SPECIAL_OTHER: &str = "other";
pub(crate) const MODF_FUNCTION: &str = "naga_modf";
pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
+pub(crate) const EXTRACT_BITS_FUNCTION: &str = "naga_extractBits";
+pub(crate) const INSERT_BITS_FUNCTION: &str = "naga_insertBits";
struct EpStructMember {
name: String,
@@ -125,14 +127,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.need_bake_expressions.insert(fun_handle);
}
- if let Expression::Math {
- fun,
- arg,
- arg1,
- arg2,
- arg3,
- } = *expr
- {
+ if let Expression::Math { fun, arg, .. } = *expr {
match fun {
crate::MathFunction::Asinh
| crate::MathFunction::Acosh
@@ -149,17 +144,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
| crate::MathFunction::Pack4x8unorm => {
self.need_bake_expressions.insert(arg);
}
- crate::MathFunction::ExtractBits => {
- self.need_bake_expressions.insert(arg);
- self.need_bake_expressions.insert(arg1.unwrap());
- self.need_bake_expressions.insert(arg2.unwrap());
- }
- crate::MathFunction::InsertBits => {
- self.need_bake_expressions.insert(arg);
- self.need_bake_expressions.insert(arg1.unwrap());
- self.need_bake_expressions.insert(arg2.unwrap());
- self.need_bake_expressions.insert(arg3.unwrap());
- }
crate::MathFunction::CountLeadingZeros => {
let inner = info[fun_handle].ty.inner_with(&module.types);
if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() {
@@ -2038,6 +2022,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
crate::Literal::U32(value) => write!(self.out, "{}u", value)?,
crate::Literal::I32(value) => write!(self.out, "{}", value)?,
+ crate::Literal::U64(value) => write!(self.out, "{}uL", value)?,
crate::Literal::I64(value) => write!(self.out, "{}L", value)?,
crate::Literal::Bool(value) => write!(self.out, "{}", value)?,
crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
@@ -2567,7 +2552,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
convert,
} => {
let inner = func_ctx.resolve_type(expr, &module.types);
- match convert {
+ let close_paren = match convert {
Some(dst_width) => {
let scalar = crate::Scalar {
kind,
@@ -2600,13 +2585,21 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
)));
}
};
+ true
}
None => {
- write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
+ if inner.scalar_width() == Some(64) {
+ false
+ } else {
+ write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
+ true
+ }
}
- }
+ };
self.write_expr(module, expr, func_ctx)?;
- write!(self.out, ")")?;
+ if close_paren {
+ write!(self.out, ")")?;
+ }
}
Expression::Math {
fun,
@@ -2620,8 +2613,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
enum Function {
Asincosh { is_sin: bool },
Atanh,
- ExtractBits,
- InsertBits,
Pack2x16float,
Pack2x16snorm,
Pack2x16unorm,
@@ -2705,8 +2696,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Mf::ReverseBits => Function::MissingIntOverload("reversebits"),
Mf::FindLsb => Function::MissingIntReturnType("firstbitlow"),
Mf::FindMsb => Function::MissingIntReturnType("firstbithigh"),
- Mf::ExtractBits => Function::ExtractBits,
- Mf::InsertBits => Function::InsertBits,
+ Mf::ExtractBits => Function::Regular(EXTRACT_BITS_FUNCTION),
+ Mf::InsertBits => Function::Regular(INSERT_BITS_FUNCTION),
// Data Packing
Mf::Pack2x16float => Function::Pack2x16float,
Mf::Pack2x16snorm => Function::Pack2x16snorm,
@@ -2742,70 +2733,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))")?;
}
- Function::ExtractBits => {
- // e: T,
- // offset: u32,
- // count: u32
- // T is u32 or i32 or vecN<u32> or vecN<i32>
- if let (Some(offset), Some(count)) = (arg1, arg2) {
- let scalar_width: u8 = 32;
- // Works for signed and unsigned
- // (count == 0 ? 0 : (e << (32 - count - offset)) >> (32 - count))
- write!(self.out, "(")?;
- self.write_expr(module, count, func_ctx)?;
- write!(self.out, " == 0 ? 0 : (")?;
- self.write_expr(module, arg, func_ctx)?;
- write!(self.out, " << ({scalar_width} - ")?;
- self.write_expr(module, count, func_ctx)?;
- write!(self.out, " - ")?;
- self.write_expr(module, offset, func_ctx)?;
- write!(self.out, ")) >> ({scalar_width} - ")?;
- self.write_expr(module, count, func_ctx)?;
- write!(self.out, "))")?;
- }
- }
- Function::InsertBits => {
- // e: T,
- // newbits: T,
- // offset: u32,
- // count: u32
- // returns T
- // T is i32, u32, vecN<i32>, or vecN<u32>
- if let (Some(newbits), Some(offset), Some(count)) = (arg1, arg2, arg3) {
- let scalar_width: u8 = 32;
- let scalar_max: u32 = 0xFFFFFFFF;
- // mask = ((0xFFFFFFFFu >> (32 - count)) << offset)
- // (count == 0 ? e : ((e & ~mask) | ((newbits << offset) & mask)))
- write!(self.out, "(")?;
- self.write_expr(module, count, func_ctx)?;
- write!(self.out, " == 0 ? ")?;
- self.write_expr(module, arg, func_ctx)?;
- write!(self.out, " : ")?;
- write!(self.out, "(")?;
- self.write_expr(module, arg, func_ctx)?;
- write!(self.out, " & ~")?;
- // mask
- write!(self.out, "(({scalar_max}u >> ({scalar_width}u - ")?;
- self.write_expr(module, count, func_ctx)?;
- write!(self.out, ")) << ")?;
- self.write_expr(module, offset, func_ctx)?;
- write!(self.out, ")")?;
- // end mask
- write!(self.out, ") | ((")?;
- self.write_expr(module, newbits, func_ctx)?;
- write!(self.out, " << ")?;
- self.write_expr(module, offset, func_ctx)?;
- write!(self.out, ") & ")?;
- // // mask
- write!(self.out, "(({scalar_max}u >> ({scalar_width}u - ")?;
- self.write_expr(module, count, func_ctx)?;
- write!(self.out, ")) << ")?;
- self.write_expr(module, offset, func_ctx)?;
- write!(self.out, ")")?;
- // // end mask
- write!(self.out, "))")?;
- }
- }
Function::Pack2x16float => {
write!(self.out, "(f32tof16(")?;
self.write_expr(module, arg, func_ctx)?;
@@ -2944,9 +2871,15 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
write!(self.out, ")")?
}
+ // These overloads are only missing on FXC, so this is only needed for 32bit types,
+ // as non-32bit types are DXC only.
Function::MissingIntOverload(fun_name) => {
- let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar_kind();
- if let Some(ScalarKind::Sint) = scalar_kind {
+ let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
+ if let Some(crate::Scalar {
+ kind: ScalarKind::Sint,
+ width: 4,
+ }) = scalar_kind
+ {
write!(self.out, "asint({fun_name}(asuint(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, ")))")?;
@@ -2956,9 +2889,15 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
write!(self.out, ")")?;
}
}
+ // These overloads are only missing on FXC, so this is only needed for 32bit types,
+ // as non-32bit types are DXC only.
Function::MissingIntReturnType(fun_name) => {
- let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar_kind();
- if let Some(ScalarKind::Sint) = scalar_kind {
+ let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
+ if let Some(crate::Scalar {
+ kind: ScalarKind::Sint,
+ width: 4,
+ }) = scalar_kind
+ {
write!(self.out, "asint({fun_name}(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))")?;
@@ -2977,23 +2916,38 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
crate::VectorSize::Quad => ".xxxx",
};
- if let ScalarKind::Uint = scalar.kind {
- write!(self.out, "min((32u){s}, firstbitlow(")?;
+ let scalar_width_bits = scalar.width * 8;
+
+ if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
+ write!(
+ self.out,
+ "min(({scalar_width_bits}u){s}, firstbitlow("
+ )?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))")?;
} else {
- write!(self.out, "asint(min((32u){s}, firstbitlow(")?;
+ // This is only needed for the FXC path, on 32bit signed integers.
+ write!(
+ self.out,
+ "asint(min(({scalar_width_bits}u){s}, firstbitlow("
+ )?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, ")))")?;
}
}
TypeInner::Scalar(scalar) => {
- if let ScalarKind::Uint = scalar.kind {
- write!(self.out, "min(32u, firstbitlow(")?;
+ let scalar_width_bits = scalar.width * 8;
+
+ if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
+ write!(self.out, "min({scalar_width_bits}u, firstbitlow(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))")?;
} else {
- write!(self.out, "asint(min(32u, firstbitlow(")?;
+ // This is only needed for the FXC path, on 32bit signed integers.
+ write!(
+ self.out,
+ "asint(min({scalar_width_bits}u, firstbitlow("
+ )?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, ")))")?;
}
@@ -3012,30 +2966,47 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
crate::VectorSize::Quad => ".xxxx",
};
- if let ScalarKind::Uint = scalar.kind {
- write!(self.out, "((31u){s} - firstbithigh(")?;
+ // scalar width - 1
+ let constant = scalar.width * 8 - 1;
+
+ if scalar.kind == ScalarKind::Uint {
+ write!(self.out, "(({constant}u){s} - firstbithigh(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))")?;
} else {
+ let conversion_func = match scalar.width {
+ 4 => "asint",
+ _ => "",
+ };
write!(self.out, "(")?;
self.write_expr(module, arg, func_ctx)?;
write!(
self.out,
- " < (0){s} ? (0){s} : (31){s} - asint(firstbithigh("
+ " < (0){s} ? (0){s} : ({constant}){s} - {conversion_func}(firstbithigh("
)?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, ")))")?;
}
}
TypeInner::Scalar(scalar) => {
+ // scalar width - 1
+ let constant = scalar.width * 8 - 1;
+
if let ScalarKind::Uint = scalar.kind {
- write!(self.out, "(31u - firstbithigh(")?;
+ write!(self.out, "({constant}u - firstbithigh(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))")?;
} else {
+ let conversion_func = match scalar.width {
+ 4 => "asint",
+ _ => "",
+ };
write!(self.out, "(")?;
self.write_expr(module, arg, func_ctx)?;
- write!(self.out, " < 0 ? 0 : 31 - asint(firstbithigh(")?;
+ write!(
+ self.out,
+ " < 0 ? 0 : {constant} - {conversion_func}(firstbithigh("
+ )?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, ")))")?;
}