diff options
Diffstat (limited to 'third_party/rust/naga/src/back/msl')
-rw-r--r-- | third_party/rust/naga/src/back/msl/keywords.rs | 2 | ||||
-rw-r--r-- | third_party/rust/naga/src/back/msl/mod.rs | 15 | ||||
-rw-r--r-- | third_party/rust/naga/src/back/msl/writer.rs | 200 |
3 files changed, 161 insertions, 56 deletions
diff --git a/third_party/rust/naga/src/back/msl/keywords.rs b/third_party/rust/naga/src/back/msl/keywords.rs index f0025bf239..73c457dd34 100644 --- a/third_party/rust/naga/src/back/msl/keywords.rs +++ b/third_party/rust/naga/src/back/msl/keywords.rs @@ -4,6 +4,8 @@ // C++ - Standard for Programming Language C++ (N4431) // https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2015/n4431.pdf pub const RESERVED: &[&str] = &[ + // Undocumented + "assert", // found in https://github.com/gfx-rs/wgpu/issues/5347 // Standard for Programming Language C++ (N4431): 2.5 Alternative tokens "and", "bitor", diff --git a/third_party/rust/naga/src/back/msl/mod.rs b/third_party/rust/naga/src/back/msl/mod.rs index 5ef18730c9..68e5b79906 100644 --- a/third_party/rust/naga/src/back/msl/mod.rs +++ b/third_party/rust/naga/src/back/msl/mod.rs @@ -121,8 +121,8 @@ pub enum Error { UnsupportedCall(String), #[error("feature '{0}' is not implemented yet")] FeatureNotImplemented(String), - #[error("module is not valid")] - Validation, + #[error("internal naga error: module should not have validated: {0}")] + GenericValidation(String), #[error("BuiltIn {0:?} is not supported")] UnsupportedBuiltIn(crate::BuiltIn), #[error("capability {0:?} is not supported")] @@ -306,13 +306,10 @@ impl Options { }, }) } - LocationMode::Uniform => { - log::error!( - "Unexpected Binding::Location({}) for the Uniform mode", - location - ); - Err(Error::Validation) - } + LocationMode::Uniform => Err(Error::GenericValidation(format!( + "Unexpected Binding::Location({}) for the Uniform mode", + location + ))), }, } } diff --git a/third_party/rust/naga/src/back/msl/writer.rs b/third_party/rust/naga/src/back/msl/writer.rs index 1e496b5f50..5227d8e7db 100644 --- a/third_party/rust/naga/src/back/msl/writer.rs +++ b/third_party/rust/naga/src/back/msl/writer.rs @@ -319,7 +319,7 @@ pub struct Writer<W> { } impl crate::Scalar { - const fn to_msl_name(self) -> &'static str { + fn to_msl_name(self) -> &'static str { use crate::ScalarKind as Sk; match self { Self { @@ -328,20 +328,29 @@ impl crate::Scalar { } => "float", Self { kind: Sk::Sint, - width: _, + width: 4, } => "int", Self { kind: Sk::Uint, - width: _, + width: 4, } => "uint", Self { + kind: Sk::Sint, + width: 8, + } => "long", + Self { + kind: Sk::Uint, + width: 8, + } => "ulong", + Self { kind: Sk::Bool, width: _, } => "bool", Self { kind: Sk::AbstractInt | Sk::AbstractFloat, width: _, - } => unreachable!(), + } => unreachable!("Found Abstract scalar kind"), + _ => unreachable!("Unsupported scalar kind: {:?}", self), } } } @@ -735,7 +744,11 @@ impl<W: Write> Writer<W> { crate::TypeInner::Vector { size, .. } => { put_numeric_type(&mut self.out, crate::Scalar::U32, &[size])? } - _ => return Err(Error::Validation), + _ => { + return Err(Error::GenericValidation( + "Invalid type for image coordinate".into(), + )) + } }; write!(self.out, "(")?; @@ -1068,13 +1081,17 @@ impl<W: Write> Writer<W> { let (offset, array_ty) = match context.module.types[global.ty].inner { crate::TypeInner::Struct { ref members, .. } => match members.last() { Some(&crate::StructMember { offset, ty, .. }) => (offset, ty), - None => return Err(Error::Validation), + None => return Err(Error::GenericValidation("Struct has no members".into())), }, crate::TypeInner::Array { size: crate::ArraySize::Dynamic, .. } => (0, global.ty), - _ => return Err(Error::Validation), + ref ty => { + return Err(Error::GenericValidation(format!( + "Expected type with dynamic array, got {ty:?}" + ))) + } }; let (size, stride) = match context.module.types[array_ty].inner { @@ -1084,7 +1101,11 @@ impl<W: Write> Writer<W> { .size(context.module.to_ctx()), stride, ), - _ => return Err(Error::Validation), + ref ty => { + return Err(Error::GenericValidation(format!( + "Expected array type, got {ty:?}" + ))) + } }; // When the stride length is larger than the size, the final element's stride of @@ -1273,6 +1294,9 @@ impl<W: Write> Writer<W> { crate::Literal::I32(value) => { write!(self.out, "{value}")?; } + crate::Literal::U64(value) => { + write!(self.out, "{value}uL")?; + } crate::Literal::I64(value) => { write!(self.out, "{value}L")?; } @@ -1280,7 +1304,9 @@ impl<W: Write> Writer<W> { write!(self.out, "{value}")?; } crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => { - return Err(Error::Validation); + return Err(Error::GenericValidation( + "Unsupported abstract literal".into(), + )); } }, crate::Expression::Constant(handle) => { @@ -1342,7 +1368,11 @@ impl<W: Write> Writer<W> { crate::Expression::Splat { size, value } => { let scalar = match *get_expr_ty(ctx, value).inner_with(&module.types) { crate::TypeInner::Scalar(scalar) => scalar, - _ => return Err(Error::Validation), + ref ty => { + return Err(Error::GenericValidation(format!( + "Expected splat value type must be a scalar, got {ty:?}", + ))) + } }; put_numeric_type(&mut self.out, scalar, &[size])?; write!(self.out, "(")?; @@ -1672,7 +1702,11 @@ impl<W: Write> Writer<W> { self.put_expression(condition, context, true)?; write!(self.out, ")")?; } - _ => return Err(Error::Validation), + ref ty => { + return Err(Error::GenericValidation(format!( + "Expected select condition to be a non-bool type, got {ty:?}", + ))) + } }, crate::Expression::Derivative { axis, expr, .. } => { use crate::DerivativeAxis as Axis; @@ -1794,8 +1828,8 @@ impl<W: Write> Writer<W> { Mf::CountLeadingZeros => "clz", Mf::CountOneBits => "popcount", Mf::ReverseBits => "reverse_bits", - Mf::ExtractBits => "extract_bits", - Mf::InsertBits => "insert_bits", + Mf::ExtractBits => "", + Mf::InsertBits => "", Mf::FindLsb => "", Mf::FindMsb => "", // data packing @@ -1836,15 +1870,23 @@ impl<W: Write> Writer<W> { self.put_expression(arg1.unwrap(), context, false)?; write!(self.out, ")")?; } else if fun == Mf::FindLsb { + let scalar = context.resolve_type(arg).scalar().unwrap(); + let constant = scalar.width * 8 + 1; + write!(self.out, "((({NAMESPACE}::ctz(")?; self.put_expression(arg, context, true)?; - write!(self.out, ") + 1) % 33) - 1)")?; + write!(self.out, ") + 1) % {constant}) - 1)")?; } else if fun == Mf::FindMsb { let inner = context.resolve_type(arg); + let scalar = inner.scalar().unwrap(); + let constant = scalar.width * 8 - 1; - write!(self.out, "{NAMESPACE}::select(31 - {NAMESPACE}::clz(")?; + write!( + self.out, + "{NAMESPACE}::select({constant} - {NAMESPACE}::clz(" + )?; - if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() { + if scalar.kind == crate::ScalarKind::Sint { write!(self.out, "{NAMESPACE}::select(")?; self.put_expression(arg, context, true)?; write!(self.out, ", ~")?; @@ -1862,18 +1904,12 @@ impl<W: Write> Writer<W> { match *inner { crate::TypeInner::Vector { size, scalar } => { let size = back::vector_size_str(size); - if let crate::ScalarKind::Sint = scalar.kind { - write!(self.out, "int{size}")?; - } else { - write!(self.out, "uint{size}")?; - } + let name = scalar.to_msl_name(); + write!(self.out, "{name}{size}")?; } crate::TypeInner::Scalar(scalar) => { - if let crate::ScalarKind::Sint = scalar.kind { - write!(self.out, "int")?; - } else { - write!(self.out, "uint")?; - } + let name = scalar.to_msl_name(); + write!(self.out, "{name}")?; } _ => (), } @@ -1891,6 +1927,52 @@ impl<W: Write> Writer<W> { write!(self.out, "as_type<uint>(half2(")?; self.put_expression(arg, context, false)?; write!(self.out, "))")?; + } else if fun == Mf::ExtractBits { + // The behavior of ExtractBits is undefined when offset + count > bit_width. We need + // to first sanitize the offset and count first. If we don't do this, Apple chips + // will return out-of-spec values if the extracted range is not within the bit width. + // + // This encodes the exact formula specified by the wgsl spec, without temporary values: + // https://gpuweb.github.io/gpuweb/wgsl/#extractBits-unsigned-builtin + // + // w = sizeof(x) * 8 + // o = min(offset, w) + // tmp = w - o + // c = min(count, tmp) + // + // bitfieldExtract(x, o, c) + // + // extract_bits(e, min(offset, w), min(count, w - min(offset, w)))) + + let scalar_bits = context.resolve_type(arg).scalar_width().unwrap(); + + write!(self.out, "{NAMESPACE}::extract_bits(")?; + self.put_expression(arg, context, true)?; + write!(self.out, ", {NAMESPACE}::min(")?; + self.put_expression(arg1.unwrap(), context, true)?; + write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?; + self.put_expression(arg2.unwrap(), context, true)?; + write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?; + self.put_expression(arg1.unwrap(), context, true)?; + write!(self.out, ", {scalar_bits}u)))")?; + } else if fun == Mf::InsertBits { + // The behavior of InsertBits has the same issue as ExtractBits. + // + // insertBits(e, newBits, min(offset, w), min(count, w - min(offset, w)))) + + let scalar_bits = context.resolve_type(arg).scalar_width().unwrap(); + + write!(self.out, "{NAMESPACE}::insert_bits(")?; + self.put_expression(arg, context, true)?; + write!(self.out, ", ")?; + self.put_expression(arg1.unwrap(), context, true)?; + write!(self.out, ", {NAMESPACE}::min(")?; + self.put_expression(arg2.unwrap(), context, true)?; + write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?; + self.put_expression(arg3.unwrap(), context, true)?; + write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?; + self.put_expression(arg2.unwrap(), context, true)?; + write!(self.out, ", {scalar_bits}u)))")?; } else if fun == Mf::Radians { write!(self.out, "((")?; self.put_expression(arg, context, false)?; @@ -1920,14 +2002,8 @@ impl<W: Write> Writer<W> { kind, width: convert.unwrap_or(src.width), }; - let is_bool_cast = - kind == crate::ScalarKind::Bool || src.kind == crate::ScalarKind::Bool; let op = match convert { - Some(w) if w == src.width || is_bool_cast => "static_cast", - Some(8) if kind == crate::ScalarKind::Float => { - return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64)) - } - Some(_) => return Err(Error::Validation), + Some(_) => "static_cast", None => "as_type", }; write!(self.out, "{op}<")?; @@ -1955,7 +2031,11 @@ impl<W: Write> Writer<W> { self.put_expression(expr, context, true)?; write!(self.out, ")")?; } - _ => return Err(Error::Validation), + ref ty => { + return Err(Error::GenericValidation(format!( + "Unsupported type for As: {ty:?}" + ))) + } }, // has to be a named expression crate::Expression::CallResult(_) @@ -1970,11 +2050,19 @@ impl<W: Write> Writer<W> { crate::Expression::AccessIndex { base, .. } => { match context.function.expressions[base] { crate::Expression::GlobalVariable(handle) => handle, - _ => return Err(Error::Validation), + ref ex => { + return Err(Error::GenericValidation(format!( + "Expected global variable in AccessIndex, got {ex:?}" + ))) + } } } crate::Expression::GlobalVariable(handle) => handle, - _ => return Err(Error::Validation), + ref ex => { + return Err(Error::GenericValidation(format!( + "Unexpected expression in ArrayLength, got {ex:?}" + ))) + } }; if !is_scoped { @@ -2140,10 +2228,12 @@ impl<W: Write> Writer<W> { match length { index::IndexableLength::Known(value) => write!(self.out, "{value}")?, index::IndexableLength::Dynamic => { - let global = context - .function - .originating_global(base) - .ok_or(Error::Validation)?; + let global = + context.function.originating_global(base).ok_or_else(|| { + Error::GenericValidation( + "Could not find originating global".into(), + ) + })?; write!(self.out, "1 + ")?; self.put_dynamic_array_max_index(global, context)? } @@ -2300,10 +2390,9 @@ impl<W: Write> Writer<W> { write!(self.out, "{}u", limit - 1)?; } index::IndexableLength::Dynamic => { - let global = context - .function - .originating_global(base) - .ok_or(Error::Validation)?; + let global = context.function.originating_global(base).ok_or_else(|| { + Error::GenericValidation("Could not find originating global".into()) + })?; self.put_dynamic_array_max_index(global, context)?; } } @@ -2489,7 +2578,14 @@ impl<W: Write> Writer<W> { } } - if let Expression::Math { fun, arg, arg1, .. } = *expr { + if let Expression::Math { + fun, + arg, + arg1, + arg2, + .. + } = *expr + { match fun { crate::MathFunction::Dot => { // WGSL's `dot` function works on any `vecN` type, but Metal's only @@ -2514,6 +2610,14 @@ impl<W: Write> Writer<W> { crate::MathFunction::FindMsb => { self.need_bake_expressions.insert(arg); } + crate::MathFunction::ExtractBits => { + // Only argument 1 is re-used. + self.need_bake_expressions.insert(arg1.unwrap()); + } + crate::MathFunction::InsertBits => { + // Only argument 2 is re-used. + self.need_bake_expressions.insert(arg2.unwrap()); + } crate::MathFunction::Sign => { // WGSL's `sign` function works also on signed ints, but Metal's only // works on floating points, so we emit inline code for integer `sign` @@ -3048,7 +3152,7 @@ impl<W: Write> Writer<W> { for statement in statements { if let crate::Statement::Emit(ref range) = *statement { for handle in range.clone() { - self.named_expressions.remove(&handle); + self.named_expressions.shift_remove(&handle); } } } @@ -3897,7 +4001,9 @@ impl<W: Write> Writer<W> { binding: None, first_time: true, }; - let binding = binding.ok_or(Error::Validation)?; + let binding = binding.ok_or_else(|| { + Error::GenericValidation("Expected binding, got None".into()) + })?; if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = *binding { has_point_size = true; |