summaryrefslogtreecommitdiffstats
path: root/third_party/rust/naga/src/back/msl/writer.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/naga/src/back/msl/writer.rs')
-rw-r--r--third_party/rust/naga/src/back/msl/writer.rs200
1 files changed, 153 insertions, 47 deletions
diff --git a/third_party/rust/naga/src/back/msl/writer.rs b/third_party/rust/naga/src/back/msl/writer.rs
index 1e496b5f50..5227d8e7db 100644
--- a/third_party/rust/naga/src/back/msl/writer.rs
+++ b/third_party/rust/naga/src/back/msl/writer.rs
@@ -319,7 +319,7 @@ pub struct Writer<W> {
}
impl crate::Scalar {
- const fn to_msl_name(self) -> &'static str {
+ fn to_msl_name(self) -> &'static str {
use crate::ScalarKind as Sk;
match self {
Self {
@@ -328,20 +328,29 @@ impl crate::Scalar {
} => "float",
Self {
kind: Sk::Sint,
- width: _,
+ width: 4,
} => "int",
Self {
kind: Sk::Uint,
- width: _,
+ width: 4,
} => "uint",
Self {
+ kind: Sk::Sint,
+ width: 8,
+ } => "long",
+ Self {
+ kind: Sk::Uint,
+ width: 8,
+ } => "ulong",
+ Self {
kind: Sk::Bool,
width: _,
} => "bool",
Self {
kind: Sk::AbstractInt | Sk::AbstractFloat,
width: _,
- } => unreachable!(),
+ } => unreachable!("Found Abstract scalar kind"),
+ _ => unreachable!("Unsupported scalar kind: {:?}", self),
}
}
}
@@ -735,7 +744,11 @@ impl<W: Write> Writer<W> {
crate::TypeInner::Vector { size, .. } => {
put_numeric_type(&mut self.out, crate::Scalar::U32, &[size])?
}
- _ => return Err(Error::Validation),
+ _ => {
+ return Err(Error::GenericValidation(
+ "Invalid type for image coordinate".into(),
+ ))
+ }
};
write!(self.out, "(")?;
@@ -1068,13 +1081,17 @@ impl<W: Write> Writer<W> {
let (offset, array_ty) = match context.module.types[global.ty].inner {
crate::TypeInner::Struct { ref members, .. } => match members.last() {
Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
- None => return Err(Error::Validation),
+ None => return Err(Error::GenericValidation("Struct has no members".into())),
},
crate::TypeInner::Array {
size: crate::ArraySize::Dynamic,
..
} => (0, global.ty),
- _ => return Err(Error::Validation),
+ ref ty => {
+ return Err(Error::GenericValidation(format!(
+ "Expected type with dynamic array, got {ty:?}"
+ )))
+ }
};
let (size, stride) = match context.module.types[array_ty].inner {
@@ -1084,7 +1101,11 @@ impl<W: Write> Writer<W> {
.size(context.module.to_ctx()),
stride,
),
- _ => return Err(Error::Validation),
+ ref ty => {
+ return Err(Error::GenericValidation(format!(
+ "Expected array type, got {ty:?}"
+ )))
+ }
};
// When the stride length is larger than the size, the final element's stride of
@@ -1273,6 +1294,9 @@ impl<W: Write> Writer<W> {
crate::Literal::I32(value) => {
write!(self.out, "{value}")?;
}
+ crate::Literal::U64(value) => {
+ write!(self.out, "{value}uL")?;
+ }
crate::Literal::I64(value) => {
write!(self.out, "{value}L")?;
}
@@ -1280,7 +1304,9 @@ impl<W: Write> Writer<W> {
write!(self.out, "{value}")?;
}
crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
- return Err(Error::Validation);
+ return Err(Error::GenericValidation(
+ "Unsupported abstract literal".into(),
+ ));
}
},
crate::Expression::Constant(handle) => {
@@ -1342,7 +1368,11 @@ impl<W: Write> Writer<W> {
crate::Expression::Splat { size, value } => {
let scalar = match *get_expr_ty(ctx, value).inner_with(&module.types) {
crate::TypeInner::Scalar(scalar) => scalar,
- _ => return Err(Error::Validation),
+ ref ty => {
+ return Err(Error::GenericValidation(format!(
+ "Expected splat value type must be a scalar, got {ty:?}",
+ )))
+ }
};
put_numeric_type(&mut self.out, scalar, &[size])?;
write!(self.out, "(")?;
@@ -1672,7 +1702,11 @@ impl<W: Write> Writer<W> {
self.put_expression(condition, context, true)?;
write!(self.out, ")")?;
}
- _ => return Err(Error::Validation),
+ ref ty => {
+ return Err(Error::GenericValidation(format!(
+ "Expected select condition to be a non-bool type, got {ty:?}",
+ )))
+ }
},
crate::Expression::Derivative { axis, expr, .. } => {
use crate::DerivativeAxis as Axis;
@@ -1794,8 +1828,8 @@ impl<W: Write> Writer<W> {
Mf::CountLeadingZeros => "clz",
Mf::CountOneBits => "popcount",
Mf::ReverseBits => "reverse_bits",
- Mf::ExtractBits => "extract_bits",
- Mf::InsertBits => "insert_bits",
+ Mf::ExtractBits => "",
+ Mf::InsertBits => "",
Mf::FindLsb => "",
Mf::FindMsb => "",
// data packing
@@ -1836,15 +1870,23 @@ impl<W: Write> Writer<W> {
self.put_expression(arg1.unwrap(), context, false)?;
write!(self.out, ")")?;
} else if fun == Mf::FindLsb {
+ let scalar = context.resolve_type(arg).scalar().unwrap();
+ let constant = scalar.width * 8 + 1;
+
write!(self.out, "((({NAMESPACE}::ctz(")?;
self.put_expression(arg, context, true)?;
- write!(self.out, ") + 1) % 33) - 1)")?;
+ write!(self.out, ") + 1) % {constant}) - 1)")?;
} else if fun == Mf::FindMsb {
let inner = context.resolve_type(arg);
+ let scalar = inner.scalar().unwrap();
+ let constant = scalar.width * 8 - 1;
- write!(self.out, "{NAMESPACE}::select(31 - {NAMESPACE}::clz(")?;
+ write!(
+ self.out,
+ "{NAMESPACE}::select({constant} - {NAMESPACE}::clz("
+ )?;
- if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() {
+ if scalar.kind == crate::ScalarKind::Sint {
write!(self.out, "{NAMESPACE}::select(")?;
self.put_expression(arg, context, true)?;
write!(self.out, ", ~")?;
@@ -1862,18 +1904,12 @@ impl<W: Write> Writer<W> {
match *inner {
crate::TypeInner::Vector { size, scalar } => {
let size = back::vector_size_str(size);
- if let crate::ScalarKind::Sint = scalar.kind {
- write!(self.out, "int{size}")?;
- } else {
- write!(self.out, "uint{size}")?;
- }
+ let name = scalar.to_msl_name();
+ write!(self.out, "{name}{size}")?;
}
crate::TypeInner::Scalar(scalar) => {
- if let crate::ScalarKind::Sint = scalar.kind {
- write!(self.out, "int")?;
- } else {
- write!(self.out, "uint")?;
- }
+ let name = scalar.to_msl_name();
+ write!(self.out, "{name}")?;
}
_ => (),
}
@@ -1891,6 +1927,52 @@ impl<W: Write> Writer<W> {
write!(self.out, "as_type<uint>(half2(")?;
self.put_expression(arg, context, false)?;
write!(self.out, "))")?;
+ } else if fun == Mf::ExtractBits {
+ // The behavior of ExtractBits is undefined when offset + count > bit_width. We need
+ // to first sanitize the offset and count first. If we don't do this, Apple chips
+ // will return out-of-spec values if the extracted range is not within the bit width.
+ //
+ // This encodes the exact formula specified by the wgsl spec, without temporary values:
+ // https://gpuweb.github.io/gpuweb/wgsl/#extractBits-unsigned-builtin
+ //
+ // w = sizeof(x) * 8
+ // o = min(offset, w)
+ // tmp = w - o
+ // c = min(count, tmp)
+ //
+ // bitfieldExtract(x, o, c)
+ //
+ // extract_bits(e, min(offset, w), min(count, w - min(offset, w))))
+
+ let scalar_bits = context.resolve_type(arg).scalar_width().unwrap();
+
+ write!(self.out, "{NAMESPACE}::extract_bits(")?;
+ self.put_expression(arg, context, true)?;
+ write!(self.out, ", {NAMESPACE}::min(")?;
+ self.put_expression(arg1.unwrap(), context, true)?;
+ write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
+ self.put_expression(arg2.unwrap(), context, true)?;
+ write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
+ self.put_expression(arg1.unwrap(), context, true)?;
+ write!(self.out, ", {scalar_bits}u)))")?;
+ } else if fun == Mf::InsertBits {
+ // The behavior of InsertBits has the same issue as ExtractBits.
+ //
+ // insertBits(e, newBits, min(offset, w), min(count, w - min(offset, w))))
+
+ let scalar_bits = context.resolve_type(arg).scalar_width().unwrap();
+
+ write!(self.out, "{NAMESPACE}::insert_bits(")?;
+ self.put_expression(arg, context, true)?;
+ write!(self.out, ", ")?;
+ self.put_expression(arg1.unwrap(), context, true)?;
+ write!(self.out, ", {NAMESPACE}::min(")?;
+ self.put_expression(arg2.unwrap(), context, true)?;
+ write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
+ self.put_expression(arg3.unwrap(), context, true)?;
+ write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
+ self.put_expression(arg2.unwrap(), context, true)?;
+ write!(self.out, ", {scalar_bits}u)))")?;
} else if fun == Mf::Radians {
write!(self.out, "((")?;
self.put_expression(arg, context, false)?;
@@ -1920,14 +2002,8 @@ impl<W: Write> Writer<W> {
kind,
width: convert.unwrap_or(src.width),
};
- let is_bool_cast =
- kind == crate::ScalarKind::Bool || src.kind == crate::ScalarKind::Bool;
let op = match convert {
- Some(w) if w == src.width || is_bool_cast => "static_cast",
- Some(8) if kind == crate::ScalarKind::Float => {
- return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
- }
- Some(_) => return Err(Error::Validation),
+ Some(_) => "static_cast",
None => "as_type",
};
write!(self.out, "{op}<")?;
@@ -1955,7 +2031,11 @@ impl<W: Write> Writer<W> {
self.put_expression(expr, context, true)?;
write!(self.out, ")")?;
}
- _ => return Err(Error::Validation),
+ ref ty => {
+ return Err(Error::GenericValidation(format!(
+ "Unsupported type for As: {ty:?}"
+ )))
+ }
},
// has to be a named expression
crate::Expression::CallResult(_)
@@ -1970,11 +2050,19 @@ impl<W: Write> Writer<W> {
crate::Expression::AccessIndex { base, .. } => {
match context.function.expressions[base] {
crate::Expression::GlobalVariable(handle) => handle,
- _ => return Err(Error::Validation),
+ ref ex => {
+ return Err(Error::GenericValidation(format!(
+ "Expected global variable in AccessIndex, got {ex:?}"
+ )))
+ }
}
}
crate::Expression::GlobalVariable(handle) => handle,
- _ => return Err(Error::Validation),
+ ref ex => {
+ return Err(Error::GenericValidation(format!(
+ "Unexpected expression in ArrayLength, got {ex:?}"
+ )))
+ }
};
if !is_scoped {
@@ -2140,10 +2228,12 @@ impl<W: Write> Writer<W> {
match length {
index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
index::IndexableLength::Dynamic => {
- let global = context
- .function
- .originating_global(base)
- .ok_or(Error::Validation)?;
+ let global =
+ context.function.originating_global(base).ok_or_else(|| {
+ Error::GenericValidation(
+ "Could not find originating global".into(),
+ )
+ })?;
write!(self.out, "1 + ")?;
self.put_dynamic_array_max_index(global, context)?
}
@@ -2300,10 +2390,9 @@ impl<W: Write> Writer<W> {
write!(self.out, "{}u", limit - 1)?;
}
index::IndexableLength::Dynamic => {
- let global = context
- .function
- .originating_global(base)
- .ok_or(Error::Validation)?;
+ let global = context.function.originating_global(base).ok_or_else(|| {
+ Error::GenericValidation("Could not find originating global".into())
+ })?;
self.put_dynamic_array_max_index(global, context)?;
}
}
@@ -2489,7 +2578,14 @@ impl<W: Write> Writer<W> {
}
}
- if let Expression::Math { fun, arg, arg1, .. } = *expr {
+ if let Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ ..
+ } = *expr
+ {
match fun {
crate::MathFunction::Dot => {
// WGSL's `dot` function works on any `vecN` type, but Metal's only
@@ -2514,6 +2610,14 @@ impl<W: Write> Writer<W> {
crate::MathFunction::FindMsb => {
self.need_bake_expressions.insert(arg);
}
+ crate::MathFunction::ExtractBits => {
+ // Only argument 1 is re-used.
+ self.need_bake_expressions.insert(arg1.unwrap());
+ }
+ crate::MathFunction::InsertBits => {
+ // Only argument 2 is re-used.
+ self.need_bake_expressions.insert(arg2.unwrap());
+ }
crate::MathFunction::Sign => {
// WGSL's `sign` function works also on signed ints, but Metal's only
// works on floating points, so we emit inline code for integer `sign`
@@ -3048,7 +3152,7 @@ impl<W: Write> Writer<W> {
for statement in statements {
if let crate::Statement::Emit(ref range) = *statement {
for handle in range.clone() {
- self.named_expressions.remove(&handle);
+ self.named_expressions.shift_remove(&handle);
}
}
}
@@ -3897,7 +4001,9 @@ impl<W: Write> Writer<W> {
binding: None,
first_time: true,
};
- let binding = binding.ok_or(Error::Validation)?;
+ let binding = binding.ok_or_else(|| {
+ Error::GenericValidation("Expected binding, got None".into())
+ })?;
if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = *binding {
has_point_size = true;