diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 01:13:27 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 01:13:27 +0000 |
commit | 40a355a42d4a9444dc753c04c6608dade2f06a23 (patch) | |
tree | 871fc667d2de662f171103ce5ec067014ef85e61 /third_party/rust/naga/src | |
parent | Adding upstream version 124.0.1. (diff) | |
download | firefox-40a355a42d4a9444dc753c04c6608dade2f06a23.tar.xz firefox-40a355a42d4a9444dc753c04c6608dade2f06a23.zip |
Adding upstream version 125.0.1.upstream/125.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/naga/src')
30 files changed, 1454 insertions, 681 deletions
diff --git a/third_party/rust/naga/src/back/glsl/features.rs b/third_party/rust/naga/src/back/glsl/features.rs index e7de05f695..99c128c6d9 100644 --- a/third_party/rust/naga/src/back/glsl/features.rs +++ b/third_party/rust/naga/src/back/glsl/features.rs @@ -1,8 +1,8 @@ use super::{BackendResult, Error, Version, Writer}; use crate::{ back::glsl::{Options, WriterFlags}, - AddressSpace, Binding, Expression, Handle, ImageClass, ImageDimension, Interpolation, Sampling, - Scalar, ScalarKind, ShaderStage, StorageFormat, Type, TypeInner, + AddressSpace, Binding, Expression, Handle, ImageClass, ImageDimension, Interpolation, + SampleLevel, Sampling, Scalar, ScalarKind, ShaderStage, StorageFormat, Type, TypeInner, }; use std::fmt::Write; @@ -48,6 +48,8 @@ bitflags::bitflags! { /// /// We can always support this, either through the language or a polyfill const INSTANCE_INDEX = 1 << 22; + /// Sample specific LODs of cube / array shadow textures + const TEXTURE_SHADOW_LOD = 1 << 23; } } @@ -125,6 +127,7 @@ impl FeaturesManager { check_feature!(TEXTURE_SAMPLES, 150); check_feature!(TEXTURE_LEVELS, 130); check_feature!(IMAGE_SIZE, 430, 310); + check_feature!(TEXTURE_SHADOW_LOD, 200, 300); // Return an error if there are missing features if missing.is_empty() { @@ -251,6 +254,11 @@ impl FeaturesManager { } } + if self.0.contains(Features::TEXTURE_SHADOW_LOD) { + // https://registry.khronos.org/OpenGL/extensions/EXT/EXT_texture_shadow_lod.txt + writeln!(out, "#extension GL_EXT_texture_shadow_lod : require")?; + } + Ok(()) } } @@ -469,6 +477,47 @@ impl<'a, W> Writer<'a, W> { } } } + Expression::ImageSample { image, level, offset, .. } => { + if let TypeInner::Image { + dim, + arrayed, + class: ImageClass::Depth { .. }, + } = *info[image].ty.inner_with(&module.types) { + let lod = matches!(level, SampleLevel::Zero | SampleLevel::Exact(_)); + let bias = matches!(level, SampleLevel::Bias(_)); + let auto = matches!(level, SampleLevel::Auto); + let cube = dim == ImageDimension::Cube; + let array2d = dim == ImageDimension::D2 && arrayed; + let gles = self.options.version.is_es(); + + // We have a workaround of using `textureGrad` instead of `textureLod` if the LOD is zero, + // so we don't *need* this extension for those cases. + // But if we're explicitly allowed to use the extension (`WriterFlags::TEXTURE_SHADOW_LOD`), + // we always use it instead of the workaround. + let grad_workaround_applicable = (array2d || (cube && !arrayed)) && level == SampleLevel::Zero; + let prefer_grad_workaround = grad_workaround_applicable && !self.options.writer_flags.contains(WriterFlags::TEXTURE_SHADOW_LOD); + + let mut ext_used = false; + + // float texture(sampler2DArrayShadow sampler, vec4 P [, float bias]) + // float texture(samplerCubeArrayShadow sampler, vec4 P, float compare [, float bias]) + ext_used |= (array2d || cube && arrayed) && bias; + + // The non `bias` version of this was standardized in GL 4.3, but never in GLES. + // float textureOffset(sampler2DArrayShadow sampler, vec4 P, ivec2 offset [, float bias]) + ext_used |= array2d && (bias || (gles && auto)) && offset.is_some(); + + // float textureLod(sampler2DArrayShadow sampler, vec4 P, float lod) + // float textureLodOffset(sampler2DArrayShadow sampler, vec4 P, float lod, ivec2 offset) + // float textureLod(samplerCubeShadow sampler, vec4 P, float lod) + // float textureLod(samplerCubeArrayShadow sampler, vec4 P, float compare, float lod) + ext_used |= (cube || array2d) && lod && !prefer_grad_workaround; + + if ext_used { + features.request(Features::TEXTURE_SHADOW_LOD); + } + } + } _ => {} } } diff --git a/third_party/rust/naga/src/back/glsl/mod.rs b/third_party/rust/naga/src/back/glsl/mod.rs index e346d43257..9bda594610 100644 --- a/third_party/rust/naga/src/back/glsl/mod.rs +++ b/third_party/rust/naga/src/back/glsl/mod.rs @@ -178,7 +178,7 @@ impl Version { /// Note: `location=` for vertex inputs and fragment outputs is supported /// unconditionally for GLES 300. fn supports_explicit_locations(&self) -> bool { - *self >= Version::Desktop(410) || *self >= Version::new_gles(310) + *self >= Version::Desktop(420) || *self >= Version::new_gles(310) } fn supports_early_depth_test(&self) -> bool { @@ -646,16 +646,6 @@ impl<'a, W: Write> Writer<'a, W> { // preprocessor not the processor ¯\_(ツ)_/¯ self.features.write(self.options, &mut self.out)?; - // Write the additional extensions - if self - .options - .writer_flags - .contains(WriterFlags::TEXTURE_SHADOW_LOD) - { - // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_texture_shadow_lod.txt - writeln!(self.out, "#extension GL_EXT_texture_shadow_lod : require")?; - } - // glsl es requires a precision to be specified for floats and ints // TODO: Should this be user configurable? if es { @@ -1300,7 +1290,14 @@ impl<'a, W: Write> Writer<'a, W> { let inner = expr_info.ty.inner_with(&self.module.types); - if let Expression::Math { fun, arg, arg1, .. } = *expr { + if let Expression::Math { + fun, + arg, + arg1, + arg2, + .. + } = *expr + { match fun { crate::MathFunction::Dot => { // if the expression is a Dot product with integer arguments, @@ -1315,6 +1312,14 @@ impl<'a, W: Write> Writer<'a, W> { } } } + crate::MathFunction::ExtractBits => { + // Only argument 1 is re-used. + self.need_bake_expressions.insert(arg1.unwrap()); + } + crate::MathFunction::InsertBits => { + // Only argument 2 is re-used. + self.need_bake_expressions.insert(arg2.unwrap()); + } crate::MathFunction::CountLeadingZeros => { if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() { self.need_bake_expressions.insert(arg); @@ -2451,6 +2456,9 @@ impl<'a, W: Write> Writer<'a, W> { crate::Literal::I64(_) => { return Err(Error::Custom("GLSL has no 64-bit integer type".into())); } + crate::Literal::U64(_) => { + return Err(Error::Custom("GLSL has no 64-bit integer type".into())); + } crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => { return Err(Error::Custom( "Abstract types should not appear in IR presented to backends".into(), @@ -2620,51 +2628,49 @@ impl<'a, W: Write> Writer<'a, W> { level, depth_ref, } => { - let dim = match *ctx.resolve_type(image, &self.module.types) { - TypeInner::Image { dim, .. } => dim, + let (dim, class, arrayed) = match *ctx.resolve_type(image, &self.module.types) { + TypeInner::Image { + dim, + class, + arrayed, + .. + } => (dim, class, arrayed), _ => unreachable!(), }; - - if dim == crate::ImageDimension::Cube - && array_index.is_some() - && depth_ref.is_some() - { - match level { - crate::SampleLevel::Zero - | crate::SampleLevel::Exact(_) - | crate::SampleLevel::Gradient { .. } - | crate::SampleLevel::Bias(_) => { - return Err(Error::Custom(String::from( - "gsamplerCubeArrayShadow isn't supported in textureGrad, \ - textureLod or texture with bias", - ))) - } - crate::SampleLevel::Auto => {} + let mut err = None; + if dim == crate::ImageDimension::Cube { + if offset.is_some() { + err = Some("gsamplerCube[Array][Shadow] doesn't support texture sampling with offsets"); + } + if arrayed + && matches!(class, crate::ImageClass::Depth { .. }) + && matches!(level, crate::SampleLevel::Gradient { .. }) + { + err = Some("samplerCubeArrayShadow don't support textureGrad"); } } + if gather.is_some() && level != crate::SampleLevel::Zero { + err = Some("textureGather doesn't support LOD parameters"); + } + if let Some(err) = err { + return Err(Error::Custom(String::from(err))); + } - // textureLod on sampler2DArrayShadow and samplerCubeShadow does not exist in GLSL. - // To emulate this, we will have to use textureGrad with a constant gradient of 0. - let workaround_lod_array_shadow_as_grad = (array_index.is_some() - || dim == crate::ImageDimension::Cube) - && depth_ref.is_some() - && gather.is_none() - && !self - .options - .writer_flags - .contains(WriterFlags::TEXTURE_SHADOW_LOD); - - //Write the function to be used depending on the sample level + // `textureLod[Offset]` on `sampler2DArrayShadow` and `samplerCubeShadow` does not exist in GLSL, + // unless `GL_EXT_texture_shadow_lod` is present. + // But if the target LOD is zero, we can emulate that by using `textureGrad[Offset]` with a constant gradient of 0. + let workaround_lod_with_grad = ((dim == crate::ImageDimension::Cube && !arrayed) + || (dim == crate::ImageDimension::D2 && arrayed)) + && level == crate::SampleLevel::Zero + && matches!(class, crate::ImageClass::Depth { .. }) + && !self.features.contains(Features::TEXTURE_SHADOW_LOD); + + // Write the function to be used depending on the sample level let fun_name = match level { crate::SampleLevel::Zero if gather.is_some() => "textureGather", + crate::SampleLevel::Zero if workaround_lod_with_grad => "textureGrad", crate::SampleLevel::Auto | crate::SampleLevel::Bias(_) => "texture", - crate::SampleLevel::Zero | crate::SampleLevel::Exact(_) => { - if workaround_lod_array_shadow_as_grad { - "textureGrad" - } else { - "textureLod" - } - } + crate::SampleLevel::Zero | crate::SampleLevel::Exact(_) => "textureLod", crate::SampleLevel::Gradient { .. } => "textureGrad", }; let offset_name = match offset { @@ -2727,7 +2733,7 @@ impl<'a, W: Write> Writer<'a, W> { crate::SampleLevel::Auto => (), // Zero needs level set to 0 crate::SampleLevel::Zero => { - if workaround_lod_array_shadow_as_grad { + if workaround_lod_with_grad { let vec_dim = match dim { crate::ImageDimension::Cube => 3, _ => 2, @@ -2739,13 +2745,8 @@ impl<'a, W: Write> Writer<'a, W> { } // Exact and bias require another argument crate::SampleLevel::Exact(expr) => { - if workaround_lod_array_shadow_as_grad { - log::warn!("Unable to `textureLod` a shadow array, ignoring the LOD"); - write!(self.out, ", vec2(0,0), vec2(0,0)")?; - } else { - write!(self.out, ", ")?; - self.write_expr(expr, ctx)?; - } + write!(self.out, ", ")?; + self.write_expr(expr, ctx)?; } crate::SampleLevel::Bias(_) => { // This needs to be done after the offset writing @@ -3155,7 +3156,29 @@ impl<'a, W: Write> Writer<'a, W> { Mf::Abs => "abs", Mf::Min => "min", Mf::Max => "max", - Mf::Clamp => "clamp", + Mf::Clamp => { + let scalar_kind = ctx + .resolve_type(arg, &self.module.types) + .scalar_kind() + .unwrap(); + match scalar_kind { + crate::ScalarKind::Float => "clamp", + // Clamp is undefined if min > max. In practice this means it can use a median-of-three + // instruction to determine the value. This is fine according to the WGSL spec for float + // clamp, but integer clamp _must_ use min-max. As such we write out min/max. + _ => { + write!(self.out, "min(max(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ", ")?; + self.write_expr(arg1.unwrap(), ctx)?; + write!(self.out, "), ")?; + self.write_expr(arg2.unwrap(), ctx)?; + write!(self.out, ")")?; + + return Ok(()); + } + } + } Mf::Saturate => { write!(self.out, "clamp(")?; @@ -3370,8 +3393,59 @@ impl<'a, W: Write> Writer<'a, W> { } Mf::CountOneBits => "bitCount", Mf::ReverseBits => "bitfieldReverse", - Mf::ExtractBits => "bitfieldExtract", - Mf::InsertBits => "bitfieldInsert", + Mf::ExtractBits => { + // The behavior of ExtractBits is undefined when offset + count > bit_width. We need + // to first sanitize the offset and count first. If we don't do this, AMD and Intel chips + // will return out-of-spec values if the extracted range is not within the bit width. + // + // This encodes the exact formula specified by the wgsl spec, without temporary values: + // https://gpuweb.github.io/gpuweb/wgsl/#extractBits-unsigned-builtin + // + // w = sizeof(x) * 8 + // o = min(offset, w) + // c = min(count, w - o) + // + // bitfieldExtract(x, o, c) + // + // extract_bits(e, min(offset, w), min(count, w - min(offset, w)))) + let scalar_bits = ctx + .resolve_type(arg, &self.module.types) + .scalar_width() + .unwrap(); + + write!(self.out, "bitfieldExtract(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ", int(min(")?; + self.write_expr(arg1.unwrap(), ctx)?; + write!(self.out, ", {scalar_bits}u)), int(min(",)?; + self.write_expr(arg2.unwrap(), ctx)?; + write!(self.out, ", {scalar_bits}u - min(")?; + self.write_expr(arg1.unwrap(), ctx)?; + write!(self.out, ", {scalar_bits}u))))")?; + + return Ok(()); + } + Mf::InsertBits => { + // InsertBits has the same considerations as ExtractBits above + let scalar_bits = ctx + .resolve_type(arg, &self.module.types) + .scalar_width() + .unwrap(); + + write!(self.out, "bitfieldInsert(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ", ")?; + self.write_expr(arg1.unwrap(), ctx)?; + write!(self.out, ", int(min(")?; + self.write_expr(arg2.unwrap(), ctx)?; + write!(self.out, ", {scalar_bits}u)), int(min(",)?; + self.write_expr(arg3.unwrap(), ctx)?; + write!(self.out, ", {scalar_bits}u - min(")?; + self.write_expr(arg2.unwrap(), ctx)?; + write!(self.out, ", {scalar_bits}u))))")?; + + return Ok(()); + } Mf::FindLsb => "findLSB", Mf::FindMsb => "findMSB", // data packing diff --git a/third_party/rust/naga/src/back/hlsl/conv.rs b/third_party/rust/naga/src/back/hlsl/conv.rs index b6918ddc42..2a6db35db8 100644 --- a/third_party/rust/naga/src/back/hlsl/conv.rs +++ b/third_party/rust/naga/src/back/hlsl/conv.rs @@ -21,8 +21,16 @@ impl crate::Scalar { /// <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-scalar> pub(super) const fn to_hlsl_str(self) -> Result<&'static str, Error> { match self.kind { - crate::ScalarKind::Sint => Ok("int"), - crate::ScalarKind::Uint => Ok("uint"), + crate::ScalarKind::Sint => match self.width { + 4 => Ok("int"), + 8 => Ok("int64_t"), + _ => Err(Error::UnsupportedScalar(self)), + }, + crate::ScalarKind::Uint => match self.width { + 4 => Ok("uint"), + 8 => Ok("uint64_t"), + _ => Err(Error::UnsupportedScalar(self)), + }, crate::ScalarKind::Float => match self.width { 2 => Ok("half"), 4 => Ok("float"), diff --git a/third_party/rust/naga/src/back/hlsl/help.rs b/third_party/rust/naga/src/back/hlsl/help.rs index fa6062a1ad..4dd9ea5987 100644 --- a/third_party/rust/naga/src/back/hlsl/help.rs +++ b/third_party/rust/naga/src/back/hlsl/help.rs @@ -26,7 +26,11 @@ int dim_1d = NagaDimensions1D(image_1d); ``` */ -use super::{super::FunctionCtx, BackendResult}; +use super::{ + super::FunctionCtx, + writer::{EXTRACT_BITS_FUNCTION, INSERT_BITS_FUNCTION}, + BackendResult, +}; use crate::{arena::Handle, proc::NameKey}; use std::fmt::Write; @@ -59,6 +63,13 @@ pub(super) struct WrappedMatCx2 { pub(super) columns: crate::VectorSize, } +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +pub(super) struct WrappedMath { + pub(super) fun: crate::MathFunction, + pub(super) scalar: crate::Scalar, + pub(super) components: Option<u32>, +} + /// HLSL backend requires its own `ImageQuery` enum. /// /// It is used inside `WrappedImageQuery` and should be unique per ImageQuery function. @@ -851,12 +862,149 @@ impl<'a, W: Write> super::Writer<'a, W> { Ok(()) } + pub(super) fn write_wrapped_math_functions( + &mut self, + module: &crate::Module, + func_ctx: &FunctionCtx, + ) -> BackendResult { + for (_, expression) in func_ctx.expressions.iter() { + if let crate::Expression::Math { + fun, + arg, + arg1: _arg1, + arg2: _arg2, + arg3: _arg3, + } = *expression + { + match fun { + crate::MathFunction::ExtractBits => { + // The behavior of our extractBits polyfill is undefined if offset + count > bit_width. We need + // to first sanitize the offset and count first. If we don't do this, we will get out-of-spec + // values if the extracted range is not within the bit width. + // + // This encodes the exact formula specified by the wgsl spec: + // https://gpuweb.github.io/gpuweb/wgsl/#extractBits-unsigned-builtin + // + // w = sizeof(x) * 8 + // o = min(offset, w) + // c = min(count, w - o) + // + // bitfieldExtract(x, o, c) + let arg_ty = func_ctx.resolve_type(arg, &module.types); + let scalar = arg_ty.scalar().unwrap(); + let components = arg_ty.components(); + + let wrapped = WrappedMath { + fun, + scalar, + components, + }; + + if !self.wrapped.math.insert(wrapped) { + continue; + } + + // Write return type + self.write_value_type(module, arg_ty)?; + + let scalar_width: u8 = scalar.width * 8; + + // Write function name and parameters + writeln!(self.out, " {EXTRACT_BITS_FUNCTION}(")?; + write!(self.out, " ")?; + self.write_value_type(module, arg_ty)?; + writeln!(self.out, " e,")?; + writeln!(self.out, " uint offset,")?; + writeln!(self.out, " uint count")?; + writeln!(self.out, ") {{")?; + + // Write function body + writeln!(self.out, " uint w = {scalar_width};")?; + writeln!(self.out, " uint o = min(offset, w);")?; + writeln!(self.out, " uint c = min(count, w - o);")?; + writeln!( + self.out, + " return (c == 0 ? 0 : (e << (w - c - o)) >> (w - c));" + )?; + + // End of function body + writeln!(self.out, "}}")?; + } + crate::MathFunction::InsertBits => { + // The behavior of our insertBits polyfill has the same constraints as the extractBits polyfill. + + let arg_ty = func_ctx.resolve_type(arg, &module.types); + let scalar = arg_ty.scalar().unwrap(); + let components = arg_ty.components(); + + let wrapped = WrappedMath { + fun, + scalar, + components, + }; + + if !self.wrapped.math.insert(wrapped) { + continue; + } + + // Write return type + self.write_value_type(module, arg_ty)?; + + let scalar_width: u8 = scalar.width * 8; + let scalar_max: u64 = match scalar.width { + 1 => 0xFF, + 2 => 0xFFFF, + 4 => 0xFFFFFFFF, + 8 => 0xFFFFFFFFFFFFFFFF, + _ => unreachable!(), + }; + + // Write function name and parameters + writeln!(self.out, " {INSERT_BITS_FUNCTION}(")?; + write!(self.out, " ")?; + self.write_value_type(module, arg_ty)?; + writeln!(self.out, " e,")?; + write!(self.out, " ")?; + self.write_value_type(module, arg_ty)?; + writeln!(self.out, " newbits,")?; + writeln!(self.out, " uint offset,")?; + writeln!(self.out, " uint count")?; + writeln!(self.out, ") {{")?; + + // Write function body + writeln!(self.out, " uint w = {scalar_width}u;")?; + writeln!(self.out, " uint o = min(offset, w);")?; + writeln!(self.out, " uint c = min(count, w - o);")?; + + // The `u` suffix on the literals is _extremely_ important. Otherwise it will use + // i32 shifting instead of the intended u32 shifting. + writeln!( + self.out, + " uint mask = (({scalar_max}u >> ({scalar_width}u - c)) << o);" + )?; + writeln!( + self.out, + " return (c == 0 ? e : ((e & ~mask) | ((newbits << o) & mask)));" + )?; + + // End of function body + writeln!(self.out, "}}")?; + } + _ => {} + } + } + } + + Ok(()) + } + /// Helper function that writes various wrapped functions pub(super) fn write_wrapped_functions( &mut self, module: &crate::Module, func_ctx: &FunctionCtx, ) -> BackendResult { + self.write_wrapped_math_functions(module, func_ctx)?; self.write_wrapped_compose_functions(module, func_ctx.expressions)?; for (handle, _) in func_ctx.expressions.iter() { diff --git a/third_party/rust/naga/src/back/hlsl/keywords.rs b/third_party/rust/naga/src/back/hlsl/keywords.rs index 059e533ff7..2cb715c42c 100644 --- a/third_party/rust/naga/src/back/hlsl/keywords.rs +++ b/third_party/rust/naga/src/back/hlsl/keywords.rs @@ -817,6 +817,8 @@ pub const RESERVED: &[&str] = &[ // Naga utilities super::writer::MODF_FUNCTION, super::writer::FREXP_FUNCTION, + super::writer::EXTRACT_BITS_FUNCTION, + super::writer::INSERT_BITS_FUNCTION, ]; // DXC scalar types, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/tools/clang/lib/AST/ASTContextHLSL.cpp#L48-L254 diff --git a/third_party/rust/naga/src/back/hlsl/mod.rs b/third_party/rust/naga/src/back/hlsl/mod.rs index 37ddbd3d67..f37a223f47 100644 --- a/third_party/rust/naga/src/back/hlsl/mod.rs +++ b/third_party/rust/naga/src/back/hlsl/mod.rs @@ -256,6 +256,7 @@ struct Wrapped { constructors: crate::FastHashSet<help::WrappedConstructor>, struct_matrix_access: crate::FastHashSet<help::WrappedStructMatrixAccess>, mat_cx2s: crate::FastHashSet<help::WrappedMatCx2>, + math: crate::FastHashSet<help::WrappedMath>, } impl Wrapped { @@ -265,6 +266,7 @@ impl Wrapped { self.constructors.clear(); self.struct_matrix_access.clear(); self.mat_cx2s.clear(); + self.math.clear(); } } diff --git a/third_party/rust/naga/src/back/hlsl/storage.rs b/third_party/rust/naga/src/back/hlsl/storage.rs index 1b8a6ec12d..4d3a6af56d 100644 --- a/third_party/rust/naga/src/back/hlsl/storage.rs +++ b/third_party/rust/naga/src/back/hlsl/storage.rs @@ -32,6 +32,16 @@ The [`temp_access_chain`] field is a member of [`Writer`] solely to allow re-use of the `Vec`'s dynamic allocation. Its value is no longer needed once HLSL for the access has been generated. +Note about DXC and Load/Store functions: + +DXC's HLSL has a generic [`Load` and `Store`] function for [`ByteAddressBuffer`] and +[`RWByteAddressBuffer`]. This is not available in FXC's HLSL, so we use +it only for types that are only available in DXC. Notably 64 and 16 bit types. + +FXC's HLSL has functions Load, Load2, Load3, and Load4 and Store, Store2, Store3, Store4. +This loads/stores a vector of length 1, 2, 3, or 4. We use that for 32bit types, bitcasting to the +correct type if necessary. + [`Storage`]: crate::AddressSpace::Storage [`ByteAddressBuffer`]: https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-object-byteaddressbuffer [`RWByteAddressBuffer`]: https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-object-rwbyteaddressbuffer @@ -42,6 +52,7 @@ needed once HLSL for the access has been generated. [`Writer::temp_access_chain`]: super::Writer::temp_access_chain [`temp_access_chain`]: super::Writer::temp_access_chain [`Writer`]: super::Writer +[`Load` and `Store`]: https://github.com/microsoft/DirectXShaderCompiler/wiki/ByteAddressBuffer-Load-Store-Additions */ use super::{super::FunctionCtx, BackendResult, Error}; @@ -161,20 +172,39 @@ impl<W: fmt::Write> super::Writer<'_, W> { // working around the borrow checker in `self.write_expr` let chain = mem::take(&mut self.temp_access_chain); let var_name = &self.names[&NameKey::GlobalVariable(var_handle)]; - let cast = scalar.kind.to_hlsl_cast(); - write!(self.out, "{cast}({var_name}.Load(")?; + // See note about DXC and Load/Store in the module's documentation. + if scalar.width == 4 { + let cast = scalar.kind.to_hlsl_cast(); + write!(self.out, "{cast}({var_name}.Load(")?; + } else { + let ty = scalar.to_hlsl_str()?; + write!(self.out, "{var_name}.Load<{ty}>(")?; + }; self.write_storage_address(module, &chain, func_ctx)?; - write!(self.out, "))")?; + write!(self.out, ")")?; + if scalar.width == 4 { + write!(self.out, ")")?; + } self.temp_access_chain = chain; } crate::TypeInner::Vector { size, scalar } => { // working around the borrow checker in `self.write_expr` let chain = mem::take(&mut self.temp_access_chain); let var_name = &self.names[&NameKey::GlobalVariable(var_handle)]; - let cast = scalar.kind.to_hlsl_cast(); - write!(self.out, "{}({}.Load{}(", cast, var_name, size as u8)?; + let size = size as u8; + // See note about DXC and Load/Store in the module's documentation. + if scalar.width == 4 { + let cast = scalar.kind.to_hlsl_cast(); + write!(self.out, "{cast}({var_name}.Load{size}(")?; + } else { + let ty = scalar.to_hlsl_str()?; + write!(self.out, "{var_name}.Load<{ty}{size}>(")?; + }; self.write_storage_address(module, &chain, func_ctx)?; - write!(self.out, "))")?; + write!(self.out, ")")?; + if scalar.width == 4 { + write!(self.out, ")")?; + } self.temp_access_chain = chain; } crate::TypeInner::Matrix { @@ -288,26 +318,44 @@ impl<W: fmt::Write> super::Writer<'_, W> { } }; match *ty_resolution.inner_with(&module.types) { - crate::TypeInner::Scalar(_) => { + crate::TypeInner::Scalar(scalar) => { // working around the borrow checker in `self.write_expr` let chain = mem::take(&mut self.temp_access_chain); let var_name = &self.names[&NameKey::GlobalVariable(var_handle)]; - write!(self.out, "{level}{var_name}.Store(")?; - self.write_storage_address(module, &chain, func_ctx)?; - write!(self.out, ", asuint(")?; - self.write_store_value(module, &value, func_ctx)?; - writeln!(self.out, "));")?; + // See note about DXC and Load/Store in the module's documentation. + if scalar.width == 4 { + write!(self.out, "{level}{var_name}.Store(")?; + self.write_storage_address(module, &chain, func_ctx)?; + write!(self.out, ", asuint(")?; + self.write_store_value(module, &value, func_ctx)?; + writeln!(self.out, "));")?; + } else { + write!(self.out, "{level}{var_name}.Store(")?; + self.write_storage_address(module, &chain, func_ctx)?; + write!(self.out, ", ")?; + self.write_store_value(module, &value, func_ctx)?; + writeln!(self.out, ");")?; + } self.temp_access_chain = chain; } - crate::TypeInner::Vector { size, .. } => { + crate::TypeInner::Vector { size, scalar } => { // working around the borrow checker in `self.write_expr` let chain = mem::take(&mut self.temp_access_chain); let var_name = &self.names[&NameKey::GlobalVariable(var_handle)]; - write!(self.out, "{}{}.Store{}(", level, var_name, size as u8)?; - self.write_storage_address(module, &chain, func_ctx)?; - write!(self.out, ", asuint(")?; - self.write_store_value(module, &value, func_ctx)?; - writeln!(self.out, "));")?; + // See note about DXC and Load/Store in the module's documentation. + if scalar.width == 4 { + write!(self.out, "{}{}.Store{}(", level, var_name, size as u8)?; + self.write_storage_address(module, &chain, func_ctx)?; + write!(self.out, ", asuint(")?; + self.write_store_value(module, &value, func_ctx)?; + writeln!(self.out, "));")?; + } else { + write!(self.out, "{}{}.Store(", level, var_name)?; + self.write_storage_address(module, &chain, func_ctx)?; + write!(self.out, ", ")?; + self.write_store_value(module, &value, func_ctx)?; + writeln!(self.out, ");")?; + } self.temp_access_chain = chain; } crate::TypeInner::Matrix { diff --git a/third_party/rust/naga/src/back/hlsl/writer.rs b/third_party/rust/naga/src/back/hlsl/writer.rs index 43f7212837..4ba856946b 100644 --- a/third_party/rust/naga/src/back/hlsl/writer.rs +++ b/third_party/rust/naga/src/back/hlsl/writer.rs @@ -19,6 +19,8 @@ const SPECIAL_OTHER: &str = "other"; pub(crate) const MODF_FUNCTION: &str = "naga_modf"; pub(crate) const FREXP_FUNCTION: &str = "naga_frexp"; +pub(crate) const EXTRACT_BITS_FUNCTION: &str = "naga_extractBits"; +pub(crate) const INSERT_BITS_FUNCTION: &str = "naga_insertBits"; struct EpStructMember { name: String, @@ -125,14 +127,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.need_bake_expressions.insert(fun_handle); } - if let Expression::Math { - fun, - arg, - arg1, - arg2, - arg3, - } = *expr - { + if let Expression::Math { fun, arg, .. } = *expr { match fun { crate::MathFunction::Asinh | crate::MathFunction::Acosh @@ -149,17 +144,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { | crate::MathFunction::Pack4x8unorm => { self.need_bake_expressions.insert(arg); } - crate::MathFunction::ExtractBits => { - self.need_bake_expressions.insert(arg); - self.need_bake_expressions.insert(arg1.unwrap()); - self.need_bake_expressions.insert(arg2.unwrap()); - } - crate::MathFunction::InsertBits => { - self.need_bake_expressions.insert(arg); - self.need_bake_expressions.insert(arg1.unwrap()); - self.need_bake_expressions.insert(arg2.unwrap()); - self.need_bake_expressions.insert(arg3.unwrap()); - } crate::MathFunction::CountLeadingZeros => { let inner = info[fun_handle].ty.inner_with(&module.types); if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() { @@ -2038,6 +2022,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { crate::Literal::F32(value) => write!(self.out, "{value:?}")?, crate::Literal::U32(value) => write!(self.out, "{}u", value)?, crate::Literal::I32(value) => write!(self.out, "{}", value)?, + crate::Literal::U64(value) => write!(self.out, "{}uL", value)?, crate::Literal::I64(value) => write!(self.out, "{}L", value)?, crate::Literal::Bool(value) => write!(self.out, "{}", value)?, crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => { @@ -2567,7 +2552,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { convert, } => { let inner = func_ctx.resolve_type(expr, &module.types); - match convert { + let close_paren = match convert { Some(dst_width) => { let scalar = crate::Scalar { kind, @@ -2600,13 +2585,21 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { ))); } }; + true } None => { - write!(self.out, "{}(", kind.to_hlsl_cast(),)?; + if inner.scalar_width() == Some(64) { + false + } else { + write!(self.out, "{}(", kind.to_hlsl_cast(),)?; + true + } } - } + }; self.write_expr(module, expr, func_ctx)?; - write!(self.out, ")")?; + if close_paren { + write!(self.out, ")")?; + } } Expression::Math { fun, @@ -2620,8 +2613,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { enum Function { Asincosh { is_sin: bool }, Atanh, - ExtractBits, - InsertBits, Pack2x16float, Pack2x16snorm, Pack2x16unorm, @@ -2705,8 +2696,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { Mf::ReverseBits => Function::MissingIntOverload("reversebits"), Mf::FindLsb => Function::MissingIntReturnType("firstbitlow"), Mf::FindMsb => Function::MissingIntReturnType("firstbithigh"), - Mf::ExtractBits => Function::ExtractBits, - Mf::InsertBits => Function::InsertBits, + Mf::ExtractBits => Function::Regular(EXTRACT_BITS_FUNCTION), + Mf::InsertBits => Function::Regular(INSERT_BITS_FUNCTION), // Data Packing Mf::Pack2x16float => Function::Pack2x16float, Mf::Pack2x16snorm => Function::Pack2x16snorm, @@ -2742,70 +2733,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.write_expr(module, arg, func_ctx)?; write!(self.out, "))")?; } - Function::ExtractBits => { - // e: T, - // offset: u32, - // count: u32 - // T is u32 or i32 or vecN<u32> or vecN<i32> - if let (Some(offset), Some(count)) = (arg1, arg2) { - let scalar_width: u8 = 32; - // Works for signed and unsigned - // (count == 0 ? 0 : (e << (32 - count - offset)) >> (32 - count)) - write!(self.out, "(")?; - self.write_expr(module, count, func_ctx)?; - write!(self.out, " == 0 ? 0 : (")?; - self.write_expr(module, arg, func_ctx)?; - write!(self.out, " << ({scalar_width} - ")?; - self.write_expr(module, count, func_ctx)?; - write!(self.out, " - ")?; - self.write_expr(module, offset, func_ctx)?; - write!(self.out, ")) >> ({scalar_width} - ")?; - self.write_expr(module, count, func_ctx)?; - write!(self.out, "))")?; - } - } - Function::InsertBits => { - // e: T, - // newbits: T, - // offset: u32, - // count: u32 - // returns T - // T is i32, u32, vecN<i32>, or vecN<u32> - if let (Some(newbits), Some(offset), Some(count)) = (arg1, arg2, arg3) { - let scalar_width: u8 = 32; - let scalar_max: u32 = 0xFFFFFFFF; - // mask = ((0xFFFFFFFFu >> (32 - count)) << offset) - // (count == 0 ? e : ((e & ~mask) | ((newbits << offset) & mask))) - write!(self.out, "(")?; - self.write_expr(module, count, func_ctx)?; - write!(self.out, " == 0 ? ")?; - self.write_expr(module, arg, func_ctx)?; - write!(self.out, " : ")?; - write!(self.out, "(")?; - self.write_expr(module, arg, func_ctx)?; - write!(self.out, " & ~")?; - // mask - write!(self.out, "(({scalar_max}u >> ({scalar_width}u - ")?; - self.write_expr(module, count, func_ctx)?; - write!(self.out, ")) << ")?; - self.write_expr(module, offset, func_ctx)?; - write!(self.out, ")")?; - // end mask - write!(self.out, ") | ((")?; - self.write_expr(module, newbits, func_ctx)?; - write!(self.out, " << ")?; - self.write_expr(module, offset, func_ctx)?; - write!(self.out, ") & ")?; - // // mask - write!(self.out, "(({scalar_max}u >> ({scalar_width}u - ")?; - self.write_expr(module, count, func_ctx)?; - write!(self.out, ")) << ")?; - self.write_expr(module, offset, func_ctx)?; - write!(self.out, ")")?; - // // end mask - write!(self.out, "))")?; - } - } Function::Pack2x16float => { write!(self.out, "(f32tof16(")?; self.write_expr(module, arg, func_ctx)?; @@ -2944,9 +2871,15 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } write!(self.out, ")")? } + // These overloads are only missing on FXC, so this is only needed for 32bit types, + // as non-32bit types are DXC only. Function::MissingIntOverload(fun_name) => { - let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar_kind(); - if let Some(ScalarKind::Sint) = scalar_kind { + let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar(); + if let Some(crate::Scalar { + kind: ScalarKind::Sint, + width: 4, + }) = scalar_kind + { write!(self.out, "asint({fun_name}(asuint(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, ")))")?; @@ -2956,9 +2889,15 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { write!(self.out, ")")?; } } + // These overloads are only missing on FXC, so this is only needed for 32bit types, + // as non-32bit types are DXC only. Function::MissingIntReturnType(fun_name) => { - let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar_kind(); - if let Some(ScalarKind::Sint) = scalar_kind { + let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar(); + if let Some(crate::Scalar { + kind: ScalarKind::Sint, + width: 4, + }) = scalar_kind + { write!(self.out, "asint({fun_name}(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, "))")?; @@ -2977,23 +2916,38 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { crate::VectorSize::Quad => ".xxxx", }; - if let ScalarKind::Uint = scalar.kind { - write!(self.out, "min((32u){s}, firstbitlow(")?; + let scalar_width_bits = scalar.width * 8; + + if scalar.kind == ScalarKind::Uint || scalar.width != 4 { + write!( + self.out, + "min(({scalar_width_bits}u){s}, firstbitlow(" + )?; self.write_expr(module, arg, func_ctx)?; write!(self.out, "))")?; } else { - write!(self.out, "asint(min((32u){s}, firstbitlow(")?; + // This is only needed for the FXC path, on 32bit signed integers. + write!( + self.out, + "asint(min(({scalar_width_bits}u){s}, firstbitlow(" + )?; self.write_expr(module, arg, func_ctx)?; write!(self.out, ")))")?; } } TypeInner::Scalar(scalar) => { - if let ScalarKind::Uint = scalar.kind { - write!(self.out, "min(32u, firstbitlow(")?; + let scalar_width_bits = scalar.width * 8; + + if scalar.kind == ScalarKind::Uint || scalar.width != 4 { + write!(self.out, "min({scalar_width_bits}u, firstbitlow(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, "))")?; } else { - write!(self.out, "asint(min(32u, firstbitlow(")?; + // This is only needed for the FXC path, on 32bit signed integers. + write!( + self.out, + "asint(min({scalar_width_bits}u, firstbitlow(" + )?; self.write_expr(module, arg, func_ctx)?; write!(self.out, ")))")?; } @@ -3012,30 +2966,47 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { crate::VectorSize::Quad => ".xxxx", }; - if let ScalarKind::Uint = scalar.kind { - write!(self.out, "((31u){s} - firstbithigh(")?; + // scalar width - 1 + let constant = scalar.width * 8 - 1; + + if scalar.kind == ScalarKind::Uint { + write!(self.out, "(({constant}u){s} - firstbithigh(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, "))")?; } else { + let conversion_func = match scalar.width { + 4 => "asint", + _ => "", + }; write!(self.out, "(")?; self.write_expr(module, arg, func_ctx)?; write!( self.out, - " < (0){s} ? (0){s} : (31){s} - asint(firstbithigh(" + " < (0){s} ? (0){s} : ({constant}){s} - {conversion_func}(firstbithigh(" )?; self.write_expr(module, arg, func_ctx)?; write!(self.out, ")))")?; } } TypeInner::Scalar(scalar) => { + // scalar width - 1 + let constant = scalar.width * 8 - 1; + if let ScalarKind::Uint = scalar.kind { - write!(self.out, "(31u - firstbithigh(")?; + write!(self.out, "({constant}u - firstbithigh(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, "))")?; } else { + let conversion_func = match scalar.width { + 4 => "asint", + _ => "", + }; write!(self.out, "(")?; self.write_expr(module, arg, func_ctx)?; - write!(self.out, " < 0 ? 0 : 31 - asint(firstbithigh(")?; + write!( + self.out, + " < 0 ? 0 : {constant} - {conversion_func}(firstbithigh(" + )?; self.write_expr(module, arg, func_ctx)?; write!(self.out, ")))")?; } diff --git a/third_party/rust/naga/src/back/msl/keywords.rs b/third_party/rust/naga/src/back/msl/keywords.rs index f0025bf239..73c457dd34 100644 --- a/third_party/rust/naga/src/back/msl/keywords.rs +++ b/third_party/rust/naga/src/back/msl/keywords.rs @@ -4,6 +4,8 @@ // C++ - Standard for Programming Language C++ (N4431) // https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2015/n4431.pdf pub const RESERVED: &[&str] = &[ + // Undocumented + "assert", // found in https://github.com/gfx-rs/wgpu/issues/5347 // Standard for Programming Language C++ (N4431): 2.5 Alternative tokens "and", "bitor", diff --git a/third_party/rust/naga/src/back/msl/mod.rs b/third_party/rust/naga/src/back/msl/mod.rs index 5ef18730c9..68e5b79906 100644 --- a/third_party/rust/naga/src/back/msl/mod.rs +++ b/third_party/rust/naga/src/back/msl/mod.rs @@ -121,8 +121,8 @@ pub enum Error { UnsupportedCall(String), #[error("feature '{0}' is not implemented yet")] FeatureNotImplemented(String), - #[error("module is not valid")] - Validation, + #[error("internal naga error: module should not have validated: {0}")] + GenericValidation(String), #[error("BuiltIn {0:?} is not supported")] UnsupportedBuiltIn(crate::BuiltIn), #[error("capability {0:?} is not supported")] @@ -306,13 +306,10 @@ impl Options { }, }) } - LocationMode::Uniform => { - log::error!( - "Unexpected Binding::Location({}) for the Uniform mode", - location - ); - Err(Error::Validation) - } + LocationMode::Uniform => Err(Error::GenericValidation(format!( + "Unexpected Binding::Location({}) for the Uniform mode", + location + ))), }, } } diff --git a/third_party/rust/naga/src/back/msl/writer.rs b/third_party/rust/naga/src/back/msl/writer.rs index 1e496b5f50..5227d8e7db 100644 --- a/third_party/rust/naga/src/back/msl/writer.rs +++ b/third_party/rust/naga/src/back/msl/writer.rs @@ -319,7 +319,7 @@ pub struct Writer<W> { } impl crate::Scalar { - const fn to_msl_name(self) -> &'static str { + fn to_msl_name(self) -> &'static str { use crate::ScalarKind as Sk; match self { Self { @@ -328,20 +328,29 @@ impl crate::Scalar { } => "float", Self { kind: Sk::Sint, - width: _, + width: 4, } => "int", Self { kind: Sk::Uint, - width: _, + width: 4, } => "uint", Self { + kind: Sk::Sint, + width: 8, + } => "long", + Self { + kind: Sk::Uint, + width: 8, + } => "ulong", + Self { kind: Sk::Bool, width: _, } => "bool", Self { kind: Sk::AbstractInt | Sk::AbstractFloat, width: _, - } => unreachable!(), + } => unreachable!("Found Abstract scalar kind"), + _ => unreachable!("Unsupported scalar kind: {:?}", self), } } } @@ -735,7 +744,11 @@ impl<W: Write> Writer<W> { crate::TypeInner::Vector { size, .. } => { put_numeric_type(&mut self.out, crate::Scalar::U32, &[size])? } - _ => return Err(Error::Validation), + _ => { + return Err(Error::GenericValidation( + "Invalid type for image coordinate".into(), + )) + } }; write!(self.out, "(")?; @@ -1068,13 +1081,17 @@ impl<W: Write> Writer<W> { let (offset, array_ty) = match context.module.types[global.ty].inner { crate::TypeInner::Struct { ref members, .. } => match members.last() { Some(&crate::StructMember { offset, ty, .. }) => (offset, ty), - None => return Err(Error::Validation), + None => return Err(Error::GenericValidation("Struct has no members".into())), }, crate::TypeInner::Array { size: crate::ArraySize::Dynamic, .. } => (0, global.ty), - _ => return Err(Error::Validation), + ref ty => { + return Err(Error::GenericValidation(format!( + "Expected type with dynamic array, got {ty:?}" + ))) + } }; let (size, stride) = match context.module.types[array_ty].inner { @@ -1084,7 +1101,11 @@ impl<W: Write> Writer<W> { .size(context.module.to_ctx()), stride, ), - _ => return Err(Error::Validation), + ref ty => { + return Err(Error::GenericValidation(format!( + "Expected array type, got {ty:?}" + ))) + } }; // When the stride length is larger than the size, the final element's stride of @@ -1273,6 +1294,9 @@ impl<W: Write> Writer<W> { crate::Literal::I32(value) => { write!(self.out, "{value}")?; } + crate::Literal::U64(value) => { + write!(self.out, "{value}uL")?; + } crate::Literal::I64(value) => { write!(self.out, "{value}L")?; } @@ -1280,7 +1304,9 @@ impl<W: Write> Writer<W> { write!(self.out, "{value}")?; } crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => { - return Err(Error::Validation); + return Err(Error::GenericValidation( + "Unsupported abstract literal".into(), + )); } }, crate::Expression::Constant(handle) => { @@ -1342,7 +1368,11 @@ impl<W: Write> Writer<W> { crate::Expression::Splat { size, value } => { let scalar = match *get_expr_ty(ctx, value).inner_with(&module.types) { crate::TypeInner::Scalar(scalar) => scalar, - _ => return Err(Error::Validation), + ref ty => { + return Err(Error::GenericValidation(format!( + "Expected splat value type must be a scalar, got {ty:?}", + ))) + } }; put_numeric_type(&mut self.out, scalar, &[size])?; write!(self.out, "(")?; @@ -1672,7 +1702,11 @@ impl<W: Write> Writer<W> { self.put_expression(condition, context, true)?; write!(self.out, ")")?; } - _ => return Err(Error::Validation), + ref ty => { + return Err(Error::GenericValidation(format!( + "Expected select condition to be a non-bool type, got {ty:?}", + ))) + } }, crate::Expression::Derivative { axis, expr, .. } => { use crate::DerivativeAxis as Axis; @@ -1794,8 +1828,8 @@ impl<W: Write> Writer<W> { Mf::CountLeadingZeros => "clz", Mf::CountOneBits => "popcount", Mf::ReverseBits => "reverse_bits", - Mf::ExtractBits => "extract_bits", - Mf::InsertBits => "insert_bits", + Mf::ExtractBits => "", + Mf::InsertBits => "", Mf::FindLsb => "", Mf::FindMsb => "", // data packing @@ -1836,15 +1870,23 @@ impl<W: Write> Writer<W> { self.put_expression(arg1.unwrap(), context, false)?; write!(self.out, ")")?; } else if fun == Mf::FindLsb { + let scalar = context.resolve_type(arg).scalar().unwrap(); + let constant = scalar.width * 8 + 1; + write!(self.out, "((({NAMESPACE}::ctz(")?; self.put_expression(arg, context, true)?; - write!(self.out, ") + 1) % 33) - 1)")?; + write!(self.out, ") + 1) % {constant}) - 1)")?; } else if fun == Mf::FindMsb { let inner = context.resolve_type(arg); + let scalar = inner.scalar().unwrap(); + let constant = scalar.width * 8 - 1; - write!(self.out, "{NAMESPACE}::select(31 - {NAMESPACE}::clz(")?; + write!( + self.out, + "{NAMESPACE}::select({constant} - {NAMESPACE}::clz(" + )?; - if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() { + if scalar.kind == crate::ScalarKind::Sint { write!(self.out, "{NAMESPACE}::select(")?; self.put_expression(arg, context, true)?; write!(self.out, ", ~")?; @@ -1862,18 +1904,12 @@ impl<W: Write> Writer<W> { match *inner { crate::TypeInner::Vector { size, scalar } => { let size = back::vector_size_str(size); - if let crate::ScalarKind::Sint = scalar.kind { - write!(self.out, "int{size}")?; - } else { - write!(self.out, "uint{size}")?; - } + let name = scalar.to_msl_name(); + write!(self.out, "{name}{size}")?; } crate::TypeInner::Scalar(scalar) => { - if let crate::ScalarKind::Sint = scalar.kind { - write!(self.out, "int")?; - } else { - write!(self.out, "uint")?; - } + let name = scalar.to_msl_name(); + write!(self.out, "{name}")?; } _ => (), } @@ -1891,6 +1927,52 @@ impl<W: Write> Writer<W> { write!(self.out, "as_type<uint>(half2(")?; self.put_expression(arg, context, false)?; write!(self.out, "))")?; + } else if fun == Mf::ExtractBits { + // The behavior of ExtractBits is undefined when offset + count > bit_width. We need + // to first sanitize the offset and count first. If we don't do this, Apple chips + // will return out-of-spec values if the extracted range is not within the bit width. + // + // This encodes the exact formula specified by the wgsl spec, without temporary values: + // https://gpuweb.github.io/gpuweb/wgsl/#extractBits-unsigned-builtin + // + // w = sizeof(x) * 8 + // o = min(offset, w) + // tmp = w - o + // c = min(count, tmp) + // + // bitfieldExtract(x, o, c) + // + // extract_bits(e, min(offset, w), min(count, w - min(offset, w)))) + + let scalar_bits = context.resolve_type(arg).scalar_width().unwrap(); + + write!(self.out, "{NAMESPACE}::extract_bits(")?; + self.put_expression(arg, context, true)?; + write!(self.out, ", {NAMESPACE}::min(")?; + self.put_expression(arg1.unwrap(), context, true)?; + write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?; + self.put_expression(arg2.unwrap(), context, true)?; + write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?; + self.put_expression(arg1.unwrap(), context, true)?; + write!(self.out, ", {scalar_bits}u)))")?; + } else if fun == Mf::InsertBits { + // The behavior of InsertBits has the same issue as ExtractBits. + // + // insertBits(e, newBits, min(offset, w), min(count, w - min(offset, w)))) + + let scalar_bits = context.resolve_type(arg).scalar_width().unwrap(); + + write!(self.out, "{NAMESPACE}::insert_bits(")?; + self.put_expression(arg, context, true)?; + write!(self.out, ", ")?; + self.put_expression(arg1.unwrap(), context, true)?; + write!(self.out, ", {NAMESPACE}::min(")?; + self.put_expression(arg2.unwrap(), context, true)?; + write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?; + self.put_expression(arg3.unwrap(), context, true)?; + write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?; + self.put_expression(arg2.unwrap(), context, true)?; + write!(self.out, ", {scalar_bits}u)))")?; } else if fun == Mf::Radians { write!(self.out, "((")?; self.put_expression(arg, context, false)?; @@ -1920,14 +2002,8 @@ impl<W: Write> Writer<W> { kind, width: convert.unwrap_or(src.width), }; - let is_bool_cast = - kind == crate::ScalarKind::Bool || src.kind == crate::ScalarKind::Bool; let op = match convert { - Some(w) if w == src.width || is_bool_cast => "static_cast", - Some(8) if kind == crate::ScalarKind::Float => { - return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64)) - } - Some(_) => return Err(Error::Validation), + Some(_) => "static_cast", None => "as_type", }; write!(self.out, "{op}<")?; @@ -1955,7 +2031,11 @@ impl<W: Write> Writer<W> { self.put_expression(expr, context, true)?; write!(self.out, ")")?; } - _ => return Err(Error::Validation), + ref ty => { + return Err(Error::GenericValidation(format!( + "Unsupported type for As: {ty:?}" + ))) + } }, // has to be a named expression crate::Expression::CallResult(_) @@ -1970,11 +2050,19 @@ impl<W: Write> Writer<W> { crate::Expression::AccessIndex { base, .. } => { match context.function.expressions[base] { crate::Expression::GlobalVariable(handle) => handle, - _ => return Err(Error::Validation), + ref ex => { + return Err(Error::GenericValidation(format!( + "Expected global variable in AccessIndex, got {ex:?}" + ))) + } } } crate::Expression::GlobalVariable(handle) => handle, - _ => return Err(Error::Validation), + ref ex => { + return Err(Error::GenericValidation(format!( + "Unexpected expression in ArrayLength, got {ex:?}" + ))) + } }; if !is_scoped { @@ -2140,10 +2228,12 @@ impl<W: Write> Writer<W> { match length { index::IndexableLength::Known(value) => write!(self.out, "{value}")?, index::IndexableLength::Dynamic => { - let global = context - .function - .originating_global(base) - .ok_or(Error::Validation)?; + let global = + context.function.originating_global(base).ok_or_else(|| { + Error::GenericValidation( + "Could not find originating global".into(), + ) + })?; write!(self.out, "1 + ")?; self.put_dynamic_array_max_index(global, context)? } @@ -2300,10 +2390,9 @@ impl<W: Write> Writer<W> { write!(self.out, "{}u", limit - 1)?; } index::IndexableLength::Dynamic => { - let global = context - .function - .originating_global(base) - .ok_or(Error::Validation)?; + let global = context.function.originating_global(base).ok_or_else(|| { + Error::GenericValidation("Could not find originating global".into()) + })?; self.put_dynamic_array_max_index(global, context)?; } } @@ -2489,7 +2578,14 @@ impl<W: Write> Writer<W> { } } - if let Expression::Math { fun, arg, arg1, .. } = *expr { + if let Expression::Math { + fun, + arg, + arg1, + arg2, + .. + } = *expr + { match fun { crate::MathFunction::Dot => { // WGSL's `dot` function works on any `vecN` type, but Metal's only @@ -2514,6 +2610,14 @@ impl<W: Write> Writer<W> { crate::MathFunction::FindMsb => { self.need_bake_expressions.insert(arg); } + crate::MathFunction::ExtractBits => { + // Only argument 1 is re-used. + self.need_bake_expressions.insert(arg1.unwrap()); + } + crate::MathFunction::InsertBits => { + // Only argument 2 is re-used. + self.need_bake_expressions.insert(arg2.unwrap()); + } crate::MathFunction::Sign => { // WGSL's `sign` function works also on signed ints, but Metal's only // works on floating points, so we emit inline code for integer `sign` @@ -3048,7 +3152,7 @@ impl<W: Write> Writer<W> { for statement in statements { if let crate::Statement::Emit(ref range) = *statement { for handle in range.clone() { - self.named_expressions.remove(&handle); + self.named_expressions.shift_remove(&handle); } } } @@ -3897,7 +4001,9 @@ impl<W: Write> Writer<W> { binding: None, first_time: true, }; - let binding = binding.ok_or(Error::Validation)?; + let binding = binding.ok_or_else(|| { + Error::GenericValidation("Expected binding, got None".into()) + })?; if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = *binding { has_point_size = true; diff --git a/third_party/rust/naga/src/back/spv/block.rs b/third_party/rust/naga/src/back/spv/block.rs index 6c96fa09e3..81f2fc10e0 100644 --- a/third_party/rust/naga/src/back/spv/block.rs +++ b/third_party/rust/naga/src/back/spv/block.rs @@ -731,12 +731,41 @@ impl<'w> BlockContext<'w> { Some(crate::ScalarKind::Uint) => spirv::GLOp::UMax, other => unimplemented!("Unexpected max({:?})", other), }), - Mf::Clamp => MathOp::Ext(match arg_scalar_kind { - Some(crate::ScalarKind::Float) => spirv::GLOp::FClamp, - Some(crate::ScalarKind::Sint) => spirv::GLOp::SClamp, - Some(crate::ScalarKind::Uint) => spirv::GLOp::UClamp, + Mf::Clamp => match arg_scalar_kind { + // Clamp is undefined if min > max. In practice this means it can use a median-of-three + // instruction to determine the value. This is fine according to the WGSL spec for float + // clamp, but integer clamp _must_ use min-max. As such we write out min/max. + Some(crate::ScalarKind::Float) => MathOp::Ext(spirv::GLOp::FClamp), + Some(_) => { + let (min_op, max_op) = match arg_scalar_kind { + Some(crate::ScalarKind::Sint) => { + (spirv::GLOp::SMin, spirv::GLOp::SMax) + } + Some(crate::ScalarKind::Uint) => { + (spirv::GLOp::UMin, spirv::GLOp::UMax) + } + _ => unreachable!(), + }; + + let max_id = self.gen_id(); + block.body.push(Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + max_op, + result_type_id, + max_id, + &[arg0_id, arg1_id], + )); + + MathOp::Custom(Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + min_op, + result_type_id, + id, + &[max_id, arg2_id], + )) + } other => unimplemented!("Unexpected max({:?})", other), - }), + }, Mf::Saturate => { let (maybe_size, scalar) = match *arg_ty { crate::TypeInner::Vector { size, scalar } => (Some(size), scalar), @@ -915,8 +944,7 @@ impl<'w> BlockContext<'w> { )), Mf::CountTrailingZeros => { let uint_id = match *arg_ty { - crate::TypeInner::Vector { size, mut scalar } => { - scalar.kind = crate::ScalarKind::Uint; + crate::TypeInner::Vector { size, scalar } => { let ty = LocalType::Value { vector_size: Some(size), scalar, @@ -927,15 +955,15 @@ impl<'w> BlockContext<'w> { self.temp_list.clear(); self.temp_list.resize( size as _, - self.writer.get_constant_scalar_with(32, scalar)?, + self.writer + .get_constant_scalar_with(scalar.width * 8, scalar)?, ); self.writer.get_constant_composite(ty, &self.temp_list) } - crate::TypeInner::Scalar(mut scalar) => { - scalar.kind = crate::ScalarKind::Uint; - self.writer.get_constant_scalar_with(32, scalar)? - } + crate::TypeInner::Scalar(scalar) => self + .writer + .get_constant_scalar_with(scalar.width * 8, scalar)?, _ => unreachable!(), }; @@ -957,9 +985,8 @@ impl<'w> BlockContext<'w> { )) } Mf::CountLeadingZeros => { - let (int_type_id, int_id) = match *arg_ty { - crate::TypeInner::Vector { size, mut scalar } => { - scalar.kind = crate::ScalarKind::Sint; + let (int_type_id, int_id, width) = match *arg_ty { + crate::TypeInner::Vector { size, scalar } => { let ty = LocalType::Value { vector_size: Some(size), scalar, @@ -970,32 +997,41 @@ impl<'w> BlockContext<'w> { self.temp_list.clear(); self.temp_list.resize( size as _, - self.writer.get_constant_scalar_with(31, scalar)?, + self.writer + .get_constant_scalar_with(scalar.width * 8 - 1, scalar)?, ); ( self.get_type_id(ty), self.writer.get_constant_composite(ty, &self.temp_list), + scalar.width, ) } - crate::TypeInner::Scalar(mut scalar) => { - scalar.kind = crate::ScalarKind::Sint; - ( - self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: None, - scalar, - pointer_space: None, - })), - self.writer.get_constant_scalar_with(31, scalar)?, - ) - } + crate::TypeInner::Scalar(scalar) => ( + self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar, + pointer_space: None, + })), + self.writer + .get_constant_scalar_with(scalar.width * 8 - 1, scalar)?, + scalar.width, + ), _ => unreachable!(), }; + if width != 4 { + unreachable!("This is validated out until a polyfill is implemented. https://github.com/gfx-rs/wgpu/issues/5276"); + }; + let msb_id = self.gen_id(); block.body.push(Instruction::ext_inst( self.writer.gl450_ext_inst_id, - spirv::GLOp::FindUMsb, + if width != 4 { + spirv::GLOp::FindILsb + } else { + spirv::GLOp::FindUMsb + }, int_type_id, msb_id, &[arg0_id], @@ -1021,30 +1057,144 @@ impl<'w> BlockContext<'w> { Some(crate::ScalarKind::Sint) => spirv::Op::BitFieldSExtract, other => unimplemented!("Unexpected sign({:?})", other), }; + + // The behavior of ExtractBits is undefined when offset + count > bit_width. We need + // to first sanitize the offset and count first. If we don't do this, AMD and Intel + // will return out-of-spec values if the extracted range is not within the bit width. + // + // This encodes the exact formula specified by the wgsl spec: + // https://gpuweb.github.io/gpuweb/wgsl/#extractBits-unsigned-builtin + // + // w = sizeof(x) * 8 + // o = min(offset, w) + // tmp = w - o + // c = min(count, tmp) + // + // bitfieldExtract(x, o, c) + + let bit_width = arg_ty.scalar_width().unwrap(); + let width_constant = self + .writer + .get_constant_scalar(crate::Literal::U32(bit_width as u32)); + + let u32_type = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar: crate::Scalar { + kind: crate::ScalarKind::Uint, + width: 4, + }, + pointer_space: None, + })); + + // o = min(offset, w) + let offset_id = self.gen_id(); + block.body.push(Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + spirv::GLOp::UMin, + u32_type, + offset_id, + &[arg1_id, width_constant], + )); + + // tmp = w - o + let max_count_id = self.gen_id(); + block.body.push(Instruction::binary( + spirv::Op::ISub, + u32_type, + max_count_id, + width_constant, + offset_id, + )); + + // c = min(count, tmp) + let count_id = self.gen_id(); + block.body.push(Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + spirv::GLOp::UMin, + u32_type, + count_id, + &[arg2_id, max_count_id], + )); + MathOp::Custom(Instruction::ternary( op, result_type_id, id, arg0_id, + offset_id, + count_id, + )) + } + Mf::InsertBits => { + // The behavior of InsertBits has the same undefined behavior as ExtractBits. + + let bit_width = arg_ty.scalar_width().unwrap(); + let width_constant = self + .writer + .get_constant_scalar(crate::Literal::U32(bit_width as u32)); + + let u32_type = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar: crate::Scalar { + kind: crate::ScalarKind::Uint, + width: 4, + }, + pointer_space: None, + })); + + // o = min(offset, w) + let offset_id = self.gen_id(); + block.body.push(Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + spirv::GLOp::UMin, + u32_type, + offset_id, + &[arg2_id, width_constant], + )); + + // tmp = w - o + let max_count_id = self.gen_id(); + block.body.push(Instruction::binary( + spirv::Op::ISub, + u32_type, + max_count_id, + width_constant, + offset_id, + )); + + // c = min(count, tmp) + let count_id = self.gen_id(); + block.body.push(Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + spirv::GLOp::UMin, + u32_type, + count_id, + &[arg3_id, max_count_id], + )); + + MathOp::Custom(Instruction::quaternary( + spirv::Op::BitFieldInsert, + result_type_id, + id, + arg0_id, arg1_id, - arg2_id, + offset_id, + count_id, )) } - Mf::InsertBits => MathOp::Custom(Instruction::quaternary( - spirv::Op::BitFieldInsert, - result_type_id, - id, - arg0_id, - arg1_id, - arg2_id, - arg3_id, - )), Mf::FindLsb => MathOp::Ext(spirv::GLOp::FindILsb), - Mf::FindMsb => MathOp::Ext(match arg_scalar_kind { - Some(crate::ScalarKind::Uint) => spirv::GLOp::FindUMsb, - Some(crate::ScalarKind::Sint) => spirv::GLOp::FindSMsb, - other => unimplemented!("Unexpected findMSB({:?})", other), - }), + Mf::FindMsb => { + if arg_ty.scalar_width() == Some(32) { + let thing = match arg_scalar_kind { + Some(crate::ScalarKind::Uint) => spirv::GLOp::FindUMsb, + Some(crate::ScalarKind::Sint) => spirv::GLOp::FindSMsb, + other => unimplemented!("Unexpected findMSB({:?})", other), + }; + MathOp::Ext(thing) + } else { + unreachable!("This is validated out until a polyfill is implemented. https://github.com/gfx-rs/wgpu/issues/5276"); + } + } Mf::Pack4x8unorm => MathOp::Ext(spirv::GLOp::PackUnorm4x8), Mf::Pack4x8snorm => MathOp::Ext(spirv::GLOp::PackSnorm4x8), Mf::Pack2x16float => MathOp::Ext(spirv::GLOp::PackHalf2x16), @@ -1250,6 +1400,12 @@ impl<'w> BlockContext<'w> { (Sk::Uint, Sk::Uint, Some(dst_width)) if src_scalar.width != dst_width => { Cast::Unary(spirv::Op::UConvert) } + (Sk::Uint, Sk::Sint, Some(dst_width)) if src_scalar.width != dst_width => { + Cast::Unary(spirv::Op::SConvert) + } + (Sk::Sint, Sk::Uint, Some(dst_width)) if src_scalar.width != dst_width => { + Cast::Unary(spirv::Op::UConvert) + } // We assume it's either an identity cast, or int-uint. _ => Cast::Unary(spirv::Op::Bitcast), } diff --git a/third_party/rust/naga/src/back/spv/writer.rs b/third_party/rust/naga/src/back/spv/writer.rs index 4db86c93a7..de3220bbda 100644 --- a/third_party/rust/naga/src/back/spv/writer.rs +++ b/third_party/rust/naga/src/back/spv/writer.rs @@ -1182,6 +1182,9 @@ impl Writer { crate::Literal::F32(value) => Instruction::constant_32bit(type_id, id, value.to_bits()), crate::Literal::U32(value) => Instruction::constant_32bit(type_id, id, value), crate::Literal::I32(value) => Instruction::constant_32bit(type_id, id, value as u32), + crate::Literal::U64(value) => { + Instruction::constant_64bit(type_id, id, value as u32, (value >> 32) as u32) + } crate::Literal::I64(value) => { Instruction::constant_64bit(type_id, id, value as u32, (value >> 32) as u32) } diff --git a/third_party/rust/naga/src/back/wgsl/writer.rs b/third_party/rust/naga/src/back/wgsl/writer.rs index c737934f5e..3039cbbbe4 100644 --- a/third_party/rust/naga/src/back/wgsl/writer.rs +++ b/third_party/rust/naga/src/back/wgsl/writer.rs @@ -109,7 +109,7 @@ impl<W: Write> Writer<W> { self.reset(module); // Save all ep result types - for (_, ep) in module.entry_points.iter().enumerate() { + for ep in &module.entry_points { if let Some(ref result) = ep.function.result { self.ep_results.push((ep.stage, result.ty)); } @@ -593,6 +593,7 @@ impl<W: Write> Writer<W> { } write!(self.out, ">")?; } + TypeInner::AccelerationStructure => write!(self.out, "acceleration_structure")?, _ => { return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))); } @@ -1095,16 +1096,24 @@ impl<W: Write> Writer<W> { // value can only be expressed in WGSL using AbstractInt and // a unary negation operator. if value == i32::MIN { - write!(self.out, "i32(-2147483648)")?; + write!(self.out, "i32({})", value)?; } else { write!(self.out, "{}i", value)?; } } crate::Literal::Bool(value) => write!(self.out, "{}", value)?, crate::Literal::F64(value) => write!(self.out, "{:?}lf", value)?, - crate::Literal::I64(_) => { - return Err(Error::Custom("unsupported i64 literal".to_string())); + crate::Literal::I64(value) => { + // `-9223372036854775808li` is not valid WGSL. The most negative `i64` + // value can only be expressed in WGSL using AbstractInt and + // a unary negation operator. + if value == i64::MIN { + write!(self.out, "i64({})", value)?; + } else { + write!(self.out, "{}li", value)?; + } } + crate::Literal::U64(value) => write!(self.out, "{:?}lu", value)?, crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => { return Err(Error::Custom( "Abstract types should not appear in IR presented to backends".into(), @@ -1828,6 +1837,14 @@ const fn scalar_kind_str(scalar: crate::Scalar) -> &'static str { width: 4, } => "u32", Scalar { + kind: Sk::Sint, + width: 8, + } => "i64", + Scalar { + kind: Sk::Uint, + width: 8, + } => "u64", + Scalar { kind: Sk::Bool, width: 1, } => "bool", diff --git a/third_party/rust/naga/src/front/glsl/functions.rs b/third_party/rust/naga/src/front/glsl/functions.rs index df8cc8a30e..01846eb814 100644 --- a/third_party/rust/naga/src/front/glsl/functions.rs +++ b/third_party/rust/naga/src/front/glsl/functions.rs @@ -160,7 +160,7 @@ impl Frontend { } => self.matrix_one_arg(ctx, ty, columns, rows, scalar, (value, expr_meta), meta)?, TypeInner::Struct { ref members, .. } => { let scalar_components = members - .get(0) + .first() .and_then(|member| scalar_components(&ctx.module.types[member.ty].inner)); if let Some(scalar) = scalar_components { ctx.implicit_conversion(&mut value, expr_meta, scalar)?; diff --git a/third_party/rust/naga/src/front/glsl/parser/functions.rs b/third_party/rust/naga/src/front/glsl/parser/functions.rs index 38184eedf7..d428d74761 100644 --- a/third_party/rust/naga/src/front/glsl/parser/functions.rs +++ b/third_party/rust/naga/src/front/glsl/parser/functions.rs @@ -435,7 +435,7 @@ impl<'source> ParsingContext<'source> { if self.bump_if(frontend, TokenValue::Semicolon).is_none() { if self.peek_type_name(frontend) || self.peek_type_qualifier(frontend) { - self.parse_declaration(frontend, ctx, false, false)?; + self.parse_declaration(frontend, ctx, false, is_inside_loop)?; } else { let mut stmt = ctx.stmt_ctx(); let expr = self.parse_expression(frontend, ctx, &mut stmt)?; diff --git a/third_party/rust/naga/src/front/spv/function.rs b/third_party/rust/naga/src/front/spv/function.rs index 198d9c52dd..e81ecf5c9b 100644 --- a/third_party/rust/naga/src/front/spv/function.rs +++ b/third_party/rust/naga/src/front/spv/function.rs @@ -292,278 +292,286 @@ impl<I: Iterator<Item = u32>> super::Frontend<I> { ); if let Some(ep) = self.lookup_entry_point.remove(&fun_id) { - // create a wrapping function - let mut function = crate::Function { - name: Some(format!("{}_wrap", ep.name)), - arguments: Vec::new(), - result: None, - local_variables: Arena::new(), - expressions: Arena::new(), - named_expressions: crate::NamedExpressions::default(), - body: crate::Block::new(), - }; + self.deferred_entry_points.push((ep, fun_id)); + } - // 1. copy the inputs from arguments to privates - for &v_id in ep.variable_ids.iter() { - let lvar = self.lookup_variable.lookup(v_id)?; - if let super::Variable::Input(ref arg) = lvar.inner { - let span = module.global_variables.get_span(lvar.handle); - let arg_expr = function.expressions.append( - crate::Expression::FunctionArgument(function.arguments.len() as u32), - span, - ); - let load_expr = if arg.ty == module.global_variables[lvar.handle].ty { - arg_expr - } else { - // The only case where the type is different is if we need to treat - // unsigned integer as signed. - let mut emitter = Emitter::default(); - emitter.start(&function.expressions); - let handle = function.expressions.append( - crate::Expression::As { - expr: arg_expr, - kind: crate::ScalarKind::Sint, - convert: Some(4), - }, - span, - ); - function.body.extend(emitter.finish(&function.expressions)); - handle - }; - function.body.push( - crate::Statement::Store { - pointer: function - .expressions - .append(crate::Expression::GlobalVariable(lvar.handle), span), - value: load_expr, + Ok(()) + } + + pub(super) fn process_entry_point( + &mut self, + module: &mut crate::Module, + ep: super::EntryPoint, + fun_id: u32, + ) -> Result<(), Error> { + // create a wrapping function + let mut function = crate::Function { + name: Some(format!("{}_wrap", ep.name)), + arguments: Vec::new(), + result: None, + local_variables: Arena::new(), + expressions: Arena::new(), + named_expressions: crate::NamedExpressions::default(), + body: crate::Block::new(), + }; + + // 1. copy the inputs from arguments to privates + for &v_id in ep.variable_ids.iter() { + let lvar = self.lookup_variable.lookup(v_id)?; + if let super::Variable::Input(ref arg) = lvar.inner { + let span = module.global_variables.get_span(lvar.handle); + let arg_expr = function.expressions.append( + crate::Expression::FunctionArgument(function.arguments.len() as u32), + span, + ); + let load_expr = if arg.ty == module.global_variables[lvar.handle].ty { + arg_expr + } else { + // The only case where the type is different is if we need to treat + // unsigned integer as signed. + let mut emitter = Emitter::default(); + emitter.start(&function.expressions); + let handle = function.expressions.append( + crate::Expression::As { + expr: arg_expr, + kind: crate::ScalarKind::Sint, + convert: Some(4), }, span, ); + function.body.extend(emitter.finish(&function.expressions)); + handle + }; + function.body.push( + crate::Statement::Store { + pointer: function + .expressions + .append(crate::Expression::GlobalVariable(lvar.handle), span), + value: load_expr, + }, + span, + ); - let mut arg = arg.clone(); - if ep.stage == crate::ShaderStage::Fragment { - if let Some(ref mut binding) = arg.binding { - binding.apply_default_interpolation(&module.types[arg.ty].inner); - } + let mut arg = arg.clone(); + if ep.stage == crate::ShaderStage::Fragment { + if let Some(ref mut binding) = arg.binding { + binding.apply_default_interpolation(&module.types[arg.ty].inner); } - function.arguments.push(arg); } + function.arguments.push(arg); } - // 2. call the wrapped function - let fake_id = !(module.entry_points.len() as u32); // doesn't matter, as long as it's not a collision - let dummy_handle = self.add_call(fake_id, fun_id); - function.body.push( - crate::Statement::Call { - function: dummy_handle, - arguments: Vec::new(), - result: None, - }, - crate::Span::default(), - ); - - // 3. copy the outputs from privates to the result - let mut members = Vec::new(); - let mut components = Vec::new(); - for &v_id in ep.variable_ids.iter() { - let lvar = self.lookup_variable.lookup(v_id)?; - if let super::Variable::Output(ref result) = lvar.inner { - let span = module.global_variables.get_span(lvar.handle); - let expr_handle = function - .expressions - .append(crate::Expression::GlobalVariable(lvar.handle), span); + } + // 2. call the wrapped function + let fake_id = !(module.entry_points.len() as u32); // doesn't matter, as long as it's not a collision + let dummy_handle = self.add_call(fake_id, fun_id); + function.body.push( + crate::Statement::Call { + function: dummy_handle, + arguments: Vec::new(), + result: None, + }, + crate::Span::default(), + ); - // Cull problematic builtins of gl_PerVertex. - // See the docs for `Frontend::gl_per_vertex_builtin_access`. + // 3. copy the outputs from privates to the result + let mut members = Vec::new(); + let mut components = Vec::new(); + for &v_id in ep.variable_ids.iter() { + let lvar = self.lookup_variable.lookup(v_id)?; + if let super::Variable::Output(ref result) = lvar.inner { + let span = module.global_variables.get_span(lvar.handle); + let expr_handle = function + .expressions + .append(crate::Expression::GlobalVariable(lvar.handle), span); + + // Cull problematic builtins of gl_PerVertex. + // See the docs for `Frontend::gl_per_vertex_builtin_access`. + { + let ty = &module.types[result.ty]; + if let crate::TypeInner::Struct { + members: ref original_members, + span, + } = ty.inner { - let ty = &module.types[result.ty]; - match ty.inner { - crate::TypeInner::Struct { - members: ref original_members, - span, - } if ty.name.as_deref() == Some("gl_PerVertex") => { - let mut new_members = original_members.clone(); - for member in &mut new_members { - if let Some(crate::Binding::BuiltIn(built_in)) = member.binding - { - if !self.gl_per_vertex_builtin_access.contains(&built_in) { - member.binding = None - } - } - } - if &new_members != original_members { - module.types.replace( - result.ty, - crate::Type { - name: ty.name.clone(), - inner: crate::TypeInner::Struct { - members: new_members, - span, - }, - }, - ); + let mut new_members = None; + for (idx, member) in original_members.iter().enumerate() { + if let Some(crate::Binding::BuiltIn(built_in)) = member.binding { + if !self.gl_per_vertex_builtin_access.contains(&built_in) { + new_members.get_or_insert_with(|| original_members.clone()) + [idx] + .binding = None; } } - _ => {} + } + if let Some(new_members) = new_members { + module.types.replace( + result.ty, + crate::Type { + name: ty.name.clone(), + inner: crate::TypeInner::Struct { + members: new_members, + span, + }, + }, + ); } } + } - match module.types[result.ty].inner { - crate::TypeInner::Struct { - members: ref sub_members, - .. - } => { - for (index, sm) in sub_members.iter().enumerate() { - if sm.binding.is_none() { - continue; - } - let mut sm = sm.clone(); - - if let Some(ref mut binding) = sm.binding { - if ep.stage == crate::ShaderStage::Vertex { - binding.apply_default_interpolation( - &module.types[sm.ty].inner, - ); - } - } - - members.push(sm); - - components.push(function.expressions.append( - crate::Expression::AccessIndex { - base: expr_handle, - index: index as u32, - }, - span, - )); + match module.types[result.ty].inner { + crate::TypeInner::Struct { + members: ref sub_members, + .. + } => { + for (index, sm) in sub_members.iter().enumerate() { + if sm.binding.is_none() { + continue; } - } - ref inner => { - let mut binding = result.binding.clone(); - if let Some(ref mut binding) = binding { + let mut sm = sm.clone(); + + if let Some(ref mut binding) = sm.binding { if ep.stage == crate::ShaderStage::Vertex { - binding.apply_default_interpolation(inner); + binding.apply_default_interpolation(&module.types[sm.ty].inner); } } - members.push(crate::StructMember { - name: None, - ty: result.ty, - binding, - offset: 0, - }); - // populate just the globals first, then do `Load` in a - // separate step, so that we can get a range. - components.push(expr_handle); + members.push(sm); + + components.push(function.expressions.append( + crate::Expression::AccessIndex { + base: expr_handle, + index: index as u32, + }, + span, + )); } } - } - } + ref inner => { + let mut binding = result.binding.clone(); + if let Some(ref mut binding) = binding { + if ep.stage == crate::ShaderStage::Vertex { + binding.apply_default_interpolation(inner); + } + } - for (member_index, member) in members.iter().enumerate() { - match member.binding { - Some(crate::Binding::BuiltIn(crate::BuiltIn::Position { .. })) - if self.options.adjust_coordinate_space => - { - let mut emitter = Emitter::default(); - emitter.start(&function.expressions); - let global_expr = components[member_index]; - let span = function.expressions.get_span(global_expr); - let access_expr = function.expressions.append( - crate::Expression::AccessIndex { - base: global_expr, - index: 1, - }, - span, - ); - let load_expr = function.expressions.append( - crate::Expression::Load { - pointer: access_expr, - }, - span, - ); - let neg_expr = function.expressions.append( - crate::Expression::Unary { - op: crate::UnaryOperator::Negate, - expr: load_expr, - }, - span, - ); - function.body.extend(emitter.finish(&function.expressions)); - function.body.push( - crate::Statement::Store { - pointer: access_expr, - value: neg_expr, - }, - span, - ); + members.push(crate::StructMember { + name: None, + ty: result.ty, + binding, + offset: 0, + }); + // populate just the globals first, then do `Load` in a + // separate step, so that we can get a range. + components.push(expr_handle); } - _ => {} } } + } - let mut emitter = Emitter::default(); - emitter.start(&function.expressions); - for component in components.iter_mut() { - let load_expr = crate::Expression::Load { - pointer: *component, - }; - let span = function.expressions.get_span(*component); - *component = function.expressions.append(load_expr, span); - } - - match members[..] { - [] => {} - [ref member] => { - function.body.extend(emitter.finish(&function.expressions)); - let span = function.expressions.get_span(components[0]); - function.body.push( - crate::Statement::Return { - value: components.first().cloned(), + for (member_index, member) in members.iter().enumerate() { + match member.binding { + Some(crate::Binding::BuiltIn(crate::BuiltIn::Position { .. })) + if self.options.adjust_coordinate_space => + { + let mut emitter = Emitter::default(); + emitter.start(&function.expressions); + let global_expr = components[member_index]; + let span = function.expressions.get_span(global_expr); + let access_expr = function.expressions.append( + crate::Expression::AccessIndex { + base: global_expr, + index: 1, }, span, ); - function.result = Some(crate::FunctionResult { - ty: member.ty, - binding: member.binding.clone(), - }); - } - _ => { - let span = crate::Span::total_span( - components.iter().map(|h| function.expressions.get_span(*h)), + let load_expr = function.expressions.append( + crate::Expression::Load { + pointer: access_expr, + }, + span, ); - let ty = module.types.insert( - crate::Type { - name: None, - inner: crate::TypeInner::Struct { - members, - span: 0xFFFF, // shouldn't matter - }, + let neg_expr = function.expressions.append( + crate::Expression::Unary { + op: crate::UnaryOperator::Negate, + expr: load_expr, }, span, ); - let result_expr = function - .expressions - .append(crate::Expression::Compose { ty, components }, span); function.body.extend(emitter.finish(&function.expressions)); function.body.push( - crate::Statement::Return { - value: Some(result_expr), + crate::Statement::Store { + pointer: access_expr, + value: neg_expr, }, span, ); - function.result = Some(crate::FunctionResult { ty, binding: None }); } + _ => {} } + } - module.entry_points.push(crate::EntryPoint { - name: ep.name, - stage: ep.stage, - early_depth_test: ep.early_depth_test, - workgroup_size: ep.workgroup_size, - function, - }); + let mut emitter = Emitter::default(); + emitter.start(&function.expressions); + for component in components.iter_mut() { + let load_expr = crate::Expression::Load { + pointer: *component, + }; + let span = function.expressions.get_span(*component); + *component = function.expressions.append(load_expr, span); } + match members[..] { + [] => {} + [ref member] => { + function.body.extend(emitter.finish(&function.expressions)); + let span = function.expressions.get_span(components[0]); + function.body.push( + crate::Statement::Return { + value: components.first().cloned(), + }, + span, + ); + function.result = Some(crate::FunctionResult { + ty: member.ty, + binding: member.binding.clone(), + }); + } + _ => { + let span = crate::Span::total_span( + components.iter().map(|h| function.expressions.get_span(*h)), + ); + let ty = module.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Struct { + members, + span: 0xFFFF, // shouldn't matter + }, + }, + span, + ); + let result_expr = function + .expressions + .append(crate::Expression::Compose { ty, components }, span); + function.body.extend(emitter.finish(&function.expressions)); + function.body.push( + crate::Statement::Return { + value: Some(result_expr), + }, + span, + ); + function.result = Some(crate::FunctionResult { ty, binding: None }); + } + } + + module.entry_points.push(crate::EntryPoint { + name: ep.name, + stage: ep.stage, + early_depth_test: ep.early_depth_test, + workgroup_size: ep.workgroup_size, + function, + }); + Ok(()) } } diff --git a/third_party/rust/naga/src/front/spv/mod.rs b/third_party/rust/naga/src/front/spv/mod.rs index 8b1c854358..b793448597 100644 --- a/third_party/rust/naga/src/front/spv/mod.rs +++ b/third_party/rust/naga/src/front/spv/mod.rs @@ -577,6 +577,9 @@ pub struct Frontend<I> { lookup_function_type: FastHashMap<spirv::Word, LookupFunctionType>, lookup_function: FastHashMap<spirv::Word, LookupFunction>, lookup_entry_point: FastHashMap<spirv::Word, EntryPoint>, + // When parsing functions, each entry point function gets an entry here so that additional + // processing for them can be performed after all function parsing. + deferred_entry_points: Vec<(EntryPoint, spirv::Word)>, //Note: each `OpFunctionCall` gets a single entry here, indexed by the // dummy `Handle<crate::Function>` of the call site. deferred_function_calls: Vec<spirv::Word>, @@ -628,6 +631,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> { lookup_function_type: FastHashMap::default(), lookup_function: FastHashMap::default(), lookup_entry_point: FastHashMap::default(), + deferred_entry_points: Vec::default(), deferred_function_calls: Vec::default(), dummy_functions: Arena::new(), function_call_graph: GraphMap::new(), @@ -1561,12 +1565,10 @@ impl<I: Iterator<Item = u32>> Frontend<I> { span, ); - if ty.name.as_deref() == Some("gl_PerVertex") { - if let Some(crate::Binding::BuiltIn(built_in)) = - members[index as usize].binding - { - self.gl_per_vertex_builtin_access.insert(built_in); - } + if let Some(crate::Binding::BuiltIn(built_in)) = + members[index as usize].binding + { + self.gl_per_vertex_builtin_access.insert(built_in); } AccessExpression { @@ -3956,6 +3958,12 @@ impl<I: Iterator<Item = u32>> Frontend<I> { }?; } + // Do entry point specific processing after all functions are parsed so that we can + // cull unused problematic builtins of gl_PerVertex. + for (ep, fun_id) in core::mem::take(&mut self.deferred_entry_points) { + self.process_entry_point(&mut module, ep, fun_id)?; + } + log::info!("Patching..."); { let mut nodes = petgraph::algo::toposort(&self.function_call_graph, None) @@ -4868,6 +4876,11 @@ impl<I: Iterator<Item = u32>> Frontend<I> { let low = self.next()?; match width { 4 => crate::Literal::U32(low), + 8 => { + inst.expect(5)?; + let high = self.next()?; + crate::Literal::U64(u64::from(high) << 32 | u64::from(low)) + } _ => return Err(Error::InvalidTypeWidth(width as u32)), } } @@ -5081,7 +5094,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> { None }; let span = self.span_from_with_op(start); - let mut dec = self.future_decor.remove(&id).unwrap_or_default(); + let dec = self.future_decor.remove(&id).unwrap_or_default(); let original_ty = self.lookup_type.lookup(type_id)?.handle; let mut ty = original_ty; @@ -5127,17 +5140,6 @@ impl<I: Iterator<Item = u32>> Frontend<I> { None => map_storage_class(storage_class)?, }; - // Fix empty name for gl_PerVertex struct generated by glslang - if let crate::TypeInner::Pointer { .. } = module.types[original_ty].inner { - if ext_class == ExtendedClass::Input || ext_class == ExtendedClass::Output { - if let Some(ref dec_name) = dec.name { - if dec_name.is_empty() { - dec.name = Some("perVertexStruct".to_string()) - } - } - } - } - let (inner, var) = match ext_class { ExtendedClass::Global(mut space) => { if let crate::AddressSpace::Storage { ref mut access } = space { @@ -5323,6 +5325,21 @@ pub fn parse_u8_slice(data: &[u8], options: &Options) -> Result<crate::Module, E Frontend::new(words, options).parse() } +/// Helper function to check if `child` is in the scope of `parent` +fn is_parent(mut child: usize, parent: usize, block_ctx: &BlockContext) -> bool { + loop { + if child == parent { + // The child is in the scope parent + break true; + } else if child == 0 { + // Searched finished at the root the child isn't in the parent's body + break false; + } + + child = block_ctx.bodies[child].parent; + } +} + #[cfg(test)] mod test { #[test] @@ -5339,18 +5356,3 @@ mod test { let _ = super::parse_u8_slice(&bin, &Default::default()).unwrap(); } } - -/// Helper function to check if `child` is in the scope of `parent` -fn is_parent(mut child: usize, parent: usize, block_ctx: &BlockContext) -> bool { - loop { - if child == parent { - // The child is in the scope parent - break true; - } else if child == 0 { - // Searched finished at the root the child isn't in the parent's body - break false; - } - - child = block_ctx.bodies[child].parent; - } -} diff --git a/third_party/rust/naga/src/front/wgsl/error.rs b/third_party/rust/naga/src/front/wgsl/error.rs index 07e68f8dd9..54aa8296b1 100644 --- a/third_party/rust/naga/src/front/wgsl/error.rs +++ b/third_party/rust/naga/src/front/wgsl/error.rs @@ -87,7 +87,7 @@ impl ParseError { /// Returns a [`SourceLocation`] for the first label in the error message. pub fn location(&self, source: &str) -> Option<SourceLocation> { - self.labels.get(0).map(|label| label.0.location(source)) + self.labels.first().map(|label| label.0.location(source)) } } diff --git a/third_party/rust/naga/src/front/wgsl/lower/mod.rs b/third_party/rust/naga/src/front/wgsl/lower/mod.rs index ba9b49e135..2ca6c182b7 100644 --- a/third_party/rust/naga/src/front/wgsl/lower/mod.rs +++ b/third_party/rust/naga/src/front/wgsl/lower/mod.rs @@ -1530,6 +1530,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ast::Literal::Number(Number::F32(f)) => crate::Literal::F32(f), ast::Literal::Number(Number::I32(i)) => crate::Literal::I32(i), ast::Literal::Number(Number::U32(u)) => crate::Literal::U32(u), + ast::Literal::Number(Number::I64(i)) => crate::Literal::I64(i), + ast::Literal::Number(Number::U64(u)) => crate::Literal::U64(u), ast::Literal::Number(Number::F64(f)) => crate::Literal::F64(f), ast::Literal::Number(Number::AbstractInt(i)) => crate::Literal::AbstractInt(i), ast::Literal::Number(Number::AbstractFloat(f)) => { diff --git a/third_party/rust/naga/src/front/wgsl/parse/conv.rs b/third_party/rust/naga/src/front/wgsl/parse/conv.rs index 08f1e39285..1a4911a3bd 100644 --- a/third_party/rust/naga/src/front/wgsl/parse/conv.rs +++ b/third_party/rust/naga/src/front/wgsl/parse/conv.rs @@ -124,6 +124,14 @@ pub fn get_scalar_type(word: &str) -> Option<Scalar> { kind: Sk::Uint, width: 4, }), + "i64" => Some(Scalar { + kind: Sk::Sint, + width: 8, + }), + "u64" => Some(Scalar { + kind: Sk::Uint, + width: 8, + }), "bool" => Some(Scalar { kind: Sk::Bool, width: crate::BOOL_WIDTH, diff --git a/third_party/rust/naga/src/front/wgsl/parse/number.rs b/third_party/rust/naga/src/front/wgsl/parse/number.rs index 7b09ac59bb..ceb2cb336c 100644 --- a/third_party/rust/naga/src/front/wgsl/parse/number.rs +++ b/third_party/rust/naga/src/front/wgsl/parse/number.rs @@ -12,6 +12,10 @@ pub enum Number { I32(i32), /// Concrete u32 U32(u32), + /// Concrete i64 + I64(i64), + /// Concrete u64 + U64(u64), /// Concrete f32 F32(f32), /// Concrete f64 @@ -31,6 +35,8 @@ enum Kind { enum IntKind { I32, U32, + I64, + U64, } #[derive(Debug)] @@ -270,6 +276,8 @@ fn parse(input: &str) -> (Result<Number, NumberError>, &str) { let kind = consume_map!(bytes, [ b'i' => Kind::Int(IntKind::I32), b'u' => Kind::Int(IntKind::U32), + b'l', b'i' => Kind::Int(IntKind::I64), + b'l', b'u' => Kind::Int(IntKind::U64), b'h' => Kind::Float(FloatKind::F16), b'f' => Kind::Float(FloatKind::F32), b'l', b'f' => Kind::Float(FloatKind::F64), @@ -416,5 +424,13 @@ fn parse_int(input: &str, kind: Option<IntKind>, radix: u32) -> Result<Number, N Ok(num) => Ok(Number::U32(num)), Err(e) => Err(map_err(e)), }, + Some(IntKind::I64) => match i64::from_str_radix(input, radix) { + Ok(num) => Ok(Number::I64(num)), + Err(e) => Err(map_err(e)), + }, + Some(IntKind::U64) => match u64::from_str_radix(input, radix) { + Ok(num) => Ok(Number::U64(num)), + Err(e) => Err(map_err(e)), + }, } } diff --git a/third_party/rust/naga/src/front/wgsl/tests.rs b/third_party/rust/naga/src/front/wgsl/tests.rs index eb2f8a2eb3..cc3d858317 100644 --- a/third_party/rust/naga/src/front/wgsl/tests.rs +++ b/third_party/rust/naga/src/front/wgsl/tests.rs @@ -17,6 +17,7 @@ fn parse_comment() { #[test] fn parse_types() { parse_str("const a : i32 = 2;").unwrap(); + parse_str("const a : u64 = 2lu;").unwrap(); assert!(parse_str("const a : x32 = 2;").is_err()); parse_str("var t: texture_2d<f32>;").unwrap(); parse_str("var t: texture_cube_array<i32>;").unwrap(); diff --git a/third_party/rust/naga/src/keywords/wgsl.rs b/third_party/rust/naga/src/keywords/wgsl.rs index 7b47a13128..683840dc1f 100644 --- a/third_party/rust/naga/src/keywords/wgsl.rs +++ b/third_party/rust/naga/src/keywords/wgsl.rs @@ -14,6 +14,7 @@ pub const RESERVED: &[&str] = &[ "f32", "f16", "i32", + "i64", "mat2x2", "mat2x3", "mat2x4", @@ -43,6 +44,7 @@ pub const RESERVED: &[&str] = &[ "texture_depth_cube_array", "texture_depth_multisampled_2d", "u32", + "u64", "vec2", "vec3", "vec4", diff --git a/third_party/rust/naga/src/lib.rs b/third_party/rust/naga/src/lib.rs index d6b9c6a7f4..4b45174300 100644 --- a/third_party/rust/naga/src/lib.rs +++ b/third_party/rust/naga/src/lib.rs @@ -252,7 +252,8 @@ An override expression can be evaluated at pipeline creation time. clippy::collapsible_if, clippy::derive_partial_eq_without_eq, clippy::needless_borrowed_reference, - clippy::single_match + clippy::single_match, + clippy::enum_variant_names )] #![warn( trivial_casts, @@ -490,7 +491,7 @@ pub enum ScalarKind { } /// Characteristics of a scalar type. -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] @@ -884,6 +885,7 @@ pub enum Literal { F32(f32), U32(u32), I32(i32), + U64(u64), I64(i64), Bool(bool), AbstractInt(i64), @@ -1255,15 +1257,18 @@ pub enum SampleLevel { #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum ImageQuery { /// Get the size at the specified level. + /// + /// The return value is a `u32` for 1D images, and a `vecN<u32>` + /// for an image with dimensions N > 2. Size { /// If `None`, the base level is considered. level: Option<Handle<Expression>>, }, - /// Get the number of mipmap levels. + /// Get the number of mipmap levels, a `u32`. NumLevels, - /// Get the number of array layers. + /// Get the number of array layers, a `u32`. NumLayers, - /// Get the number of samples. + /// Get the number of samples, a `u32`. NumSamples, } @@ -1683,6 +1688,10 @@ pub enum Statement { /// A block containing more statements, to be executed sequentially. Block(Block), /// Conditionally executes one of two blocks, based on the value of the condition. + /// + /// Naga IR does not have "phi" instructions. If you need to use + /// values computed in an `accept` or `reject` block after the `If`, + /// store them in a [`LocalVariable`]. If { condition: Handle<Expression>, //bool accept: Block, @@ -1702,6 +1711,10 @@ pub enum Statement { /// represented in the IR as a series of fallthrough cases with empty /// bodies, except for the last. /// + /// Naga IR does not have "phi" instructions. If you need to use + /// values computed in a [`SwitchCase::body`] block after the `Switch`, + /// store them in a [`LocalVariable`]. + /// /// [`value`]: SwitchCase::value /// [`body`]: SwitchCase::body /// [`Default`]: SwitchValue::Default @@ -1736,6 +1749,10 @@ pub enum Statement { /// if" statement in WGSL, or a loop whose back edge is an /// `OpBranchConditional` instruction in SPIR-V. /// + /// Naga IR does not have "phi" instructions. If you need to use + /// values computed in a `body` or `continuing` block after the + /// `Loop`, store them in a [`LocalVariable`]. + /// /// [`Break`]: Statement::Break /// [`Continue`]: Statement::Continue /// [`Kill`]: Statement::Kill diff --git a/third_party/rust/naga/src/proc/constant_evaluator.rs b/third_party/rust/naga/src/proc/constant_evaluator.rs index b3884b04b1..983af3718c 100644 --- a/third_party/rust/naga/src/proc/constant_evaluator.rs +++ b/third_party/rust/naga/src/proc/constant_evaluator.rs @@ -31,7 +31,7 @@ macro_rules! gen_component_wise_extractor { $( #[doc = concat!( "Maps to [`Literal::", - stringify!($mapping), + stringify!($literal), "`]", )] $mapping([$ty; N]), @@ -200,6 +200,8 @@ gen_component_wise_extractor! { AbstractInt => AbstractInt: i64, U32 => U32: u32, I32 => I32: i32, + U64 => U64: u64, + I64 => I64: i64, ], scalar_kinds: [ Float, @@ -847,6 +849,8 @@ impl<'a> ConstantEvaluator<'a> { Scalar::AbstractInt([e]) => Ok(Scalar::AbstractInt([e.abs()])), Scalar::I32([e]) => Ok(Scalar::I32([e.wrapping_abs()])), Scalar::U32([e]) => Ok(Scalar::U32([e])), // TODO: just re-use the expression, ezpz + Scalar::I64([e]) => Ok(Scalar::I64([e.wrapping_abs()])), + Scalar::U64([e]) => Ok(Scalar::U64([e])), }) } crate::MathFunction::Min => { @@ -1280,7 +1284,7 @@ impl<'a> ConstantEvaluator<'a> { Literal::U32(v) => v as i32, Literal::F32(v) => v as i32, Literal::Bool(v) => v as i32, - Literal::F64(_) | Literal::I64(_) => { + Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => { return make_error(); } Literal::AbstractInt(v) => i32::try_from_abstract(v)?, @@ -1291,18 +1295,40 @@ impl<'a> ConstantEvaluator<'a> { Literal::U32(v) => v, Literal::F32(v) => v as u32, Literal::Bool(v) => v as u32, - Literal::F64(_) | Literal::I64(_) => { + Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => { return make_error(); } Literal::AbstractInt(v) => u32::try_from_abstract(v)?, Literal::AbstractFloat(v) => u32::try_from_abstract(v)?, }), + Sc::I64 => Literal::I64(match literal { + Literal::I32(v) => v as i64, + Literal::U32(v) => v as i64, + Literal::F32(v) => v as i64, + Literal::Bool(v) => v as i64, + Literal::F64(v) => v as i64, + Literal::I64(v) => v, + Literal::U64(v) => v as i64, + Literal::AbstractInt(v) => i64::try_from_abstract(v)?, + Literal::AbstractFloat(v) => i64::try_from_abstract(v)?, + }), + Sc::U64 => Literal::U64(match literal { + Literal::I32(v) => v as u64, + Literal::U32(v) => v as u64, + Literal::F32(v) => v as u64, + Literal::Bool(v) => v as u64, + Literal::F64(v) => v as u64, + Literal::I64(v) => v as u64, + Literal::U64(v) => v, + Literal::AbstractInt(v) => u64::try_from_abstract(v)?, + Literal::AbstractFloat(v) => u64::try_from_abstract(v)?, + }), Sc::F32 => Literal::F32(match literal { Literal::I32(v) => v as f32, Literal::U32(v) => v as f32, Literal::F32(v) => v, Literal::Bool(v) => v as u32 as f32, - Literal::F64(_) | Literal::I64(_) => { + Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => { return make_error(); } Literal::AbstractInt(v) => f32::try_from_abstract(v)?, @@ -1314,7 +1340,7 @@ impl<'a> ConstantEvaluator<'a> { Literal::F32(v) => v as f64, Literal::F64(v) => v, Literal::Bool(v) => v as u32 as f64, - Literal::I64(_) => return make_error(), + Literal::I64(_) | Literal::U64(_) => return make_error(), Literal::AbstractInt(v) => f64::try_from_abstract(v)?, Literal::AbstractFloat(v) => f64::try_from_abstract(v)?, }), @@ -1325,6 +1351,7 @@ impl<'a> ConstantEvaluator<'a> { Literal::Bool(v) => v, Literal::F64(_) | Literal::I64(_) + | Literal::U64(_) | Literal::AbstractInt(_) | Literal::AbstractFloat(_) => { return make_error(); @@ -1877,6 +1904,122 @@ impl<'a> ConstantEvaluator<'a> { } } +/// Trait for conversions of abstract values to concrete types. +trait TryFromAbstract<T>: Sized { + /// Convert an abstract literal `value` to `Self`. + /// + /// Since Naga's `AbstractInt` and `AbstractFloat` exist to support + /// WGSL, we follow WGSL's conversion rules here: + /// + /// - WGSL §6.1.2. Conversion Rank says that automatic conversions + /// to integers are either lossless or an error. + /// + /// - WGSL §14.6.4 Floating Point Conversion says that conversions + /// to floating point in constant expressions and override + /// expressions are errors if the value is out of range for the + /// destination type, but rounding is okay. + /// + /// [`AbstractInt`]: crate::Literal::AbstractInt + /// [`Float`]: crate::Literal::Float + fn try_from_abstract(value: T) -> Result<Self, ConstantEvaluatorError>; +} + +impl TryFromAbstract<i64> for i32 { + fn try_from_abstract(value: i64) -> Result<i32, ConstantEvaluatorError> { + i32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy { + value: format!("{value:?}"), + to_type: "i32", + }) + } +} + +impl TryFromAbstract<i64> for u32 { + fn try_from_abstract(value: i64) -> Result<u32, ConstantEvaluatorError> { + u32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy { + value: format!("{value:?}"), + to_type: "u32", + }) + } +} + +impl TryFromAbstract<i64> for u64 { + fn try_from_abstract(value: i64) -> Result<u64, ConstantEvaluatorError> { + u64::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy { + value: format!("{value:?}"), + to_type: "u64", + }) + } +} + +impl TryFromAbstract<i64> for i64 { + fn try_from_abstract(value: i64) -> Result<i64, ConstantEvaluatorError> { + Ok(value) + } +} + +impl TryFromAbstract<i64> for f32 { + fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> { + let f = value as f32; + // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of + // `f32` is roughly ±3.4 × 10³⁸, so there's no opportunity for + // overflow here. + Ok(f) + } +} + +impl TryFromAbstract<f64> for f32 { + fn try_from_abstract(value: f64) -> Result<f32, ConstantEvaluatorError> { + let f = value as f32; + if f.is_infinite() { + return Err(ConstantEvaluatorError::AutomaticConversionLossy { + value: format!("{value:?}"), + to_type: "f32", + }); + } + Ok(f) + } +} + +impl TryFromAbstract<i64> for f64 { + fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> { + let f = value as f64; + // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of + // `f64` is roughly ±1.8 × 10³⁰⁸, so there's no opportunity for + // overflow here. + Ok(f) + } +} + +impl TryFromAbstract<f64> for f64 { + fn try_from_abstract(value: f64) -> Result<f64, ConstantEvaluatorError> { + Ok(value) + } +} + +impl TryFromAbstract<f64> for i32 { + fn try_from_abstract(_: f64) -> Result<Self, ConstantEvaluatorError> { + Err(ConstantEvaluatorError::AutomaticConversionFloatToInt { to_type: "i32" }) + } +} + +impl TryFromAbstract<f64> for u32 { + fn try_from_abstract(_: f64) -> Result<Self, ConstantEvaluatorError> { + Err(ConstantEvaluatorError::AutomaticConversionFloatToInt { to_type: "u32" }) + } +} + +impl TryFromAbstract<f64> for i64 { + fn try_from_abstract(_: f64) -> Result<Self, ConstantEvaluatorError> { + Err(ConstantEvaluatorError::AutomaticConversionFloatToInt { to_type: "i64" }) + } +} + +impl TryFromAbstract<f64> for u64 { + fn try_from_abstract(_: f64) -> Result<Self, ConstantEvaluatorError> { + Err(ConstantEvaluatorError::AutomaticConversionFloatToInt { to_type: "u64" }) + } +} + #[cfg(test)] mod tests { use std::vec; @@ -2384,92 +2527,3 @@ mod tests { } } } - -/// Trait for conversions of abstract values to concrete types. -trait TryFromAbstract<T>: Sized { - /// Convert an abstract literal `value` to `Self`. - /// - /// Since Naga's `AbstractInt` and `AbstractFloat` exist to support - /// WGSL, we follow WGSL's conversion rules here: - /// - /// - WGSL §6.1.2. Conversion Rank says that automatic conversions - /// to integers are either lossless or an error. - /// - /// - WGSL §14.6.4 Floating Point Conversion says that conversions - /// to floating point in constant expressions and override - /// expressions are errors if the value is out of range for the - /// destination type, but rounding is okay. - /// - /// [`AbstractInt`]: crate::Literal::AbstractInt - /// [`Float`]: crate::Literal::Float - fn try_from_abstract(value: T) -> Result<Self, ConstantEvaluatorError>; -} - -impl TryFromAbstract<i64> for i32 { - fn try_from_abstract(value: i64) -> Result<i32, ConstantEvaluatorError> { - i32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy { - value: format!("{value:?}"), - to_type: "i32", - }) - } -} - -impl TryFromAbstract<i64> for u32 { - fn try_from_abstract(value: i64) -> Result<u32, ConstantEvaluatorError> { - u32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy { - value: format!("{value:?}"), - to_type: "u32", - }) - } -} - -impl TryFromAbstract<i64> for f32 { - fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> { - let f = value as f32; - // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of - // `f32` is roughly ±3.4 × 10³⁸, so there's no opportunity for - // overflow here. - Ok(f) - } -} - -impl TryFromAbstract<f64> for f32 { - fn try_from_abstract(value: f64) -> Result<f32, ConstantEvaluatorError> { - let f = value as f32; - if f.is_infinite() { - return Err(ConstantEvaluatorError::AutomaticConversionLossy { - value: format!("{value:?}"), - to_type: "f32", - }); - } - Ok(f) - } -} - -impl TryFromAbstract<i64> for f64 { - fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> { - let f = value as f64; - // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of - // `f64` is roughly ±1.8 × 10³⁰⁸, so there's no opportunity for - // overflow here. - Ok(f) - } -} - -impl TryFromAbstract<f64> for f64 { - fn try_from_abstract(value: f64) -> Result<f64, ConstantEvaluatorError> { - Ok(value) - } -} - -impl TryFromAbstract<f64> for i32 { - fn try_from_abstract(_: f64) -> Result<Self, ConstantEvaluatorError> { - Err(ConstantEvaluatorError::AutomaticConversionFloatToInt { to_type: "i32" }) - } -} - -impl TryFromAbstract<f64> for u32 { - fn try_from_abstract(_: f64) -> Result<Self, ConstantEvaluatorError> { - Err(ConstantEvaluatorError::AutomaticConversionFloatToInt { to_type: "u32" }) - } -} diff --git a/third_party/rust/naga/src/proc/mod.rs b/third_party/rust/naga/src/proc/mod.rs index b9ce80b5ea..46cbb6c3b3 100644 --- a/third_party/rust/naga/src/proc/mod.rs +++ b/third_party/rust/naga/src/proc/mod.rs @@ -102,6 +102,10 @@ impl super::Scalar { kind: crate::ScalarKind::Sint, width: 8, }; + pub const U64: Self = Self { + kind: crate::ScalarKind::Uint, + width: 8, + }; pub const BOOL: Self = Self { kind: crate::ScalarKind::Bool, width: crate::BOOL_WIDTH, @@ -156,6 +160,7 @@ impl PartialEq for crate::Literal { (Self::F32(a), Self::F32(b)) => a.to_bits() == b.to_bits(), (Self::U32(a), Self::U32(b)) => a == b, (Self::I32(a), Self::I32(b)) => a == b, + (Self::U64(a), Self::U64(b)) => a == b, (Self::I64(a), Self::I64(b)) => a == b, (Self::Bool(a), Self::Bool(b)) => a == b, _ => false, @@ -186,10 +191,18 @@ impl std::hash::Hash for crate::Literal { hasher.write_u8(4); v.hash(hasher); } - Self::I64(v) | Self::AbstractInt(v) => { + Self::I64(v) => { hasher.write_u8(5); v.hash(hasher); } + Self::U64(v) => { + hasher.write_u8(6); + v.hash(hasher); + } + Self::AbstractInt(v) => { + hasher.write_u8(7); + v.hash(hasher); + } } } } @@ -201,6 +214,7 @@ impl crate::Literal { (value, crate::ScalarKind::Float, 4) => Some(Self::F32(value as _)), (value, crate::ScalarKind::Uint, 4) => Some(Self::U32(value as _)), (value, crate::ScalarKind::Sint, 4) => Some(Self::I32(value as _)), + (value, crate::ScalarKind::Uint, 8) => Some(Self::U64(value as _)), (value, crate::ScalarKind::Sint, 8) => Some(Self::I64(value as _)), (1, crate::ScalarKind::Bool, 4) => Some(Self::Bool(true)), (0, crate::ScalarKind::Bool, 4) => Some(Self::Bool(false)), @@ -218,7 +232,7 @@ impl crate::Literal { pub const fn width(&self) -> crate::Bytes { match *self { - Self::F64(_) | Self::I64(_) => 8, + Self::F64(_) | Self::I64(_) | Self::U64(_) => 8, Self::F32(_) | Self::U32(_) | Self::I32(_) => 4, Self::Bool(_) => crate::BOOL_WIDTH, Self::AbstractInt(_) | Self::AbstractFloat(_) => crate::ABSTRACT_WIDTH, @@ -230,6 +244,7 @@ impl crate::Literal { Self::F32(_) => crate::Scalar::F32, Self::U32(_) => crate::Scalar::U32, Self::I32(_) => crate::Scalar::I32, + Self::U64(_) => crate::Scalar::U64, Self::I64(_) => crate::Scalar::I64, Self::Bool(_) => crate::Scalar::BOOL, Self::AbstractInt(_) => crate::Scalar::ABSTRACT_INT, diff --git a/third_party/rust/naga/src/valid/expression.rs b/third_party/rust/naga/src/valid/expression.rs index c82d60f062..838ecc4e27 100644 --- a/third_party/rust/naga/src/valid/expression.rs +++ b/third_party/rust/naga/src/valid/expression.rs @@ -124,6 +124,8 @@ pub enum ExpressionError { MissingCapabilities(super::Capabilities), #[error(transparent)] Literal(#[from] LiteralError), + #[error("{0:?} is not supported for Width {2} {1:?} arguments yet, see https://github.com/gfx-rs/wgpu/issues/5276")] + UnsupportedWidth(crate::MathFunction, crate::ScalarKind, crate::Bytes), } #[derive(Clone, Debug, thiserror::Error)] @@ -1332,28 +1334,29 @@ impl super::Validator { _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), } } - Mf::CountTrailingZeros - | Mf::CountLeadingZeros + // Remove once fixed https://github.com/gfx-rs/wgpu/issues/5276 + Mf::CountLeadingZeros + | Mf::CountTrailingZeros | Mf::CountOneBits | Mf::ReverseBits - | Mf::FindLsb - | Mf::FindMsb => { + | Mf::FindMsb + | Mf::FindLsb => { if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { return Err(ExpressionError::WrongArgumentCount(fun)); } match *arg_ty { - Ti::Scalar(Sc { - kind: Sk::Sint | Sk::Uint, - .. - }) - | Ti::Vector { - scalar: - Sc { - kind: Sk::Sint | Sk::Uint, - .. - }, - .. - } => {} + Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { + Sk::Sint | Sk::Uint => { + if scalar.width != 4 { + return Err(ExpressionError::UnsupportedWidth( + fun, + scalar.kind, + scalar.width, + )); + } + } + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + }, _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), } } @@ -1404,6 +1407,21 @@ impl super::Validator { )) } } + // Remove once fixed https://github.com/gfx-rs/wgpu/issues/5276 + for &arg in [arg_ty, arg1_ty, arg2_ty, arg3_ty].iter() { + match *arg { + Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => { + if scalar.width != 4 { + return Err(ExpressionError::UnsupportedWidth( + fun, + scalar.kind, + scalar.width, + )); + } + } + _ => {} + } + } } Mf::ExtractBits => { let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) { @@ -1445,6 +1463,21 @@ impl super::Validator { )) } } + // Remove once fixed https://github.com/gfx-rs/wgpu/issues/5276 + for &arg in [arg_ty, arg1_ty, arg2_ty].iter() { + match *arg { + Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => { + if scalar.width != 4 { + return Err(ExpressionError::UnsupportedWidth( + fun, + scalar.kind, + scalar.width, + )); + } + } + _ => {} + } + } } Mf::Pack2x16unorm | Mf::Pack2x16snorm | Mf::Pack2x16float => { if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { diff --git a/third_party/rust/naga/src/valid/mod.rs b/third_party/rust/naga/src/valid/mod.rs index 388495a3ac..5459434f33 100644 --- a/third_party/rust/naga/src/valid/mod.rs +++ b/third_party/rust/naga/src/valid/mod.rs @@ -28,7 +28,7 @@ pub use expression::{check_literal_value, LiteralError}; pub use expression::{ConstExpressionError, ExpressionError}; pub use function::{CallError, FunctionError, LocalVariableError}; pub use interface::{EntryPointError, GlobalVariableError, VaryingError}; -pub use r#type::{Disalignment, TypeError, TypeFlags}; +pub use r#type::{Disalignment, TypeError, TypeFlags, WidthError}; use self::handles::InvalidHandleError; @@ -108,6 +108,8 @@ bitflags::bitflags! { const DUAL_SOURCE_BLENDING = 0x2000; /// Support for arrayed cube textures. const CUBE_ARRAY_TEXTURES = 0x4000; + /// Support for 64-bit signed and unsigned integers. + const SHADER_INT64 = 0x8000; } } diff --git a/third_party/rust/naga/src/valid/type.rs b/third_party/rust/naga/src/valid/type.rs index 1e3e03fe19..d44a295b1a 100644 --- a/third_party/rust/naga/src/valid/type.rs +++ b/third_party/rust/naga/src/valid/type.rs @@ -107,6 +107,12 @@ pub enum TypeError { MatrixElementNotFloat, #[error("The constant {0:?} is specialized, and cannot be used as an array size")] UnsupportedSpecializedArrayLength(Handle<crate::Constant>), + #[error("{} of dimensionality {dim:?} and class {class:?} are not supported", if *.arrayed {"Arrayed images"} else {"Images"})] + UnsupportedImageType { + dim: crate::ImageDimension, + arrayed: bool, + class: crate::ImageClass, + }, #[error("Array stride {stride} does not match the expected {expected}")] InvalidArrayStride { stride: u32, expected: u32 }, #[error("Field '{0}' can't be dynamically-sized, has type {1:?}")] @@ -141,9 +147,6 @@ pub enum WidthError { flag: &'static str, }, - #[error("64-bit integers are not yet supported")] - Unsupported64Bit, - #[error("Abstract types may only appear in constant expressions")] Abstract, } @@ -245,11 +248,31 @@ impl super::Validator { scalar.width == 4 } } - crate::ScalarKind::Sint | crate::ScalarKind::Uint => { + crate::ScalarKind::Sint => { + if scalar.width == 8 { + if !self.capabilities.contains(Capabilities::SHADER_INT64) { + return Err(WidthError::MissingCapability { + name: "i64", + flag: "SHADER_INT64", + }); + } + true + } else { + scalar.width == 4 + } + } + crate::ScalarKind::Uint => { if scalar.width == 8 { - return Err(WidthError::Unsupported64Bit); + if !self.capabilities.contains(Capabilities::SHADER_INT64) { + return Err(WidthError::MissingCapability { + name: "u64", + flag: "SHADER_INT64", + }); + } + true + } else { + scalar.width == 4 } - scalar.width == 4 } crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat => { return Err(WidthError::Abstract); @@ -596,8 +619,15 @@ impl super::Validator { Ti::Image { dim, arrayed, - class: _, + class, } => { + if arrayed && matches!(dim, crate::ImageDimension::D3) { + return Err(TypeError::UnsupportedImageType { + dim, + arrayed, + class, + }); + } if arrayed && matches!(dim, crate::ImageDimension::Cube) { self.require_type_capability(Capabilities::CUBE_ARRAY_TEXTURES)?; } |