diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 00:47:55 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 00:47:55 +0000 |
commit | 26a029d407be480d791972afb5975cf62c9360a6 (patch) | |
tree | f435a8308119effd964b339f76abb83a57c29483 /third_party/rust/naga/src/front | |
parent | Initial commit. (diff) | |
download | firefox-26a029d407be480d791972afb5975cf62c9360a6.tar.xz firefox-26a029d407be480d791972afb5975cf62c9360a6.zip |
Adding upstream version 124.0.1.upstream/124.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/naga/src/front')
39 files changed, 29521 insertions, 0 deletions
diff --git a/third_party/rust/naga/src/front/glsl/ast.rs b/third_party/rust/naga/src/front/glsl/ast.rs new file mode 100644 index 0000000000..96b676dd6d --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/ast.rs @@ -0,0 +1,394 @@ +use std::{borrow::Cow, fmt}; + +use super::{builtins::MacroCall, context::ExprPos, Span}; +use crate::{ + AddressSpace, BinaryOperator, Binding, Constant, Expression, Function, GlobalVariable, Handle, + Interpolation, Literal, Sampling, StorageAccess, Type, UnaryOperator, +}; + +#[derive(Debug, Clone, Copy)] +pub enum GlobalLookupKind { + Variable(Handle<GlobalVariable>), + Constant(Handle<Constant>, Handle<Type>), + BlockSelect(Handle<GlobalVariable>, u32), +} + +#[derive(Debug, Clone, Copy)] +pub struct GlobalLookup { + pub kind: GlobalLookupKind, + pub entry_arg: Option<usize>, + pub mutable: bool, +} + +#[derive(Debug, Clone)] +pub struct ParameterInfo { + pub qualifier: ParameterQualifier, + /// Whether the parameter should be treated as a depth image instead of a + /// sampled image. + pub depth: bool, +} + +/// How the function is implemented +#[derive(Clone, Copy)] +pub enum FunctionKind { + /// The function is user defined + Call(Handle<Function>), + /// The function is a builtin + Macro(MacroCall), +} + +impl fmt::Debug for FunctionKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + Self::Call(_) => write!(f, "Call"), + Self::Macro(_) => write!(f, "Macro"), + } + } +} + +#[derive(Debug)] +pub struct Overload { + /// Normalized function parameters, modifiers are not applied + pub parameters: Vec<Handle<Type>>, + pub parameters_info: Vec<ParameterInfo>, + /// How the function is implemented + pub kind: FunctionKind, + /// Whether this function was already defined or is just a prototype + pub defined: bool, + /// Whether this overload is the one provided by the language or has + /// been redeclared by the user (builtins only) + pub internal: bool, + /// Whether or not this function returns void (nothing) + pub void: bool, +} + +bitflags::bitflags! { + /// Tracks the variations of the builtin already generated, this is needed because some + /// builtins overloads can't be generated unless explicitly used, since they might cause + /// unneeded capabilities to be requested + #[derive(Default)] + #[derive(Clone, Copy, Debug, Eq, PartialEq)] + pub struct BuiltinVariations: u32 { + /// Request the standard overloads + const STANDARD = 1 << 0; + /// Request overloads that use the double type + const DOUBLE = 1 << 1; + /// Request overloads that use samplerCubeArray(Shadow) + const CUBE_TEXTURES_ARRAY = 1 << 2; + /// Request overloads that use sampler2DMSArray + const D2_MULTI_TEXTURES_ARRAY = 1 << 3; + } +} + +#[derive(Debug, Default)] +pub struct FunctionDeclaration { + pub overloads: Vec<Overload>, + /// Tracks the builtin overload variations that were already generated + pub variations: BuiltinVariations, +} + +#[derive(Debug)] +pub struct EntryArg { + pub name: Option<String>, + pub binding: Binding, + pub handle: Handle<GlobalVariable>, + pub storage: StorageQualifier, +} + +#[derive(Debug, Clone)] +pub struct VariableReference { + pub expr: Handle<Expression>, + /// Whether the variable is of a pointer type (and needs loading) or not + pub load: bool, + /// Whether the value of the variable can be changed or not + pub mutable: bool, + pub constant: Option<(Handle<Constant>, Handle<Type>)>, + pub entry_arg: Option<usize>, +} + +#[derive(Debug, Clone)] +pub struct HirExpr { + pub kind: HirExprKind, + pub meta: Span, +} + +#[derive(Debug, Clone)] +pub enum HirExprKind { + Access { + base: Handle<HirExpr>, + index: Handle<HirExpr>, + }, + Select { + base: Handle<HirExpr>, + field: String, + }, + Literal(Literal), + Binary { + left: Handle<HirExpr>, + op: BinaryOperator, + right: Handle<HirExpr>, + }, + Unary { + op: UnaryOperator, + expr: Handle<HirExpr>, + }, + Variable(VariableReference), + Call(FunctionCall), + /// Represents the ternary operator in glsl (`:?`) + Conditional { + /// The expression that will decide which branch to take, must evaluate to a boolean + condition: Handle<HirExpr>, + /// The expression that will be evaluated if [`condition`] returns `true` + /// + /// [`condition`]: Self::Conditional::condition + accept: Handle<HirExpr>, + /// The expression that will be evaluated if [`condition`] returns `false` + /// + /// [`condition`]: Self::Conditional::condition + reject: Handle<HirExpr>, + }, + Assign { + tgt: Handle<HirExpr>, + value: Handle<HirExpr>, + }, + /// A prefix/postfix operator like `++` + PrePostfix { + /// The operation to be performed + op: BinaryOperator, + /// Whether this is a postfix or a prefix + postfix: bool, + /// The target expression + expr: Handle<HirExpr>, + }, + /// A method call like `what.something(a, b, c)` + Method { + /// expression the method call applies to (`what` in the example) + expr: Handle<HirExpr>, + /// the method name (`something` in the example) + name: String, + /// the arguments to the method (`a`, `b`, and `c` in the example) + args: Vec<Handle<HirExpr>>, + }, +} + +#[derive(Debug, Hash, PartialEq, Eq)] +pub enum QualifierKey<'a> { + String(Cow<'a, str>), + /// Used for `std140` and `std430` layout qualifiers + Layout, + /// Used for image formats + Format, +} + +#[derive(Debug)] +pub enum QualifierValue { + None, + Uint(u32), + Layout(StructLayout), + Format(crate::StorageFormat), +} + +#[derive(Debug, Default)] +pub struct TypeQualifiers<'a> { + pub span: Span, + pub storage: (StorageQualifier, Span), + pub invariant: Option<Span>, + pub interpolation: Option<(Interpolation, Span)>, + pub precision: Option<(Precision, Span)>, + pub sampling: Option<(Sampling, Span)>, + /// Memory qualifiers used in the declaration to set the storage access to be used + /// in declarations that support it (storage images and buffers) + pub storage_access: Option<(StorageAccess, Span)>, + pub layout_qualifiers: crate::FastHashMap<QualifierKey<'a>, (QualifierValue, Span)>, +} + +impl<'a> TypeQualifiers<'a> { + /// Appends `errors` with errors for all unused qualifiers + pub fn unused_errors(&self, errors: &mut Vec<super::Error>) { + if let Some(meta) = self.invariant { + errors.push(super::Error { + kind: super::ErrorKind::SemanticError( + "Invariant qualifier can only be used in in/out variables".into(), + ), + meta, + }); + } + + if let Some((_, meta)) = self.interpolation { + errors.push(super::Error { + kind: super::ErrorKind::SemanticError( + "Interpolation qualifiers can only be used in in/out variables".into(), + ), + meta, + }); + } + + if let Some((_, meta)) = self.sampling { + errors.push(super::Error { + kind: super::ErrorKind::SemanticError( + "Sampling qualifiers can only be used in in/out variables".into(), + ), + meta, + }); + } + + if let Some((_, meta)) = self.storage_access { + errors.push(super::Error { + kind: super::ErrorKind::SemanticError( + "Memory qualifiers can only be used in storage variables".into(), + ), + meta, + }); + } + + for &(_, meta) in self.layout_qualifiers.values() { + errors.push(super::Error { + kind: super::ErrorKind::SemanticError("Unexpected qualifier".into()), + meta, + }); + } + } + + /// Removes the layout qualifier with `name`, if it exists and adds an error if it isn't + /// a [`QualifierValue::Uint`] + pub fn uint_layout_qualifier( + &mut self, + name: &'a str, + errors: &mut Vec<super::Error>, + ) -> Option<u32> { + match self + .layout_qualifiers + .remove(&QualifierKey::String(name.into())) + { + Some((QualifierValue::Uint(v), _)) => Some(v), + Some((_, meta)) => { + errors.push(super::Error { + kind: super::ErrorKind::SemanticError("Qualifier expects a uint value".into()), + meta, + }); + // Return a dummy value instead of `None` to differentiate from + // the qualifier not existing, since some parts might require the + // qualifier to exist and throwing another error that it doesn't + // exist would be unhelpful + Some(0) + } + _ => None, + } + } + + /// Removes the layout qualifier with `name`, if it exists and adds an error if it isn't + /// a [`QualifierValue::None`] + pub fn none_layout_qualifier(&mut self, name: &'a str, errors: &mut Vec<super::Error>) -> bool { + match self + .layout_qualifiers + .remove(&QualifierKey::String(name.into())) + { + Some((QualifierValue::None, _)) => true, + Some((_, meta)) => { + errors.push(super::Error { + kind: super::ErrorKind::SemanticError( + "Qualifier doesn't expect a value".into(), + ), + meta, + }); + // Return a `true` to since the qualifier is defined and adding + // another error for it not being defined would be unhelpful + true + } + _ => false, + } + } +} + +#[derive(Debug, Clone)] +pub enum FunctionCallKind { + TypeConstructor(Handle<Type>), + Function(String), +} + +#[derive(Debug, Clone)] +pub struct FunctionCall { + pub kind: FunctionCallKind, + pub args: Vec<Handle<HirExpr>>, +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum StorageQualifier { + AddressSpace(AddressSpace), + Input, + Output, + Const, +} + +impl Default for StorageQualifier { + fn default() -> Self { + StorageQualifier::AddressSpace(AddressSpace::Function) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StructLayout { + Std140, + Std430, +} + +// TODO: Encode precision hints in the IR +/// A precision hint used in GLSL declarations. +/// +/// Precision hints can be used to either speed up shader execution or control +/// the precision of arithmetic operations. +/// +/// To use a precision hint simply add it before the type in the declaration. +/// ```glsl +/// mediump float a; +/// ``` +/// +/// The default when no precision is declared is `highp` which means that all +/// operations operate with the type defined width. +/// +/// For `mediump` and `lowp` operations follow the spir-v +/// [`RelaxedPrecision`][RelaxedPrecision] decoration semantics. +/// +/// [RelaxedPrecision]: https://www.khronos.org/registry/SPIR-V/specs/unified1/SPIRV.html#_a_id_relaxedprecisionsection_a_relaxed_precision +#[derive(Debug, Clone, PartialEq, Copy)] +pub enum Precision { + /// `lowp` precision + Low, + /// `mediump` precision + Medium, + /// `highp` precision + High, +} + +#[derive(Debug, Clone, PartialEq, Copy)] +pub enum ParameterQualifier { + In, + Out, + InOut, + Const, +} + +impl ParameterQualifier { + /// Returns true if the argument should be passed as a lhs expression + pub const fn is_lhs(&self) -> bool { + match *self { + ParameterQualifier::Out | ParameterQualifier::InOut => true, + _ => false, + } + } + + /// Converts from a parameter qualifier into a [`ExprPos`] + pub const fn as_pos(&self) -> ExprPos { + match *self { + ParameterQualifier::Out | ParameterQualifier::InOut => ExprPos::Lhs, + _ => ExprPos::Rhs, + } + } +} + +/// The GLSL profile used by a shader. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum Profile { + /// The `core` profile, default when no profile is specified. + Core, +} diff --git a/third_party/rust/naga/src/front/glsl/builtins.rs b/third_party/rust/naga/src/front/glsl/builtins.rs new file mode 100644 index 0000000000..9e3a578c6b --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/builtins.rs @@ -0,0 +1,2314 @@ +use super::{ + ast::{ + BuiltinVariations, FunctionDeclaration, FunctionKind, Overload, ParameterInfo, + ParameterQualifier, + }, + context::Context, + Error, ErrorKind, Frontend, Result, +}; +use crate::{ + BinaryOperator, DerivativeAxis as Axis, DerivativeControl as Ctrl, Expression, Handle, + ImageClass, ImageDimension as Dim, ImageQuery, MathFunction, Module, RelationalFunction, + SampleLevel, Scalar, ScalarKind as Sk, Span, Type, TypeInner, UnaryOperator, VectorSize, +}; + +impl crate::ScalarKind { + const fn dummy_storage_format(&self) -> crate::StorageFormat { + match *self { + Sk::Sint => crate::StorageFormat::R16Sint, + Sk::Uint => crate::StorageFormat::R16Uint, + _ => crate::StorageFormat::R16Float, + } + } +} + +impl Module { + /// Helper function, to create a function prototype for a builtin + fn add_builtin(&mut self, args: Vec<TypeInner>, builtin: MacroCall) -> Overload { + let mut parameters = Vec::with_capacity(args.len()); + let mut parameters_info = Vec::with_capacity(args.len()); + + for arg in args { + parameters.push(self.types.insert( + Type { + name: None, + inner: arg, + }, + Span::default(), + )); + parameters_info.push(ParameterInfo { + qualifier: ParameterQualifier::In, + depth: false, + }); + } + + Overload { + parameters, + parameters_info, + kind: FunctionKind::Macro(builtin), + defined: false, + internal: true, + void: false, + } + } +} + +const fn make_coords_arg(number_of_components: usize, kind: Sk) -> TypeInner { + let scalar = Scalar { kind, width: 4 }; + + match number_of_components { + 1 => TypeInner::Scalar(scalar), + _ => TypeInner::Vector { + size: match number_of_components { + 2 => VectorSize::Bi, + 3 => VectorSize::Tri, + _ => VectorSize::Quad, + }, + scalar, + }, + } +} + +/// Inject builtins into the declaration +/// +/// This is done to not add a large startup cost and not increase memory +/// usage if it isn't needed. +pub fn inject_builtin( + declaration: &mut FunctionDeclaration, + module: &mut Module, + name: &str, + mut variations: BuiltinVariations, +) { + log::trace!( + "{} variations: {:?} {:?}", + name, + variations, + declaration.variations + ); + // Don't regeneate variations + variations.remove(declaration.variations); + declaration.variations |= variations; + + if variations.contains(BuiltinVariations::STANDARD) { + inject_standard_builtins(declaration, module, name) + } + + if variations.contains(BuiltinVariations::DOUBLE) { + inject_double_builtin(declaration, module, name) + } + + match name { + "texture" + | "textureGrad" + | "textureGradOffset" + | "textureLod" + | "textureLodOffset" + | "textureOffset" + | "textureProj" + | "textureProjGrad" + | "textureProjGradOffset" + | "textureProjLod" + | "textureProjLodOffset" + | "textureProjOffset" => { + let f = |kind, dim, arrayed, multi, shadow| { + for bits in 0..=0b11 { + let variant = bits & 0b1 != 0; + let bias = bits & 0b10 != 0; + + let (proj, offset, level_type) = match name { + // texture(gsampler, gvec P, [float bias]); + "texture" => (false, false, TextureLevelType::None), + // textureGrad(gsampler, gvec P, gvec dPdx, gvec dPdy); + "textureGrad" => (false, false, TextureLevelType::Grad), + // textureGradOffset(gsampler, gvec P, gvec dPdx, gvec dPdy, ivec offset); + "textureGradOffset" => (false, true, TextureLevelType::Grad), + // textureLod(gsampler, gvec P, float lod); + "textureLod" => (false, false, TextureLevelType::Lod), + // textureLodOffset(gsampler, gvec P, float lod, ivec offset); + "textureLodOffset" => (false, true, TextureLevelType::Lod), + // textureOffset(gsampler, gvec+1 P, ivec offset, [float bias]); + "textureOffset" => (false, true, TextureLevelType::None), + // textureProj(gsampler, gvec+1 P, [float bias]); + "textureProj" => (true, false, TextureLevelType::None), + // textureProjGrad(gsampler, gvec+1 P, gvec dPdx, gvec dPdy); + "textureProjGrad" => (true, false, TextureLevelType::Grad), + // textureProjGradOffset(gsampler, gvec+1 P, gvec dPdx, gvec dPdy, ivec offset); + "textureProjGradOffset" => (true, true, TextureLevelType::Grad), + // textureProjLod(gsampler, gvec+1 P, float lod); + "textureProjLod" => (true, false, TextureLevelType::Lod), + // textureProjLodOffset(gsampler, gvec+1 P, gvec dPdx, gvec dPdy, ivec offset); + "textureProjLodOffset" => (true, true, TextureLevelType::Lod), + // textureProjOffset(gsampler, gvec+1 P, ivec offset, [float bias]); + "textureProjOffset" => (true, true, TextureLevelType::None), + _ => unreachable!(), + }; + + let builtin = MacroCall::Texture { + proj, + offset, + shadow, + level_type, + }; + + // Parse out the variant settings. + let grad = level_type == TextureLevelType::Grad; + let lod = level_type == TextureLevelType::Lod; + + let supports_variant = proj && !shadow; + if variant && !supports_variant { + continue; + } + + if bias && !matches!(level_type, TextureLevelType::None) { + continue; + } + + // Proj doesn't work with arrayed or Cube + if proj && (arrayed || dim == Dim::Cube) { + continue; + } + + // texture operations with offset are not supported for cube maps + if dim == Dim::Cube && offset { + continue; + } + + // sampler2DArrayShadow can't be used in textureLod or in texture with bias + if (lod || bias) && arrayed && shadow && dim == Dim::D2 { + continue; + } + + // TODO: glsl supports using bias with depth samplers but naga doesn't + if bias && shadow { + continue; + } + + let class = match shadow { + true => ImageClass::Depth { multi }, + false => ImageClass::Sampled { kind, multi }, + }; + + let image = TypeInner::Image { + dim, + arrayed, + class, + }; + + let num_coords_from_dim = image_dims_to_coords_size(dim).min(3); + let mut num_coords = num_coords_from_dim; + + if shadow && proj { + num_coords = 4; + } else if dim == Dim::D1 && shadow { + num_coords = 3; + } else if shadow { + num_coords += 1; + } else if proj { + if variant && num_coords == 4 { + // Normal form already has 4 components, no need to have a variant form. + continue; + } else if variant { + num_coords = 4; + } else { + num_coords += 1; + } + } + + if !(dim == Dim::D1 && shadow) { + num_coords += arrayed as usize; + } + + // Special case: texture(gsamplerCubeArrayShadow) kicks the shadow compare ref to a separate argument, + // since it would otherwise take five arguments. It also can't take a bias, nor can it be proj/grad/lod/offset + // (presumably because nobody asked for it, and implementation complexity?) + if num_coords >= 5 { + if lod || grad || offset || proj || bias { + continue; + } + debug_assert!(dim == Dim::Cube && shadow && arrayed); + } + debug_assert!(num_coords <= 5); + + let vector = make_coords_arg(num_coords, Sk::Float); + let mut args = vec![image, vector]; + + if num_coords == 5 { + args.push(TypeInner::Scalar(Scalar::F32)); + } + + match level_type { + TextureLevelType::Lod => { + args.push(TypeInner::Scalar(Scalar::F32)); + } + TextureLevelType::Grad => { + args.push(make_coords_arg(num_coords_from_dim, Sk::Float)); + args.push(make_coords_arg(num_coords_from_dim, Sk::Float)); + } + TextureLevelType::None => {} + }; + + if offset { + args.push(make_coords_arg(num_coords_from_dim, Sk::Sint)); + } + + if bias { + args.push(TypeInner::Scalar(Scalar::F32)); + } + + declaration + .overloads + .push(module.add_builtin(args, builtin)); + } + }; + + texture_args_generator(TextureArgsOptions::SHADOW | variations.into(), f) + } + "textureSize" => { + let f = |kind, dim, arrayed, multi, shadow| { + let class = match shadow { + true => ImageClass::Depth { multi }, + false => ImageClass::Sampled { kind, multi }, + }; + + let image = TypeInner::Image { + dim, + arrayed, + class, + }; + + let mut args = vec![image]; + + if !multi { + args.push(TypeInner::Scalar(Scalar::I32)) + } + + declaration + .overloads + .push(module.add_builtin(args, MacroCall::TextureSize { arrayed })) + }; + + texture_args_generator( + TextureArgsOptions::SHADOW | TextureArgsOptions::MULTI | variations.into(), + f, + ) + } + "texelFetch" | "texelFetchOffset" => { + let offset = "texelFetchOffset" == name; + let f = |kind, dim, arrayed, multi, _shadow| { + // Cube images aren't supported + if let Dim::Cube = dim { + return; + } + + let image = TypeInner::Image { + dim, + arrayed, + class: ImageClass::Sampled { kind, multi }, + }; + + let dim_value = image_dims_to_coords_size(dim); + let coordinates = make_coords_arg(dim_value + arrayed as usize, Sk::Sint); + + let mut args = vec![image, coordinates, TypeInner::Scalar(Scalar::I32)]; + + if offset { + args.push(make_coords_arg(dim_value, Sk::Sint)); + } + + declaration + .overloads + .push(module.add_builtin(args, MacroCall::ImageLoad { multi })) + }; + + // Don't generate shadow images since they aren't supported + texture_args_generator(TextureArgsOptions::MULTI | variations.into(), f) + } + "imageSize" => { + let f = |kind: Sk, dim, arrayed, _, _| { + // Naga doesn't support cube images and it's usefulness + // is questionable, so they won't be supported for now + if dim == Dim::Cube { + return; + } + + let image = TypeInner::Image { + dim, + arrayed, + class: ImageClass::Storage { + format: kind.dummy_storage_format(), + access: crate::StorageAccess::empty(), + }, + }; + + declaration + .overloads + .push(module.add_builtin(vec![image], MacroCall::TextureSize { arrayed })) + }; + + texture_args_generator(variations.into(), f) + } + "imageLoad" => { + let f = |kind: Sk, dim, arrayed, _, _| { + // Naga doesn't support cube images and it's usefulness + // is questionable, so they won't be supported for now + if dim == Dim::Cube { + return; + } + + let image = TypeInner::Image { + dim, + arrayed, + class: ImageClass::Storage { + format: kind.dummy_storage_format(), + access: crate::StorageAccess::LOAD, + }, + }; + + let dim_value = image_dims_to_coords_size(dim); + let mut coord_size = dim_value + arrayed as usize; + // > Every OpenGL API call that operates on cubemap array + // > textures takes layer-faces, not array layers + // + // So this means that imageCubeArray only takes a three component + // vector coordinate and the third component is a layer index. + if Dim::Cube == dim && arrayed { + coord_size = 3 + } + let coordinates = make_coords_arg(coord_size, Sk::Sint); + + let args = vec![image, coordinates]; + + declaration + .overloads + .push(module.add_builtin(args, MacroCall::ImageLoad { multi: false })) + }; + + // Don't generate shadow nor multisampled images since they aren't supported + texture_args_generator(variations.into(), f) + } + "imageStore" => { + let f = |kind: Sk, dim, arrayed, _, _| { + // Naga doesn't support cube images and it's usefulness + // is questionable, so they won't be supported for now + if dim == Dim::Cube { + return; + } + + let image = TypeInner::Image { + dim, + arrayed, + class: ImageClass::Storage { + format: kind.dummy_storage_format(), + access: crate::StorageAccess::STORE, + }, + }; + + let dim_value = image_dims_to_coords_size(dim); + let mut coord_size = dim_value + arrayed as usize; + // > Every OpenGL API call that operates on cubemap array + // > textures takes layer-faces, not array layers + // + // So this means that imageCubeArray only takes a three component + // vector coordinate and the third component is a layer index. + if Dim::Cube == dim && arrayed { + coord_size = 3 + } + let coordinates = make_coords_arg(coord_size, Sk::Sint); + + let args = vec![ + image, + coordinates, + TypeInner::Vector { + size: VectorSize::Quad, + scalar: Scalar { kind, width: 4 }, + }, + ]; + + let mut overload = module.add_builtin(args, MacroCall::ImageStore); + overload.void = true; + declaration.overloads.push(overload) + }; + + // Don't generate shadow nor multisampled images since they aren't supported + texture_args_generator(variations.into(), f) + } + _ => {} + } +} + +/// Injects the builtins into declaration that don't need any special variations +fn inject_standard_builtins( + declaration: &mut FunctionDeclaration, + module: &mut Module, + name: &str, +) { + match name { + "sampler1D" | "sampler1DArray" | "sampler2D" | "sampler2DArray" | "sampler2DMS" + | "sampler2DMSArray" | "sampler3D" | "samplerCube" | "samplerCubeArray" => { + declaration.overloads.push(module.add_builtin( + vec![ + TypeInner::Image { + dim: match name { + "sampler1D" | "sampler1DArray" => Dim::D1, + "sampler2D" | "sampler2DArray" | "sampler2DMS" | "sampler2DMSArray" => { + Dim::D2 + } + "sampler3D" => Dim::D3, + _ => Dim::Cube, + }, + arrayed: matches!( + name, + "sampler1DArray" + | "sampler2DArray" + | "sampler2DMSArray" + | "samplerCubeArray" + ), + class: ImageClass::Sampled { + kind: Sk::Float, + multi: matches!(name, "sampler2DMS" | "sampler2DMSArray"), + }, + }, + TypeInner::Sampler { comparison: false }, + ], + MacroCall::Sampler, + )) + } + "sampler1DShadow" + | "sampler1DArrayShadow" + | "sampler2DShadow" + | "sampler2DArrayShadow" + | "samplerCubeShadow" + | "samplerCubeArrayShadow" => { + let dim = match name { + "sampler1DShadow" | "sampler1DArrayShadow" => Dim::D1, + "sampler2DShadow" | "sampler2DArrayShadow" => Dim::D2, + _ => Dim::Cube, + }; + let arrayed = matches!( + name, + "sampler1DArrayShadow" | "sampler2DArrayShadow" | "samplerCubeArrayShadow" + ); + + for i in 0..2 { + let ty = TypeInner::Image { + dim, + arrayed, + class: match i { + 0 => ImageClass::Sampled { + kind: Sk::Float, + multi: false, + }, + _ => ImageClass::Depth { multi: false }, + }, + }; + + declaration.overloads.push(module.add_builtin( + vec![ty, TypeInner::Sampler { comparison: true }], + MacroCall::SamplerShadow, + )) + } + } + "sin" | "exp" | "exp2" | "sinh" | "cos" | "cosh" | "tan" | "tanh" | "acos" | "asin" + | "log" | "log2" | "radians" | "degrees" | "asinh" | "acosh" | "atanh" + | "floatBitsToInt" | "floatBitsToUint" | "dFdx" | "dFdxFine" | "dFdxCoarse" | "dFdy" + | "dFdyFine" | "dFdyCoarse" | "fwidth" | "fwidthFine" | "fwidthCoarse" => { + // bits layout + // bit 0 through 1 - dims + for bits in 0..0b100 { + let size = match bits { + 0b00 => None, + 0b01 => Some(VectorSize::Bi), + 0b10 => Some(VectorSize::Tri), + _ => Some(VectorSize::Quad), + }; + let scalar = Scalar::F32; + + declaration.overloads.push(module.add_builtin( + vec![match size { + Some(size) => TypeInner::Vector { size, scalar }, + None => TypeInner::Scalar(scalar), + }], + match name { + "sin" => MacroCall::MathFunction(MathFunction::Sin), + "exp" => MacroCall::MathFunction(MathFunction::Exp), + "exp2" => MacroCall::MathFunction(MathFunction::Exp2), + "sinh" => MacroCall::MathFunction(MathFunction::Sinh), + "cos" => MacroCall::MathFunction(MathFunction::Cos), + "cosh" => MacroCall::MathFunction(MathFunction::Cosh), + "tan" => MacroCall::MathFunction(MathFunction::Tan), + "tanh" => MacroCall::MathFunction(MathFunction::Tanh), + "acos" => MacroCall::MathFunction(MathFunction::Acos), + "asin" => MacroCall::MathFunction(MathFunction::Asin), + "log" => MacroCall::MathFunction(MathFunction::Log), + "log2" => MacroCall::MathFunction(MathFunction::Log2), + "asinh" => MacroCall::MathFunction(MathFunction::Asinh), + "acosh" => MacroCall::MathFunction(MathFunction::Acosh), + "atanh" => MacroCall::MathFunction(MathFunction::Atanh), + "radians" => MacroCall::MathFunction(MathFunction::Radians), + "degrees" => MacroCall::MathFunction(MathFunction::Degrees), + "floatBitsToInt" => MacroCall::BitCast(Sk::Sint), + "floatBitsToUint" => MacroCall::BitCast(Sk::Uint), + "dFdxCoarse" => MacroCall::Derivate(Axis::X, Ctrl::Coarse), + "dFdyCoarse" => MacroCall::Derivate(Axis::Y, Ctrl::Coarse), + "fwidthCoarse" => MacroCall::Derivate(Axis::Width, Ctrl::Coarse), + "dFdxFine" => MacroCall::Derivate(Axis::X, Ctrl::Fine), + "dFdyFine" => MacroCall::Derivate(Axis::Y, Ctrl::Fine), + "fwidthFine" => MacroCall::Derivate(Axis::Width, Ctrl::Fine), + "dFdx" => MacroCall::Derivate(Axis::X, Ctrl::None), + "dFdy" => MacroCall::Derivate(Axis::Y, Ctrl::None), + "fwidth" => MacroCall::Derivate(Axis::Width, Ctrl::None), + _ => unreachable!(), + }, + )) + } + } + "intBitsToFloat" | "uintBitsToFloat" => { + // bits layout + // bit 0 through 1 - dims + for bits in 0..0b100 { + let size = match bits { + 0b00 => None, + 0b01 => Some(VectorSize::Bi), + 0b10 => Some(VectorSize::Tri), + _ => Some(VectorSize::Quad), + }; + let scalar = match name { + "intBitsToFloat" => Scalar::I32, + _ => Scalar::U32, + }; + + declaration.overloads.push(module.add_builtin( + vec![match size { + Some(size) => TypeInner::Vector { size, scalar }, + None => TypeInner::Scalar(scalar), + }], + MacroCall::BitCast(Sk::Float), + )) + } + } + "pow" => { + // bits layout + // bit 0 through 1 - dims + for bits in 0..0b100 { + let size = match bits { + 0b00 => None, + 0b01 => Some(VectorSize::Bi), + 0b10 => Some(VectorSize::Tri), + _ => Some(VectorSize::Quad), + }; + let scalar = Scalar::F32; + let ty = || match size { + Some(size) => TypeInner::Vector { size, scalar }, + None => TypeInner::Scalar(scalar), + }; + + declaration.overloads.push( + module + .add_builtin(vec![ty(), ty()], MacroCall::MathFunction(MathFunction::Pow)), + ) + } + } + "abs" | "sign" => { + // bits layout + // bit 0 through 1 - dims + // bit 2 - float/sint + for bits in 0..0b1000 { + let size = match bits & 0b11 { + 0b00 => None, + 0b01 => Some(VectorSize::Bi), + 0b10 => Some(VectorSize::Tri), + _ => Some(VectorSize::Quad), + }; + let scalar = match bits >> 2 { + 0b0 => Scalar::F32, + _ => Scalar::I32, + }; + + let args = vec![match size { + Some(size) => TypeInner::Vector { size, scalar }, + None => TypeInner::Scalar(scalar), + }]; + + declaration.overloads.push(module.add_builtin( + args, + MacroCall::MathFunction(match name { + "abs" => MathFunction::Abs, + "sign" => MathFunction::Sign, + _ => unreachable!(), + }), + )) + } + } + "bitCount" | "bitfieldReverse" | "bitfieldExtract" | "bitfieldInsert" | "findLSB" + | "findMSB" => { + let fun = match name { + "bitCount" => MathFunction::CountOneBits, + "bitfieldReverse" => MathFunction::ReverseBits, + "bitfieldExtract" => MathFunction::ExtractBits, + "bitfieldInsert" => MathFunction::InsertBits, + "findLSB" => MathFunction::FindLsb, + "findMSB" => MathFunction::FindMsb, + _ => unreachable!(), + }; + + let mc = match fun { + MathFunction::ExtractBits => MacroCall::BitfieldExtract, + MathFunction::InsertBits => MacroCall::BitfieldInsert, + _ => MacroCall::MathFunction(fun), + }; + + // bits layout + // bit 0 - int/uint + // bit 1 through 2 - dims + for bits in 0..0b1000 { + let scalar = match bits & 0b1 { + 0b0 => Scalar::I32, + _ => Scalar::U32, + }; + let size = match bits >> 1 { + 0b00 => None, + 0b01 => Some(VectorSize::Bi), + 0b10 => Some(VectorSize::Tri), + _ => Some(VectorSize::Quad), + }; + + let ty = || match size { + Some(size) => TypeInner::Vector { size, scalar }, + None => TypeInner::Scalar(scalar), + }; + + let mut args = vec![ty()]; + + match fun { + MathFunction::ExtractBits => { + args.push(TypeInner::Scalar(Scalar::I32)); + args.push(TypeInner::Scalar(Scalar::I32)); + } + MathFunction::InsertBits => { + args.push(ty()); + args.push(TypeInner::Scalar(Scalar::I32)); + args.push(TypeInner::Scalar(Scalar::I32)); + } + _ => {} + } + + // we need to cast the return type of findLsb / findMsb + let mc = if scalar.kind == Sk::Uint { + match mc { + MacroCall::MathFunction(MathFunction::FindLsb) => MacroCall::FindLsbUint, + MacroCall::MathFunction(MathFunction::FindMsb) => MacroCall::FindMsbUint, + mc => mc, + } + } else { + mc + }; + + declaration.overloads.push(module.add_builtin(args, mc)) + } + } + "packSnorm4x8" | "packUnorm4x8" | "packSnorm2x16" | "packUnorm2x16" | "packHalf2x16" => { + let fun = match name { + "packSnorm4x8" => MathFunction::Pack4x8snorm, + "packUnorm4x8" => MathFunction::Pack4x8unorm, + "packSnorm2x16" => MathFunction::Pack2x16unorm, + "packUnorm2x16" => MathFunction::Pack2x16snorm, + "packHalf2x16" => MathFunction::Pack2x16float, + _ => unreachable!(), + }; + + let ty = match fun { + MathFunction::Pack4x8snorm | MathFunction::Pack4x8unorm => TypeInner::Vector { + size: crate::VectorSize::Quad, + scalar: Scalar::F32, + }, + MathFunction::Pack2x16unorm + | MathFunction::Pack2x16snorm + | MathFunction::Pack2x16float => TypeInner::Vector { + size: crate::VectorSize::Bi, + scalar: Scalar::F32, + }, + _ => unreachable!(), + }; + + let args = vec![ty]; + + declaration + .overloads + .push(module.add_builtin(args, MacroCall::MathFunction(fun))); + } + "unpackSnorm4x8" | "unpackUnorm4x8" | "unpackSnorm2x16" | "unpackUnorm2x16" + | "unpackHalf2x16" => { + let fun = match name { + "unpackSnorm4x8" => MathFunction::Unpack4x8snorm, + "unpackUnorm4x8" => MathFunction::Unpack4x8unorm, + "unpackSnorm2x16" => MathFunction::Unpack2x16snorm, + "unpackUnorm2x16" => MathFunction::Unpack2x16unorm, + "unpackHalf2x16" => MathFunction::Unpack2x16float, + _ => unreachable!(), + }; + + let args = vec![TypeInner::Scalar(Scalar::U32)]; + + declaration + .overloads + .push(module.add_builtin(args, MacroCall::MathFunction(fun))); + } + "atan" => { + // bits layout + // bit 0 - atan/atan2 + // bit 1 through 2 - dims + for bits in 0..0b1000 { + let fun = match bits & 0b1 { + 0b0 => MathFunction::Atan, + _ => MathFunction::Atan2, + }; + let size = match bits >> 1 { + 0b00 => None, + 0b01 => Some(VectorSize::Bi), + 0b10 => Some(VectorSize::Tri), + _ => Some(VectorSize::Quad), + }; + let scalar = Scalar::F32; + let ty = || match size { + Some(size) => TypeInner::Vector { size, scalar }, + None => TypeInner::Scalar(scalar), + }; + + let mut args = vec![ty()]; + + if fun == MathFunction::Atan2 { + args.push(ty()) + } + + declaration + .overloads + .push(module.add_builtin(args, MacroCall::MathFunction(fun))) + } + } + "all" | "any" | "not" => { + // bits layout + // bit 0 through 1 - dims + for bits in 0..0b11 { + let size = match bits { + 0b00 => VectorSize::Bi, + 0b01 => VectorSize::Tri, + _ => VectorSize::Quad, + }; + + let args = vec![TypeInner::Vector { + size, + scalar: Scalar::BOOL, + }]; + + let fun = match name { + "all" => MacroCall::Relational(RelationalFunction::All), + "any" => MacroCall::Relational(RelationalFunction::Any), + "not" => MacroCall::Unary(UnaryOperator::LogicalNot), + _ => unreachable!(), + }; + + declaration.overloads.push(module.add_builtin(args, fun)) + } + } + "lessThan" | "greaterThan" | "lessThanEqual" | "greaterThanEqual" => { + for bits in 0..0b1001 { + let (size, scalar) = match bits { + 0b0000 => (VectorSize::Bi, Scalar::F32), + 0b0001 => (VectorSize::Tri, Scalar::F32), + 0b0010 => (VectorSize::Quad, Scalar::F32), + 0b0011 => (VectorSize::Bi, Scalar::I32), + 0b0100 => (VectorSize::Tri, Scalar::I32), + 0b0101 => (VectorSize::Quad, Scalar::I32), + 0b0110 => (VectorSize::Bi, Scalar::U32), + 0b0111 => (VectorSize::Tri, Scalar::U32), + _ => (VectorSize::Quad, Scalar::U32), + }; + + let ty = || TypeInner::Vector { size, scalar }; + let args = vec![ty(), ty()]; + + let fun = MacroCall::Binary(match name { + "lessThan" => BinaryOperator::Less, + "greaterThan" => BinaryOperator::Greater, + "lessThanEqual" => BinaryOperator::LessEqual, + "greaterThanEqual" => BinaryOperator::GreaterEqual, + _ => unreachable!(), + }); + + declaration.overloads.push(module.add_builtin(args, fun)) + } + } + "equal" | "notEqual" => { + for bits in 0..0b1100 { + let (size, scalar) = match bits { + 0b0000 => (VectorSize::Bi, Scalar::F32), + 0b0001 => (VectorSize::Tri, Scalar::F32), + 0b0010 => (VectorSize::Quad, Scalar::F32), + 0b0011 => (VectorSize::Bi, Scalar::I32), + 0b0100 => (VectorSize::Tri, Scalar::I32), + 0b0101 => (VectorSize::Quad, Scalar::I32), + 0b0110 => (VectorSize::Bi, Scalar::U32), + 0b0111 => (VectorSize::Tri, Scalar::U32), + 0b1000 => (VectorSize::Quad, Scalar::U32), + 0b1001 => (VectorSize::Bi, Scalar::BOOL), + 0b1010 => (VectorSize::Tri, Scalar::BOOL), + _ => (VectorSize::Quad, Scalar::BOOL), + }; + + let ty = || TypeInner::Vector { size, scalar }; + let args = vec![ty(), ty()]; + + let fun = MacroCall::Binary(match name { + "equal" => BinaryOperator::Equal, + "notEqual" => BinaryOperator::NotEqual, + _ => unreachable!(), + }); + + declaration.overloads.push(module.add_builtin(args, fun)) + } + } + "min" | "max" => { + // bits layout + // bit 0 through 1 - scalar kind + // bit 2 through 4 - dims + for bits in 0..0b11100 { + let scalar = match bits & 0b11 { + 0b00 => Scalar::F32, + 0b01 => Scalar::I32, + 0b10 => Scalar::U32, + _ => continue, + }; + let (size, second_size) = match bits >> 2 { + 0b000 => (None, None), + 0b001 => (Some(VectorSize::Bi), None), + 0b010 => (Some(VectorSize::Tri), None), + 0b011 => (Some(VectorSize::Quad), None), + 0b100 => (Some(VectorSize::Bi), Some(VectorSize::Bi)), + 0b101 => (Some(VectorSize::Tri), Some(VectorSize::Tri)), + _ => (Some(VectorSize::Quad), Some(VectorSize::Quad)), + }; + + let args = vec![ + match size { + Some(size) => TypeInner::Vector { size, scalar }, + None => TypeInner::Scalar(scalar), + }, + match second_size { + Some(size) => TypeInner::Vector { size, scalar }, + None => TypeInner::Scalar(scalar), + }, + ]; + + let fun = match name { + "max" => MacroCall::Splatted(MathFunction::Max, size, 1), + "min" => MacroCall::Splatted(MathFunction::Min, size, 1), + _ => unreachable!(), + }; + + declaration.overloads.push(module.add_builtin(args, fun)) + } + } + "mix" => { + // bits layout + // bit 0 through 1 - dims + // bit 2 through 4 - types + // + // 0b10011 is the last element since splatted single elements + // were already added + for bits in 0..0b10011 { + let size = match bits & 0b11 { + 0b00 => Some(VectorSize::Bi), + 0b01 => Some(VectorSize::Tri), + 0b10 => Some(VectorSize::Quad), + _ => None, + }; + let (scalar, splatted, boolean) = match bits >> 2 { + 0b000 => (Scalar::I32, false, true), + 0b001 => (Scalar::U32, false, true), + 0b010 => (Scalar::F32, false, true), + 0b011 => (Scalar::F32, false, false), + _ => (Scalar::F32, true, false), + }; + + let ty = |scalar| match size { + Some(size) => TypeInner::Vector { size, scalar }, + None => TypeInner::Scalar(scalar), + }; + let args = vec![ + ty(scalar), + ty(scalar), + match (boolean, splatted) { + (true, _) => ty(Scalar::BOOL), + (_, false) => TypeInner::Scalar(scalar), + _ => ty(scalar), + }, + ]; + + declaration.overloads.push(module.add_builtin( + args, + match boolean { + true => MacroCall::MixBoolean, + false => MacroCall::Splatted(MathFunction::Mix, size, 2), + }, + )) + } + } + "clamp" => { + // bits layout + // bit 0 through 1 - float/int/uint + // bit 2 through 3 - dims + // bit 4 - splatted + // + // 0b11010 is the last element since splatted single elements + // were already added + for bits in 0..0b11011 { + let scalar = match bits & 0b11 { + 0b00 => Scalar::F32, + 0b01 => Scalar::I32, + 0b10 => Scalar::U32, + _ => continue, + }; + let size = match (bits >> 2) & 0b11 { + 0b00 => Some(VectorSize::Bi), + 0b01 => Some(VectorSize::Tri), + 0b10 => Some(VectorSize::Quad), + _ => None, + }; + let splatted = bits & 0b10000 == 0b10000; + + let base_ty = || match size { + Some(size) => TypeInner::Vector { size, scalar }, + None => TypeInner::Scalar(scalar), + }; + let limit_ty = || match splatted { + true => TypeInner::Scalar(scalar), + false => base_ty(), + }; + + let args = vec![base_ty(), limit_ty(), limit_ty()]; + + declaration + .overloads + .push(module.add_builtin(args, MacroCall::Clamp(size))) + } + } + "barrier" => declaration + .overloads + .push(module.add_builtin(Vec::new(), MacroCall::Barrier)), + // Add common builtins with floats + _ => inject_common_builtin(declaration, module, name, 4), + } +} + +/// Injects the builtins into declaration that need doubles +fn inject_double_builtin(declaration: &mut FunctionDeclaration, module: &mut Module, name: &str) { + match name { + "abs" | "sign" => { + // bits layout + // bit 0 through 1 - dims + for bits in 0..0b100 { + let size = match bits { + 0b00 => None, + 0b01 => Some(VectorSize::Bi), + 0b10 => Some(VectorSize::Tri), + _ => Some(VectorSize::Quad), + }; + let scalar = Scalar::F64; + + let args = vec![match size { + Some(size) => TypeInner::Vector { size, scalar }, + None => TypeInner::Scalar(scalar), + }]; + + declaration.overloads.push(module.add_builtin( + args, + MacroCall::MathFunction(match name { + "abs" => MathFunction::Abs, + "sign" => MathFunction::Sign, + _ => unreachable!(), + }), + )) + } + } + "min" | "max" => { + // bits layout + // bit 0 through 2 - dims + for bits in 0..0b111 { + let (size, second_size) = match bits { + 0b000 => (None, None), + 0b001 => (Some(VectorSize::Bi), None), + 0b010 => (Some(VectorSize::Tri), None), + 0b011 => (Some(VectorSize::Quad), None), + 0b100 => (Some(VectorSize::Bi), Some(VectorSize::Bi)), + 0b101 => (Some(VectorSize::Tri), Some(VectorSize::Tri)), + _ => (Some(VectorSize::Quad), Some(VectorSize::Quad)), + }; + let scalar = Scalar::F64; + + let args = vec![ + match size { + Some(size) => TypeInner::Vector { size, scalar }, + None => TypeInner::Scalar(scalar), + }, + match second_size { + Some(size) => TypeInner::Vector { size, scalar }, + None => TypeInner::Scalar(scalar), + }, + ]; + + let fun = match name { + "max" => MacroCall::Splatted(MathFunction::Max, size, 1), + "min" => MacroCall::Splatted(MathFunction::Min, size, 1), + _ => unreachable!(), + }; + + declaration.overloads.push(module.add_builtin(args, fun)) + } + } + "mix" => { + // bits layout + // bit 0 through 1 - dims + // bit 2 through 3 - splatted/boolean + // + // 0b1010 is the last element since splatted with single elements + // is equal to normal single elements + for bits in 0..0b1011 { + let size = match bits & 0b11 { + 0b00 => Some(VectorSize::Quad), + 0b01 => Some(VectorSize::Bi), + 0b10 => Some(VectorSize::Tri), + _ => None, + }; + let scalar = Scalar::F64; + let (splatted, boolean) = match bits >> 2 { + 0b00 => (false, false), + 0b01 => (false, true), + _ => (true, false), + }; + + let ty = |scalar| match size { + Some(size) => TypeInner::Vector { size, scalar }, + None => TypeInner::Scalar(scalar), + }; + let args = vec![ + ty(scalar), + ty(scalar), + match (boolean, splatted) { + (true, _) => ty(Scalar::BOOL), + (_, false) => TypeInner::Scalar(scalar), + _ => ty(scalar), + }, + ]; + + declaration.overloads.push(module.add_builtin( + args, + match boolean { + true => MacroCall::MixBoolean, + false => MacroCall::Splatted(MathFunction::Mix, size, 2), + }, + )) + } + } + "clamp" => { + // bits layout + // bit 0 through 1 - dims + // bit 2 - splatted + // + // 0b110 is the last element since splatted with single elements + // is equal to normal single elements + for bits in 0..0b111 { + let scalar = Scalar::F64; + let size = match bits & 0b11 { + 0b00 => Some(VectorSize::Bi), + 0b01 => Some(VectorSize::Tri), + 0b10 => Some(VectorSize::Quad), + _ => None, + }; + let splatted = bits & 0b100 == 0b100; + + let base_ty = || match size { + Some(size) => TypeInner::Vector { size, scalar }, + None => TypeInner::Scalar(scalar), + }; + let limit_ty = || match splatted { + true => TypeInner::Scalar(scalar), + false => base_ty(), + }; + + let args = vec![base_ty(), limit_ty(), limit_ty()]; + + declaration + .overloads + .push(module.add_builtin(args, MacroCall::Clamp(size))) + } + } + "lessThan" | "greaterThan" | "lessThanEqual" | "greaterThanEqual" | "equal" + | "notEqual" => { + let scalar = Scalar::F64; + for bits in 0..0b11 { + let size = match bits { + 0b00 => VectorSize::Bi, + 0b01 => VectorSize::Tri, + _ => VectorSize::Quad, + }; + + let ty = || TypeInner::Vector { size, scalar }; + let args = vec![ty(), ty()]; + + let fun = MacroCall::Binary(match name { + "lessThan" => BinaryOperator::Less, + "greaterThan" => BinaryOperator::Greater, + "lessThanEqual" => BinaryOperator::LessEqual, + "greaterThanEqual" => BinaryOperator::GreaterEqual, + "equal" => BinaryOperator::Equal, + "notEqual" => BinaryOperator::NotEqual, + _ => unreachable!(), + }); + + declaration.overloads.push(module.add_builtin(args, fun)) + } + } + // Add common builtins with doubles + _ => inject_common_builtin(declaration, module, name, 8), + } +} + +/// Injects the builtins into declaration that can used either float or doubles +fn inject_common_builtin( + declaration: &mut FunctionDeclaration, + module: &mut Module, + name: &str, + float_width: crate::Bytes, +) { + let float_scalar = Scalar { + kind: Sk::Float, + width: float_width, + }; + match name { + "ceil" | "round" | "roundEven" | "floor" | "fract" | "trunc" | "sqrt" | "inversesqrt" + | "normalize" | "length" | "isinf" | "isnan" => { + // bits layout + // bit 0 through 1 - dims + for bits in 0..0b100 { + let size = match bits { + 0b00 => None, + 0b01 => Some(VectorSize::Bi), + 0b10 => Some(VectorSize::Tri), + _ => Some(VectorSize::Quad), + }; + + let args = vec![match size { + Some(size) => TypeInner::Vector { + size, + scalar: float_scalar, + }, + None => TypeInner::Scalar(float_scalar), + }]; + + let fun = match name { + "ceil" => MacroCall::MathFunction(MathFunction::Ceil), + "round" | "roundEven" => MacroCall::MathFunction(MathFunction::Round), + "floor" => MacroCall::MathFunction(MathFunction::Floor), + "fract" => MacroCall::MathFunction(MathFunction::Fract), + "trunc" => MacroCall::MathFunction(MathFunction::Trunc), + "sqrt" => MacroCall::MathFunction(MathFunction::Sqrt), + "inversesqrt" => MacroCall::MathFunction(MathFunction::InverseSqrt), + "normalize" => MacroCall::MathFunction(MathFunction::Normalize), + "length" => MacroCall::MathFunction(MathFunction::Length), + "isinf" => MacroCall::Relational(RelationalFunction::IsInf), + "isnan" => MacroCall::Relational(RelationalFunction::IsNan), + _ => unreachable!(), + }; + + declaration.overloads.push(module.add_builtin(args, fun)) + } + } + "dot" | "reflect" | "distance" | "ldexp" => { + // bits layout + // bit 0 through 1 - dims + for bits in 0..0b100 { + let size = match bits { + 0b00 => None, + 0b01 => Some(VectorSize::Bi), + 0b10 => Some(VectorSize::Tri), + _ => Some(VectorSize::Quad), + }; + let ty = |scalar| match size { + Some(size) => TypeInner::Vector { size, scalar }, + None => TypeInner::Scalar(scalar), + }; + + let fun = match name { + "dot" => MacroCall::MathFunction(MathFunction::Dot), + "reflect" => MacroCall::MathFunction(MathFunction::Reflect), + "distance" => MacroCall::MathFunction(MathFunction::Distance), + "ldexp" => MacroCall::MathFunction(MathFunction::Ldexp), + _ => unreachable!(), + }; + + let second_scalar = match fun { + MacroCall::MathFunction(MathFunction::Ldexp) => Scalar::I32, + _ => float_scalar, + }; + + declaration + .overloads + .push(module.add_builtin(vec![ty(float_scalar), ty(second_scalar)], fun)) + } + } + "transpose" => { + // bits layout + // bit 0 through 3 - dims + for bits in 0..0b1001 { + let (rows, columns) = match bits { + 0b0000 => (VectorSize::Bi, VectorSize::Bi), + 0b0001 => (VectorSize::Bi, VectorSize::Tri), + 0b0010 => (VectorSize::Bi, VectorSize::Quad), + 0b0011 => (VectorSize::Tri, VectorSize::Bi), + 0b0100 => (VectorSize::Tri, VectorSize::Tri), + 0b0101 => (VectorSize::Tri, VectorSize::Quad), + 0b0110 => (VectorSize::Quad, VectorSize::Bi), + 0b0111 => (VectorSize::Quad, VectorSize::Tri), + _ => (VectorSize::Quad, VectorSize::Quad), + }; + + declaration.overloads.push(module.add_builtin( + vec![TypeInner::Matrix { + columns, + rows, + scalar: float_scalar, + }], + MacroCall::MathFunction(MathFunction::Transpose), + )) + } + } + "inverse" | "determinant" => { + // bits layout + // bit 0 through 1 - dims + for bits in 0..0b11 { + let (rows, columns) = match bits { + 0b00 => (VectorSize::Bi, VectorSize::Bi), + 0b01 => (VectorSize::Tri, VectorSize::Tri), + _ => (VectorSize::Quad, VectorSize::Quad), + }; + + let args = vec![TypeInner::Matrix { + columns, + rows, + scalar: float_scalar, + }]; + + declaration.overloads.push(module.add_builtin( + args, + MacroCall::MathFunction(match name { + "inverse" => MathFunction::Inverse, + "determinant" => MathFunction::Determinant, + _ => unreachable!(), + }), + )) + } + } + "mod" | "step" => { + // bits layout + // bit 0 through 2 - dims + for bits in 0..0b111 { + let (size, second_size) = match bits { + 0b000 => (None, None), + 0b001 => (Some(VectorSize::Bi), None), + 0b010 => (Some(VectorSize::Tri), None), + 0b011 => (Some(VectorSize::Quad), None), + 0b100 => (Some(VectorSize::Bi), Some(VectorSize::Bi)), + 0b101 => (Some(VectorSize::Tri), Some(VectorSize::Tri)), + _ => (Some(VectorSize::Quad), Some(VectorSize::Quad)), + }; + + let mut args = Vec::with_capacity(2); + let step = name == "step"; + + for i in 0..2 { + let maybe_size = match i == step as u32 { + true => size, + false => second_size, + }; + + args.push(match maybe_size { + Some(size) => TypeInner::Vector { + size, + scalar: float_scalar, + }, + None => TypeInner::Scalar(float_scalar), + }) + } + + let fun = match name { + "mod" => MacroCall::Mod(size), + "step" => MacroCall::Splatted(MathFunction::Step, size, 0), + _ => unreachable!(), + }; + + declaration.overloads.push(module.add_builtin(args, fun)) + } + } + // TODO: https://github.com/gfx-rs/naga/issues/2526 + // "modf" | "frexp" => { ... } + "cross" => { + let args = vec![ + TypeInner::Vector { + size: VectorSize::Tri, + scalar: float_scalar, + }, + TypeInner::Vector { + size: VectorSize::Tri, + scalar: float_scalar, + }, + ]; + + declaration + .overloads + .push(module.add_builtin(args, MacroCall::MathFunction(MathFunction::Cross))) + } + "outerProduct" => { + // bits layout + // bit 0 through 3 - dims + for bits in 0..0b1001 { + let (size1, size2) = match bits { + 0b0000 => (VectorSize::Bi, VectorSize::Bi), + 0b0001 => (VectorSize::Bi, VectorSize::Tri), + 0b0010 => (VectorSize::Bi, VectorSize::Quad), + 0b0011 => (VectorSize::Tri, VectorSize::Bi), + 0b0100 => (VectorSize::Tri, VectorSize::Tri), + 0b0101 => (VectorSize::Tri, VectorSize::Quad), + 0b0110 => (VectorSize::Quad, VectorSize::Bi), + 0b0111 => (VectorSize::Quad, VectorSize::Tri), + _ => (VectorSize::Quad, VectorSize::Quad), + }; + + let args = vec![ + TypeInner::Vector { + size: size1, + scalar: float_scalar, + }, + TypeInner::Vector { + size: size2, + scalar: float_scalar, + }, + ]; + + declaration + .overloads + .push(module.add_builtin(args, MacroCall::MathFunction(MathFunction::Outer))) + } + } + "faceforward" | "fma" => { + // bits layout + // bit 0 through 1 - dims + for bits in 0..0b100 { + let size = match bits { + 0b00 => None, + 0b01 => Some(VectorSize::Bi), + 0b10 => Some(VectorSize::Tri), + _ => Some(VectorSize::Quad), + }; + + let ty = || match size { + Some(size) => TypeInner::Vector { + size, + scalar: float_scalar, + }, + None => TypeInner::Scalar(float_scalar), + }; + let args = vec![ty(), ty(), ty()]; + + let fun = match name { + "faceforward" => MacroCall::MathFunction(MathFunction::FaceForward), + "fma" => MacroCall::MathFunction(MathFunction::Fma), + _ => unreachable!(), + }; + + declaration.overloads.push(module.add_builtin(args, fun)) + } + } + "refract" => { + // bits layout + // bit 0 through 1 - dims + for bits in 0..0b100 { + let size = match bits { + 0b00 => None, + 0b01 => Some(VectorSize::Bi), + 0b10 => Some(VectorSize::Tri), + _ => Some(VectorSize::Quad), + }; + + let ty = || match size { + Some(size) => TypeInner::Vector { + size, + scalar: float_scalar, + }, + None => TypeInner::Scalar(float_scalar), + }; + let args = vec![ty(), ty(), TypeInner::Scalar(Scalar::F32)]; + declaration + .overloads + .push(module.add_builtin(args, MacroCall::MathFunction(MathFunction::Refract))) + } + } + "smoothstep" => { + // bit 0 - splatted + // bit 1 through 2 - dims + for bits in 0..0b1000 { + let splatted = bits & 0b1 == 0b1; + let size = match bits >> 1 { + 0b00 => None, + 0b01 => Some(VectorSize::Bi), + 0b10 => Some(VectorSize::Tri), + _ => Some(VectorSize::Quad), + }; + + if splatted && size.is_none() { + continue; + } + + let base_ty = || match size { + Some(size) => TypeInner::Vector { + size, + scalar: float_scalar, + }, + None => TypeInner::Scalar(float_scalar), + }; + let ty = || match splatted { + true => TypeInner::Scalar(float_scalar), + false => base_ty(), + }; + declaration.overloads.push(module.add_builtin( + vec![ty(), ty(), base_ty()], + MacroCall::SmoothStep { splatted: size }, + )) + } + } + // The function isn't a builtin or we don't yet support it + _ => {} + } +} + +#[derive(Clone, Copy, PartialEq, Debug)] +pub enum TextureLevelType { + None, + Lod, + Grad, +} + +/// A compiler defined builtin function +#[derive(Clone, Copy, PartialEq, Debug)] +pub enum MacroCall { + Sampler, + SamplerShadow, + Texture { + proj: bool, + offset: bool, + shadow: bool, + level_type: TextureLevelType, + }, + TextureSize { + arrayed: bool, + }, + ImageLoad { + multi: bool, + }, + ImageStore, + MathFunction(MathFunction), + FindLsbUint, + FindMsbUint, + BitfieldExtract, + BitfieldInsert, + Relational(RelationalFunction), + Unary(UnaryOperator), + Binary(BinaryOperator), + Mod(Option<VectorSize>), + Splatted(MathFunction, Option<VectorSize>, usize), + MixBoolean, + Clamp(Option<VectorSize>), + BitCast(Sk), + Derivate(Axis, Ctrl), + Barrier, + /// SmoothStep needs a separate variant because it might need it's inputs + /// to be splatted depending on the overload + SmoothStep { + /// The size of the splat operation if some + splatted: Option<VectorSize>, + }, +} + +impl MacroCall { + /// Adds the necessary expressions and statements to the passed body and + /// finally returns the final expression with the correct result + pub fn call( + &self, + frontend: &mut Frontend, + ctx: &mut Context, + args: &mut [Handle<Expression>], + meta: Span, + ) -> Result<Option<Handle<Expression>>> { + Ok(Some(match *self { + MacroCall::Sampler => { + ctx.samplers.insert(args[0], args[1]); + args[0] + } + MacroCall::SamplerShadow => { + sampled_to_depth(ctx, args[0], meta, &mut frontend.errors); + ctx.invalidate_expression(args[0], meta)?; + ctx.samplers.insert(args[0], args[1]); + args[0] + } + MacroCall::Texture { + proj, + offset, + shadow, + level_type, + } => { + let mut coords = args[1]; + + if proj { + let size = match *ctx.resolve_type(coords, meta)? { + TypeInner::Vector { size, .. } => size, + _ => unreachable!(), + }; + let mut right = ctx.add_expression( + Expression::AccessIndex { + base: coords, + index: size as u32 - 1, + }, + Span::default(), + )?; + let left = if let VectorSize::Bi = size { + ctx.add_expression( + Expression::AccessIndex { + base: coords, + index: 0, + }, + Span::default(), + )? + } else { + let size = match size { + VectorSize::Tri => VectorSize::Bi, + _ => VectorSize::Tri, + }; + right = ctx.add_expression( + Expression::Splat { size, value: right }, + Span::default(), + )?; + ctx.vector_resize(size, coords, Span::default())? + }; + coords = ctx.add_expression( + Expression::Binary { + op: BinaryOperator::Divide, + left, + right, + }, + Span::default(), + )?; + } + + let extra = args.get(2).copied(); + let comps = frontend.coordinate_components(ctx, args[0], coords, extra, meta)?; + + let mut num_args = 2; + + if comps.used_extra { + num_args += 1; + }; + + // Parse out explicit texture level. + let mut level = match level_type { + TextureLevelType::None => SampleLevel::Auto, + + TextureLevelType::Lod => { + num_args += 1; + + if shadow { + log::warn!("Assuming LOD {:?} is zero", args[2],); + + SampleLevel::Zero + } else { + SampleLevel::Exact(args[2]) + } + } + + TextureLevelType::Grad => { + num_args += 2; + + if shadow { + log::warn!( + "Assuming gradients {:?} and {:?} are not greater than 1", + args[2], + args[3], + ); + SampleLevel::Zero + } else { + SampleLevel::Gradient { + x: args[2], + y: args[3], + } + } + } + }; + + let texture_offset = match offset { + true => { + let offset_arg = args[num_args]; + num_args += 1; + match ctx.lift_up_const_expression(offset_arg) { + Ok(v) => Some(v), + Err(e) => { + frontend.errors.push(e); + None + } + } + } + false => None, + }; + + // Now go back and look for optional bias arg (if available) + if let TextureLevelType::None = level_type { + level = args + .get(num_args) + .copied() + .map_or(SampleLevel::Auto, SampleLevel::Bias); + } + + texture_call(ctx, args[0], level, comps, texture_offset, meta)? + } + + MacroCall::TextureSize { arrayed } => { + let mut expr = ctx.add_expression( + Expression::ImageQuery { + image: args[0], + query: ImageQuery::Size { + level: args.get(1).copied(), + }, + }, + Span::default(), + )?; + + if arrayed { + let mut components = Vec::with_capacity(4); + + let size = match *ctx.resolve_type(expr, meta)? { + TypeInner::Vector { size: ori_size, .. } => { + for index in 0..(ori_size as u32) { + components.push(ctx.add_expression( + Expression::AccessIndex { base: expr, index }, + Span::default(), + )?) + } + + match ori_size { + VectorSize::Bi => VectorSize::Tri, + _ => VectorSize::Quad, + } + } + _ => { + components.push(expr); + VectorSize::Bi + } + }; + + components.push(ctx.add_expression( + Expression::ImageQuery { + image: args[0], + query: ImageQuery::NumLayers, + }, + Span::default(), + )?); + + let ty = ctx.module.types.insert( + Type { + name: None, + inner: TypeInner::Vector { + size, + scalar: Scalar::U32, + }, + }, + Span::default(), + ); + + expr = ctx.add_expression(Expression::Compose { components, ty }, meta)? + } + + ctx.add_expression( + Expression::As { + expr, + kind: Sk::Sint, + convert: Some(4), + }, + Span::default(), + )? + } + MacroCall::ImageLoad { multi } => { + let comps = frontend.coordinate_components(ctx, args[0], args[1], None, meta)?; + let (sample, level) = match (multi, args.get(2)) { + (_, None) => (None, None), + (true, Some(&arg)) => (Some(arg), None), + (false, Some(&arg)) => (None, Some(arg)), + }; + ctx.add_expression( + Expression::ImageLoad { + image: args[0], + coordinate: comps.coordinate, + array_index: comps.array_index, + sample, + level, + }, + Span::default(), + )? + } + MacroCall::ImageStore => { + let comps = frontend.coordinate_components(ctx, args[0], args[1], None, meta)?; + ctx.emit_restart(); + ctx.body.push( + crate::Statement::ImageStore { + image: args[0], + coordinate: comps.coordinate, + array_index: comps.array_index, + value: args[2], + }, + meta, + ); + return Ok(None); + } + MacroCall::MathFunction(fun) => ctx.add_expression( + Expression::Math { + fun, + arg: args[0], + arg1: args.get(1).copied(), + arg2: args.get(2).copied(), + arg3: args.get(3).copied(), + }, + Span::default(), + )?, + mc @ (MacroCall::FindLsbUint | MacroCall::FindMsbUint) => { + let fun = match mc { + MacroCall::FindLsbUint => MathFunction::FindLsb, + MacroCall::FindMsbUint => MathFunction::FindMsb, + _ => unreachable!(), + }; + let res = ctx.add_expression( + Expression::Math { + fun, + arg: args[0], + arg1: None, + arg2: None, + arg3: None, + }, + Span::default(), + )?; + ctx.add_expression( + Expression::As { + expr: res, + kind: Sk::Sint, + convert: Some(4), + }, + Span::default(), + )? + } + MacroCall::BitfieldInsert => { + let conv_arg_2 = ctx.add_expression( + Expression::As { + expr: args[2], + kind: Sk::Uint, + convert: Some(4), + }, + Span::default(), + )?; + let conv_arg_3 = ctx.add_expression( + Expression::As { + expr: args[3], + kind: Sk::Uint, + convert: Some(4), + }, + Span::default(), + )?; + ctx.add_expression( + Expression::Math { + fun: MathFunction::InsertBits, + arg: args[0], + arg1: Some(args[1]), + arg2: Some(conv_arg_2), + arg3: Some(conv_arg_3), + }, + Span::default(), + )? + } + MacroCall::BitfieldExtract => { + let conv_arg_1 = ctx.add_expression( + Expression::As { + expr: args[1], + kind: Sk::Uint, + convert: Some(4), + }, + Span::default(), + )?; + let conv_arg_2 = ctx.add_expression( + Expression::As { + expr: args[2], + kind: Sk::Uint, + convert: Some(4), + }, + Span::default(), + )?; + ctx.add_expression( + Expression::Math { + fun: MathFunction::ExtractBits, + arg: args[0], + arg1: Some(conv_arg_1), + arg2: Some(conv_arg_2), + arg3: None, + }, + Span::default(), + )? + } + MacroCall::Relational(fun) => ctx.add_expression( + Expression::Relational { + fun, + argument: args[0], + }, + Span::default(), + )?, + MacroCall::Unary(op) => { + ctx.add_expression(Expression::Unary { op, expr: args[0] }, Span::default())? + } + MacroCall::Binary(op) => ctx.add_expression( + Expression::Binary { + op, + left: args[0], + right: args[1], + }, + Span::default(), + )?, + MacroCall::Mod(size) => { + ctx.implicit_splat(&mut args[1], meta, size)?; + + // x - y * floor(x / y) + + let div = ctx.add_expression( + Expression::Binary { + op: BinaryOperator::Divide, + left: args[0], + right: args[1], + }, + Span::default(), + )?; + let floor = ctx.add_expression( + Expression::Math { + fun: MathFunction::Floor, + arg: div, + arg1: None, + arg2: None, + arg3: None, + }, + Span::default(), + )?; + let mult = ctx.add_expression( + Expression::Binary { + op: BinaryOperator::Multiply, + left: floor, + right: args[1], + }, + Span::default(), + )?; + ctx.add_expression( + Expression::Binary { + op: BinaryOperator::Subtract, + left: args[0], + right: mult, + }, + Span::default(), + )? + } + MacroCall::Splatted(fun, size, i) => { + ctx.implicit_splat(&mut args[i], meta, size)?; + + ctx.add_expression( + Expression::Math { + fun, + arg: args[0], + arg1: args.get(1).copied(), + arg2: args.get(2).copied(), + arg3: args.get(3).copied(), + }, + Span::default(), + )? + } + MacroCall::MixBoolean => ctx.add_expression( + Expression::Select { + condition: args[2], + accept: args[1], + reject: args[0], + }, + Span::default(), + )?, + MacroCall::Clamp(size) => { + ctx.implicit_splat(&mut args[1], meta, size)?; + ctx.implicit_splat(&mut args[2], meta, size)?; + + ctx.add_expression( + Expression::Math { + fun: MathFunction::Clamp, + arg: args[0], + arg1: args.get(1).copied(), + arg2: args.get(2).copied(), + arg3: args.get(3).copied(), + }, + Span::default(), + )? + } + MacroCall::BitCast(kind) => ctx.add_expression( + Expression::As { + expr: args[0], + kind, + convert: None, + }, + Span::default(), + )?, + MacroCall::Derivate(axis, ctrl) => ctx.add_expression( + Expression::Derivative { + axis, + ctrl, + expr: args[0], + }, + Span::default(), + )?, + MacroCall::Barrier => { + ctx.emit_restart(); + ctx.body + .push(crate::Statement::Barrier(crate::Barrier::all()), meta); + return Ok(None); + } + MacroCall::SmoothStep { splatted } => { + ctx.implicit_splat(&mut args[0], meta, splatted)?; + ctx.implicit_splat(&mut args[1], meta, splatted)?; + + ctx.add_expression( + Expression::Math { + fun: MathFunction::SmoothStep, + arg: args[0], + arg1: args.get(1).copied(), + arg2: args.get(2).copied(), + arg3: None, + }, + Span::default(), + )? + } + })) + } +} + +fn texture_call( + ctx: &mut Context, + image: Handle<Expression>, + level: SampleLevel, + comps: CoordComponents, + offset: Option<Handle<Expression>>, + meta: Span, +) -> Result<Handle<Expression>> { + if let Some(sampler) = ctx.samplers.get(&image).copied() { + let mut array_index = comps.array_index; + + if let Some(ref mut array_index_expr) = array_index { + ctx.conversion(array_index_expr, meta, Scalar::I32)?; + } + + Ok(ctx.add_expression( + Expression::ImageSample { + image, + sampler, + gather: None, //TODO + coordinate: comps.coordinate, + array_index, + offset, + level, + depth_ref: comps.depth_ref, + }, + meta, + )?) + } else { + Err(Error { + kind: ErrorKind::SemanticError("Bad call".into()), + meta, + }) + } +} + +/// Helper struct for texture calls with the separate components from the vector argument +/// +/// Obtained by calling [`coordinate_components`](Frontend::coordinate_components) +#[derive(Debug)] +struct CoordComponents { + coordinate: Handle<Expression>, + depth_ref: Option<Handle<Expression>>, + array_index: Option<Handle<Expression>>, + used_extra: bool, +} + +impl Frontend { + /// Helper function for texture calls, splits the vector argument into it's components + fn coordinate_components( + &mut self, + ctx: &mut Context, + image: Handle<Expression>, + coord: Handle<Expression>, + extra: Option<Handle<Expression>>, + meta: Span, + ) -> Result<CoordComponents> { + if let TypeInner::Image { + dim, + arrayed, + class, + } = *ctx.resolve_type(image, meta)? + { + let image_size = match dim { + Dim::D1 => None, + Dim::D2 => Some(VectorSize::Bi), + Dim::D3 => Some(VectorSize::Tri), + Dim::Cube => Some(VectorSize::Tri), + }; + let coord_size = match *ctx.resolve_type(coord, meta)? { + TypeInner::Vector { size, .. } => Some(size), + _ => None, + }; + let (shadow, storage) = match class { + ImageClass::Depth { .. } => (true, false), + ImageClass::Storage { .. } => (false, true), + ImageClass::Sampled { .. } => (false, false), + }; + + let coordinate = match (image_size, coord_size) { + (Some(size), Some(coord_s)) if size != coord_s => { + ctx.vector_resize(size, coord, Span::default())? + } + (None, Some(_)) => ctx.add_expression( + Expression::AccessIndex { + base: coord, + index: 0, + }, + Span::default(), + )?, + _ => coord, + }; + + let mut coord_index = image_size.map_or(1, |s| s as u32); + + let array_index = if arrayed && !(storage && dim == Dim::Cube) { + let index = coord_index; + coord_index += 1; + + Some(ctx.add_expression( + Expression::AccessIndex { base: coord, index }, + Span::default(), + )?) + } else { + None + }; + let mut used_extra = false; + let depth_ref = match shadow { + true => { + let index = coord_index; + + if index == 4 { + used_extra = true; + extra + } else { + Some(ctx.add_expression( + Expression::AccessIndex { base: coord, index }, + Span::default(), + )?) + } + } + false => None, + }; + + Ok(CoordComponents { + coordinate, + depth_ref, + array_index, + used_extra, + }) + } else { + self.errors.push(Error { + kind: ErrorKind::SemanticError("Type is not an image".into()), + meta, + }); + + Ok(CoordComponents { + coordinate: coord, + depth_ref: None, + array_index: None, + used_extra: false, + }) + } + } +} + +/// Helper function to cast a expression holding a sampled image to a +/// depth image. +pub fn sampled_to_depth( + ctx: &mut Context, + image: Handle<Expression>, + meta: Span, + errors: &mut Vec<Error>, +) { + // Get the a mutable type handle of the underlying image storage + let ty = match ctx[image] { + Expression::GlobalVariable(handle) => &mut ctx.module.global_variables.get_mut(handle).ty, + Expression::FunctionArgument(i) => { + // Mark the function argument as carrying a depth texture + ctx.parameters_info[i as usize].depth = true; + // NOTE: We need to later also change the parameter type + &mut ctx.arguments[i as usize].ty + } + _ => { + // Only globals and function arguments are allowed to carry an image + return errors.push(Error { + kind: ErrorKind::SemanticError("Not a valid texture expression".into()), + meta, + }); + } + }; + + match ctx.module.types[*ty].inner { + // Update the image class to depth in case it already isn't + TypeInner::Image { + class, + dim, + arrayed, + } => match class { + ImageClass::Sampled { multi, .. } => { + *ty = ctx.module.types.insert( + Type { + name: None, + inner: TypeInner::Image { + dim, + arrayed, + class: ImageClass::Depth { multi }, + }, + }, + Span::default(), + ) + } + ImageClass::Depth { .. } => {} + // Other image classes aren't allowed to be transformed to depth + ImageClass::Storage { .. } => errors.push(Error { + kind: ErrorKind::SemanticError("Not a texture".into()), + meta, + }), + }, + _ => errors.push(Error { + kind: ErrorKind::SemanticError("Not a texture".into()), + meta, + }), + }; + + // Copy the handle to allow borrowing the `ctx` again + let ty = *ty; + + // If the image was passed through a function argument we also need to change + // the corresponding parameter + if let Expression::FunctionArgument(i) = ctx[image] { + ctx.parameters[i as usize] = ty; + } +} + +bitflags::bitflags! { + /// Influences the operation `texture_args_generator` + struct TextureArgsOptions: u32 { + /// Generates multisampled variants of images + const MULTI = 1 << 0; + /// Generates shadow variants of images + const SHADOW = 1 << 1; + /// Generates standard images + const STANDARD = 1 << 2; + /// Generates cube arrayed images + const CUBE_ARRAY = 1 << 3; + /// Generates cube arrayed images + const D2_MULTI_ARRAY = 1 << 4; + } +} + +impl From<BuiltinVariations> for TextureArgsOptions { + fn from(variations: BuiltinVariations) -> Self { + let mut options = TextureArgsOptions::empty(); + if variations.contains(BuiltinVariations::STANDARD) { + options |= TextureArgsOptions::STANDARD + } + if variations.contains(BuiltinVariations::CUBE_TEXTURES_ARRAY) { + options |= TextureArgsOptions::CUBE_ARRAY + } + if variations.contains(BuiltinVariations::D2_MULTI_TEXTURES_ARRAY) { + options |= TextureArgsOptions::D2_MULTI_ARRAY + } + options + } +} + +/// Helper function to generate the image components for texture/image builtins +/// +/// Calls the passed function `f` with: +/// ```text +/// f(ScalarKind, ImageDimension, arrayed, multi, shadow) +/// ``` +/// +/// `options` controls extra image variants generation like multisampling and depth, +/// see the struct documentation +fn texture_args_generator( + options: TextureArgsOptions, + mut f: impl FnMut(crate::ScalarKind, Dim, bool, bool, bool), +) { + for kind in [Sk::Float, Sk::Uint, Sk::Sint].iter().copied() { + for dim in [Dim::D1, Dim::D2, Dim::D3, Dim::Cube].iter().copied() { + for arrayed in [false, true].iter().copied() { + if dim == Dim::Cube && arrayed { + if !options.contains(TextureArgsOptions::CUBE_ARRAY) { + continue; + } + } else if Dim::D2 == dim + && options.contains(TextureArgsOptions::MULTI) + && arrayed + && options.contains(TextureArgsOptions::D2_MULTI_ARRAY) + { + // multisampling for sampler2DMSArray + f(kind, dim, arrayed, true, false); + } else if !options.contains(TextureArgsOptions::STANDARD) { + continue; + } + + f(kind, dim, arrayed, false, false); + + // 3D images can't be neither arrayed nor shadow + // so we break out early, this way arrayed will always + // be false and we won't hit the shadow branch + if let Dim::D3 = dim { + break; + } + + if Dim::D2 == dim && options.contains(TextureArgsOptions::MULTI) && !arrayed { + // multisampling + f(kind, dim, arrayed, true, false); + } + + if Sk::Float == kind && options.contains(TextureArgsOptions::SHADOW) { + // shadow + f(kind, dim, arrayed, false, true); + } + } + } + } +} + +/// Helper functions used to convert from a image dimension into a integer representing the +/// number of components needed for the coordinates vector (1 means scalar instead of vector) +const fn image_dims_to_coords_size(dim: Dim) -> usize { + match dim { + Dim::D1 => 1, + Dim::D2 => 2, + _ => 3, + } +} diff --git a/third_party/rust/naga/src/front/glsl/context.rs b/third_party/rust/naga/src/front/glsl/context.rs new file mode 100644 index 0000000000..f26c57965d --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/context.rs @@ -0,0 +1,1506 @@ +use super::{ + ast::{ + GlobalLookup, GlobalLookupKind, HirExpr, HirExprKind, ParameterInfo, ParameterQualifier, + VariableReference, + }, + error::{Error, ErrorKind}, + types::{scalar_components, type_power}, + Frontend, Result, +}; +use crate::{ + front::Typifier, proc::Emitter, AddressSpace, Arena, BinaryOperator, Block, Expression, + FastHashMap, FunctionArgument, Handle, Literal, LocalVariable, RelationalFunction, Scalar, + Span, Statement, Type, TypeInner, VectorSize, +}; +use std::ops::Index; + +/// The position at which an expression is, used while lowering +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub enum ExprPos { + /// The expression is in the left hand side of an assignment + Lhs, + /// The expression is in the right hand side of an assignment + Rhs, + /// The expression is an array being indexed, needed to allow constant + /// arrays to be dynamically indexed + AccessBase { + /// The index is a constant + constant_index: bool, + }, +} + +impl ExprPos { + /// Returns an lhs position if the current position is lhs otherwise AccessBase + const fn maybe_access_base(&self, constant_index: bool) -> Self { + match *self { + ExprPos::Lhs + | ExprPos::AccessBase { + constant_index: false, + } => *self, + _ => ExprPos::AccessBase { constant_index }, + } + } +} + +#[derive(Debug)] +pub struct Context<'a> { + pub expressions: Arena<Expression>, + pub locals: Arena<LocalVariable>, + + /// The [`FunctionArgument`]s for the final [`crate::Function`]. + /// + /// Parameters with the `out` and `inout` qualifiers have [`Pointer`] types + /// here. For example, an `inout vec2 a` argument would be a [`Pointer`] to + /// a [`Vector`]. + /// + /// [`Pointer`]: crate::TypeInner::Pointer + /// [`Vector`]: crate::TypeInner::Vector + pub arguments: Vec<FunctionArgument>, + + /// The parameter types given in the source code. + /// + /// The `out` and `inout` qualifiers don't affect the types that appear + /// here. For example, an `inout vec2 a` argument would simply be a + /// [`Vector`], not a pointer to one. + /// + /// [`Vector`]: crate::TypeInner::Vector + pub parameters: Vec<Handle<Type>>, + pub parameters_info: Vec<ParameterInfo>, + + pub symbol_table: crate::front::SymbolTable<String, VariableReference>, + pub samplers: FastHashMap<Handle<Expression>, Handle<Expression>>, + + pub const_typifier: Typifier, + pub typifier: Typifier, + emitter: Emitter, + stmt_ctx: Option<StmtContext>, + pub body: Block, + pub module: &'a mut crate::Module, + pub is_const: bool, + /// Tracks the constness of `Expression`s residing in `self.expressions` + pub expression_constness: crate::proc::ExpressionConstnessTracker, +} + +impl<'a> Context<'a> { + pub fn new(frontend: &Frontend, module: &'a mut crate::Module, is_const: bool) -> Result<Self> { + let mut this = Context { + expressions: Arena::new(), + locals: Arena::new(), + arguments: Vec::new(), + + parameters: Vec::new(), + parameters_info: Vec::new(), + + symbol_table: crate::front::SymbolTable::default(), + samplers: FastHashMap::default(), + + const_typifier: Typifier::new(), + typifier: Typifier::new(), + emitter: Emitter::default(), + stmt_ctx: Some(StmtContext::new()), + body: Block::new(), + module, + is_const: false, + expression_constness: crate::proc::ExpressionConstnessTracker::new(), + }; + + this.emit_start(); + + for &(ref name, lookup) in frontend.global_variables.iter() { + this.add_global(name, lookup)? + } + this.is_const = is_const; + + Ok(this) + } + + pub fn new_body<F>(&mut self, cb: F) -> Result<Block> + where + F: FnOnce(&mut Self) -> Result<()>, + { + self.new_body_with_ret(cb).map(|(b, _)| b) + } + + pub fn new_body_with_ret<F, R>(&mut self, cb: F) -> Result<(Block, R)> + where + F: FnOnce(&mut Self) -> Result<R>, + { + self.emit_restart(); + let old_body = std::mem::replace(&mut self.body, Block::new()); + let res = cb(self); + self.emit_restart(); + let new_body = std::mem::replace(&mut self.body, old_body); + res.map(|r| (new_body, r)) + } + + pub fn with_body<F>(&mut self, body: Block, cb: F) -> Result<Block> + where + F: FnOnce(&mut Self) -> Result<()>, + { + self.emit_restart(); + let old_body = std::mem::replace(&mut self.body, body); + let res = cb(self); + self.emit_restart(); + let body = std::mem::replace(&mut self.body, old_body); + res.map(|_| body) + } + + pub fn add_global( + &mut self, + name: &str, + GlobalLookup { + kind, + entry_arg, + mutable, + }: GlobalLookup, + ) -> Result<()> { + let (expr, load, constant) = match kind { + GlobalLookupKind::Variable(v) => { + let span = self.module.global_variables.get_span(v); + ( + self.add_expression(Expression::GlobalVariable(v), span)?, + self.module.global_variables[v].space != AddressSpace::Handle, + None, + ) + } + GlobalLookupKind::BlockSelect(handle, index) => { + let span = self.module.global_variables.get_span(handle); + let base = self.add_expression(Expression::GlobalVariable(handle), span)?; + let expr = self.add_expression(Expression::AccessIndex { base, index }, span)?; + + ( + expr, + { + let ty = self.module.global_variables[handle].ty; + + match self.module.types[ty].inner { + TypeInner::Struct { ref members, .. } => { + if let TypeInner::Array { + size: crate::ArraySize::Dynamic, + .. + } = self.module.types[members[index as usize].ty].inner + { + false + } else { + true + } + } + _ => true, + } + }, + None, + ) + } + GlobalLookupKind::Constant(v, ty) => { + let span = self.module.constants.get_span(v); + ( + self.add_expression(Expression::Constant(v), span)?, + false, + Some((v, ty)), + ) + } + }; + + let var = VariableReference { + expr, + load, + mutable, + constant, + entry_arg, + }; + + self.symbol_table.add(name.into(), var); + + Ok(()) + } + + /// Starts the expression emitter + /// + /// # Panics + /// + /// - If called twice in a row without calling [`emit_end`][Self::emit_end]. + #[inline] + pub fn emit_start(&mut self) { + self.emitter.start(&self.expressions) + } + + /// Emits all the expressions captured by the emitter to the current body + /// + /// # Panics + /// + /// - If called before calling [`emit_start`]. + /// - If called twice in a row without calling [`emit_start`]. + /// + /// [`emit_start`]: Self::emit_start + pub fn emit_end(&mut self) { + self.body.extend(self.emitter.finish(&self.expressions)) + } + + /// Emits all the expressions captured by the emitter to the current body + /// and starts the emitter again + /// + /// # Panics + /// + /// - If called before calling [`emit_start`][Self::emit_start]. + pub fn emit_restart(&mut self) { + self.emit_end(); + self.emit_start() + } + + pub fn add_expression(&mut self, expr: Expression, meta: Span) -> Result<Handle<Expression>> { + let mut eval = if self.is_const { + crate::proc::ConstantEvaluator::for_glsl_module(self.module) + } else { + crate::proc::ConstantEvaluator::for_glsl_function( + self.module, + &mut self.expressions, + &mut self.expression_constness, + &mut self.emitter, + &mut self.body, + ) + }; + + let res = eval.try_eval_and_append(&expr, meta).map_err(|e| Error { + kind: e.into(), + meta, + }); + + match res { + Ok(expr) => Ok(expr), + Err(e) => { + if self.is_const { + Err(e) + } else { + let needs_pre_emit = expr.needs_pre_emit(); + if needs_pre_emit { + self.body.extend(self.emitter.finish(&self.expressions)); + } + let h = self.expressions.append(expr, meta); + if needs_pre_emit { + self.emitter.start(&self.expressions); + } + Ok(h) + } + } + } + } + + /// Add variable to current scope + /// + /// Returns a variable if a variable with the same name was already defined, + /// otherwise returns `None` + pub fn add_local_var( + &mut self, + name: String, + expr: Handle<Expression>, + mutable: bool, + ) -> Option<VariableReference> { + let var = VariableReference { + expr, + load: true, + mutable, + constant: None, + entry_arg: None, + }; + + self.symbol_table.add(name, var) + } + + /// Add function argument to current scope + pub fn add_function_arg( + &mut self, + name_meta: Option<(String, Span)>, + ty: Handle<Type>, + qualifier: ParameterQualifier, + ) -> Result<()> { + let index = self.arguments.len(); + let mut arg = FunctionArgument { + name: name_meta.as_ref().map(|&(ref name, _)| name.clone()), + ty, + binding: None, + }; + self.parameters.push(ty); + + let opaque = match self.module.types[ty].inner { + TypeInner::Image { .. } | TypeInner::Sampler { .. } => true, + _ => false, + }; + + if qualifier.is_lhs() { + let span = self.module.types.get_span(arg.ty); + arg.ty = self.module.types.insert( + Type { + name: None, + inner: TypeInner::Pointer { + base: arg.ty, + space: AddressSpace::Function, + }, + }, + span, + ) + } + + self.arguments.push(arg); + + self.parameters_info.push(ParameterInfo { + qualifier, + depth: false, + }); + + if let Some((name, meta)) = name_meta { + let expr = self.add_expression(Expression::FunctionArgument(index as u32), meta)?; + let mutable = qualifier != ParameterQualifier::Const && !opaque; + let load = qualifier.is_lhs(); + + let var = if mutable && !load { + let handle = self.locals.append( + LocalVariable { + name: Some(name.clone()), + ty, + init: None, + }, + meta, + ); + let local_expr = self.add_expression(Expression::LocalVariable(handle), meta)?; + + self.emit_restart(); + + self.body.push( + Statement::Store { + pointer: local_expr, + value: expr, + }, + meta, + ); + + VariableReference { + expr: local_expr, + load: true, + mutable, + constant: None, + entry_arg: None, + } + } else { + VariableReference { + expr, + load, + mutable, + constant: None, + entry_arg: None, + } + }; + + self.symbol_table.add(name, var); + } + + Ok(()) + } + + /// Returns a [`StmtContext`] to be used in parsing and lowering + /// + /// # Panics + /// + /// - If more than one [`StmtContext`] are active at the same time or if the + /// previous call didn't use it in lowering. + #[must_use] + pub fn stmt_ctx(&mut self) -> StmtContext { + self.stmt_ctx.take().unwrap() + } + + /// Lowers a [`HirExpr`] which might produce a [`Expression`]. + /// + /// consumes a [`StmtContext`] returning it to the context so that it can be + /// used again later. + pub fn lower( + &mut self, + mut stmt: StmtContext, + frontend: &mut Frontend, + expr: Handle<HirExpr>, + pos: ExprPos, + ) -> Result<(Option<Handle<Expression>>, Span)> { + let res = self.lower_inner(&stmt, frontend, expr, pos); + + stmt.hir_exprs.clear(); + self.stmt_ctx = Some(stmt); + + res + } + + /// Similar to [`lower`](Self::lower) but returns an error if the expression + /// returns void (ie. doesn't produce a [`Expression`]). + /// + /// consumes a [`StmtContext`] returning it to the context so that it can be + /// used again later. + pub fn lower_expect( + &mut self, + mut stmt: StmtContext, + frontend: &mut Frontend, + expr: Handle<HirExpr>, + pos: ExprPos, + ) -> Result<(Handle<Expression>, Span)> { + let res = self.lower_expect_inner(&stmt, frontend, expr, pos); + + stmt.hir_exprs.clear(); + self.stmt_ctx = Some(stmt); + + res + } + + /// internal implementation of [`lower_expect`](Self::lower_expect) + /// + /// this method is only public because it's used in + /// [`function_call`](Frontend::function_call), unless you know what + /// you're doing use [`lower_expect`](Self::lower_expect) + pub fn lower_expect_inner( + &mut self, + stmt: &StmtContext, + frontend: &mut Frontend, + expr: Handle<HirExpr>, + pos: ExprPos, + ) -> Result<(Handle<Expression>, Span)> { + let (maybe_expr, meta) = self.lower_inner(stmt, frontend, expr, pos)?; + + let expr = match maybe_expr { + Some(e) => e, + None => { + return Err(Error { + kind: ErrorKind::SemanticError("Expression returns void".into()), + meta, + }) + } + }; + + Ok((expr, meta)) + } + + fn lower_store( + &mut self, + pointer: Handle<Expression>, + value: Handle<Expression>, + meta: Span, + ) -> Result<()> { + if let Expression::Swizzle { + size, + mut vector, + pattern, + } = self.expressions[pointer] + { + // Stores to swizzled values are not directly supported, + // lower them as series of per-component stores. + let size = match size { + VectorSize::Bi => 2, + VectorSize::Tri => 3, + VectorSize::Quad => 4, + }; + + if let Expression::Load { pointer } = self.expressions[vector] { + vector = pointer; + } + + #[allow(clippy::needless_range_loop)] + for index in 0..size { + let dst = self.add_expression( + Expression::AccessIndex { + base: vector, + index: pattern[index].index(), + }, + meta, + )?; + let src = self.add_expression( + Expression::AccessIndex { + base: value, + index: index as u32, + }, + meta, + )?; + + self.emit_restart(); + + self.body.push( + Statement::Store { + pointer: dst, + value: src, + }, + meta, + ); + } + } else { + self.emit_restart(); + + self.body.push(Statement::Store { pointer, value }, meta); + } + + Ok(()) + } + + /// Internal implementation of [`lower`](Self::lower) + fn lower_inner( + &mut self, + stmt: &StmtContext, + frontend: &mut Frontend, + expr: Handle<HirExpr>, + pos: ExprPos, + ) -> Result<(Option<Handle<Expression>>, Span)> { + let HirExpr { ref kind, meta } = stmt.hir_exprs[expr]; + + log::debug!("Lowering {:?} (kind {:?}, pos {:?})", expr, kind, pos); + + let handle = match *kind { + HirExprKind::Access { base, index } => { + let (index, _) = self.lower_expect_inner(stmt, frontend, index, ExprPos::Rhs)?; + let maybe_constant_index = match pos { + // Don't try to generate `AccessIndex` if in a LHS position, since it + // wouldn't produce a pointer. + ExprPos::Lhs => None, + _ => self + .module + .to_ctx() + .eval_expr_to_u32_from(index, &self.expressions) + .ok(), + }; + + let base = self + .lower_expect_inner( + stmt, + frontend, + base, + pos.maybe_access_base(maybe_constant_index.is_some()), + )? + .0; + + let pointer = maybe_constant_index + .map(|index| self.add_expression(Expression::AccessIndex { base, index }, meta)) + .unwrap_or_else(|| { + self.add_expression(Expression::Access { base, index }, meta) + })?; + + if ExprPos::Rhs == pos { + let resolved = self.resolve_type(pointer, meta)?; + if resolved.pointer_space().is_some() { + return Ok(( + Some(self.add_expression(Expression::Load { pointer }, meta)?), + meta, + )); + } + } + + pointer + } + HirExprKind::Select { base, ref field } => { + let base = self.lower_expect_inner(stmt, frontend, base, pos)?.0; + + frontend.field_selection(self, pos, base, field, meta)? + } + HirExprKind::Literal(literal) if pos != ExprPos::Lhs => { + self.add_expression(Expression::Literal(literal), meta)? + } + HirExprKind::Binary { left, op, right } if pos != ExprPos::Lhs => { + let (mut left, left_meta) = + self.lower_expect_inner(stmt, frontend, left, ExprPos::Rhs)?; + let (mut right, right_meta) = + self.lower_expect_inner(stmt, frontend, right, ExprPos::Rhs)?; + + match op { + BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight => { + self.implicit_conversion(&mut right, right_meta, Scalar::U32)? + } + _ => self + .binary_implicit_conversion(&mut left, left_meta, &mut right, right_meta)?, + } + + self.typifier_grow(left, left_meta)?; + self.typifier_grow(right, right_meta)?; + + let left_inner = self.get_type(left); + let right_inner = self.get_type(right); + + match (left_inner, right_inner) { + ( + &TypeInner::Matrix { + columns: left_columns, + rows: left_rows, + scalar: left_scalar, + }, + &TypeInner::Matrix { + columns: right_columns, + rows: right_rows, + scalar: right_scalar, + }, + ) => { + let dimensions_ok = if op == BinaryOperator::Multiply { + left_columns == right_rows + } else { + left_columns == right_columns && left_rows == right_rows + }; + + // Check that the two arguments have the same dimensions + if !dimensions_ok || left_scalar != right_scalar { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError( + format!( + "Cannot apply operation to {left_inner:?} and {right_inner:?}" + ) + .into(), + ), + meta, + }) + } + + match op { + BinaryOperator::Divide => { + // Naga IR doesn't support matrix division so we need to + // divide the columns individually and reassemble the matrix + let mut components = Vec::with_capacity(left_columns as usize); + + for index in 0..left_columns as u32 { + // Get the column vectors + let left_vector = self.add_expression( + Expression::AccessIndex { base: left, index }, + meta, + )?; + let right_vector = self.add_expression( + Expression::AccessIndex { base: right, index }, + meta, + )?; + + // Divide the vectors + let column = self.add_expression( + Expression::Binary { + op, + left: left_vector, + right: right_vector, + }, + meta, + )?; + + components.push(column) + } + + let ty = self.module.types.insert( + Type { + name: None, + inner: TypeInner::Matrix { + columns: left_columns, + rows: left_rows, + scalar: left_scalar, + }, + }, + Span::default(), + ); + + // Rebuild the matrix from the divided vectors + self.add_expression(Expression::Compose { ty, components }, meta)? + } + BinaryOperator::Equal | BinaryOperator::NotEqual => { + // Naga IR doesn't support matrix comparisons so we need to + // compare the columns individually and then fold them together + // + // The folding is done using a logical and for equality and + // a logical or for inequality + let equals = op == BinaryOperator::Equal; + + let (op, combine, fun) = match equals { + true => ( + BinaryOperator::Equal, + BinaryOperator::LogicalAnd, + RelationalFunction::All, + ), + false => ( + BinaryOperator::NotEqual, + BinaryOperator::LogicalOr, + RelationalFunction::Any, + ), + }; + + let mut root = None; + + for index in 0..left_columns as u32 { + // Get the column vectors + let left_vector = self.add_expression( + Expression::AccessIndex { base: left, index }, + meta, + )?; + let right_vector = self.add_expression( + Expression::AccessIndex { base: right, index }, + meta, + )?; + + let argument = self.add_expression( + Expression::Binary { + op, + left: left_vector, + right: right_vector, + }, + meta, + )?; + + // The result of comparing two vectors is a boolean vector + // so use a relational function like all to get a single + // boolean value + let compare = self.add_expression( + Expression::Relational { fun, argument }, + meta, + )?; + + // Fold the result + root = Some(match root { + Some(right) => self.add_expression( + Expression::Binary { + op: combine, + left: compare, + right, + }, + meta, + )?, + None => compare, + }); + } + + root.unwrap() + } + _ => { + self.add_expression(Expression::Binary { left, op, right }, meta)? + } + } + } + (&TypeInner::Vector { .. }, &TypeInner::Vector { .. }) => match op { + BinaryOperator::Equal | BinaryOperator::NotEqual => { + let equals = op == BinaryOperator::Equal; + + let (op, fun) = match equals { + true => (BinaryOperator::Equal, RelationalFunction::All), + false => (BinaryOperator::NotEqual, RelationalFunction::Any), + }; + + let argument = + self.add_expression(Expression::Binary { op, left, right }, meta)?; + + self.add_expression(Expression::Relational { fun, argument }, meta)? + } + _ => self.add_expression(Expression::Binary { left, op, right }, meta)?, + }, + (&TypeInner::Vector { size, .. }, &TypeInner::Scalar { .. }) => match op { + BinaryOperator::Add + | BinaryOperator::Subtract + | BinaryOperator::Divide + | BinaryOperator::And + | BinaryOperator::ExclusiveOr + | BinaryOperator::InclusiveOr + | BinaryOperator::ShiftLeft + | BinaryOperator::ShiftRight => { + let scalar_vector = self + .add_expression(Expression::Splat { size, value: right }, meta)?; + + self.add_expression( + Expression::Binary { + op, + left, + right: scalar_vector, + }, + meta, + )? + } + _ => self.add_expression(Expression::Binary { left, op, right }, meta)?, + }, + (&TypeInner::Scalar { .. }, &TypeInner::Vector { size, .. }) => match op { + BinaryOperator::Add + | BinaryOperator::Subtract + | BinaryOperator::Divide + | BinaryOperator::And + | BinaryOperator::ExclusiveOr + | BinaryOperator::InclusiveOr => { + let scalar_vector = + self.add_expression(Expression::Splat { size, value: left }, meta)?; + + self.add_expression( + Expression::Binary { + op, + left: scalar_vector, + right, + }, + meta, + )? + } + _ => self.add_expression(Expression::Binary { left, op, right }, meta)?, + }, + ( + &TypeInner::Scalar(left_scalar), + &TypeInner::Matrix { + rows, + columns, + scalar: right_scalar, + }, + ) => { + // Check that the two arguments have the same scalar type + if left_scalar != right_scalar { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError( + format!( + "Cannot apply operation to {left_inner:?} and {right_inner:?}" + ) + .into(), + ), + meta, + }) + } + + match op { + BinaryOperator::Divide + | BinaryOperator::Add + | BinaryOperator::Subtract => { + // Naga IR doesn't support all matrix by scalar operations so + // we need for some to turn the scalar into a vector by + // splatting it and then for each column vector apply the + // operation and finally reconstruct the matrix + let scalar_vector = self.add_expression( + Expression::Splat { + size: rows, + value: left, + }, + meta, + )?; + + let mut components = Vec::with_capacity(columns as usize); + + for index in 0..columns as u32 { + // Get the column vector + let matrix_column = self.add_expression( + Expression::AccessIndex { base: right, index }, + meta, + )?; + + // Apply the operation to the splatted vector and + // the column vector + let column = self.add_expression( + Expression::Binary { + op, + left: scalar_vector, + right: matrix_column, + }, + meta, + )?; + + components.push(column) + } + + let ty = self.module.types.insert( + Type { + name: None, + inner: TypeInner::Matrix { + columns, + rows, + scalar: left_scalar, + }, + }, + Span::default(), + ); + + // Rebuild the matrix from the operation result vectors + self.add_expression(Expression::Compose { ty, components }, meta)? + } + _ => { + self.add_expression(Expression::Binary { left, op, right }, meta)? + } + } + } + ( + &TypeInner::Matrix { + rows, + columns, + scalar: left_scalar, + }, + &TypeInner::Scalar(right_scalar), + ) => { + // Check that the two arguments have the same scalar type + if left_scalar != right_scalar { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError( + format!( + "Cannot apply operation to {left_inner:?} and {right_inner:?}" + ) + .into(), + ), + meta, + }) + } + + match op { + BinaryOperator::Divide + | BinaryOperator::Add + | BinaryOperator::Subtract => { + // Naga IR doesn't support all matrix by scalar operations so + // we need for some to turn the scalar into a vector by + // splatting it and then for each column vector apply the + // operation and finally reconstruct the matrix + + let scalar_vector = self.add_expression( + Expression::Splat { + size: rows, + value: right, + }, + meta, + )?; + + let mut components = Vec::with_capacity(columns as usize); + + for index in 0..columns as u32 { + // Get the column vector + let matrix_column = self.add_expression( + Expression::AccessIndex { base: left, index }, + meta, + )?; + + // Apply the operation to the splatted vector and + // the column vector + let column = self.add_expression( + Expression::Binary { + op, + left: matrix_column, + right: scalar_vector, + }, + meta, + )?; + + components.push(column) + } + + let ty = self.module.types.insert( + Type { + name: None, + inner: TypeInner::Matrix { + columns, + rows, + scalar: left_scalar, + }, + }, + Span::default(), + ); + + // Rebuild the matrix from the operation result vectors + self.add_expression(Expression::Compose { ty, components }, meta)? + } + _ => { + self.add_expression(Expression::Binary { left, op, right }, meta)? + } + } + } + _ => self.add_expression(Expression::Binary { left, op, right }, meta)?, + } + } + HirExprKind::Unary { op, expr } if pos != ExprPos::Lhs => { + let expr = self + .lower_expect_inner(stmt, frontend, expr, ExprPos::Rhs)? + .0; + + self.add_expression(Expression::Unary { op, expr }, meta)? + } + HirExprKind::Variable(ref var) => match pos { + ExprPos::Lhs => { + if !var.mutable { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError( + "Variable cannot be used in LHS position".into(), + ), + meta, + }) + } + + var.expr + } + ExprPos::AccessBase { constant_index } => { + // If the index isn't constant all accesses backed by a constant base need + // to be done through a proxy local variable, since constants have a non + // pointer type which is required for dynamic indexing + if !constant_index { + if let Some((constant, ty)) = var.constant { + let init = self + .add_expression(Expression::Constant(constant), Span::default())?; + let local = self.locals.append( + LocalVariable { + name: None, + ty, + init: Some(init), + }, + Span::default(), + ); + + self.add_expression(Expression::LocalVariable(local), Span::default())? + } else { + var.expr + } + } else { + var.expr + } + } + _ if var.load => { + self.add_expression(Expression::Load { pointer: var.expr }, meta)? + } + ExprPos::Rhs => { + if let Some((constant, _)) = self.is_const.then_some(var.constant).flatten() { + self.add_expression(Expression::Constant(constant), meta)? + } else { + var.expr + } + } + }, + HirExprKind::Call(ref call) if pos != ExprPos::Lhs => { + let maybe_expr = frontend.function_or_constructor_call( + self, + stmt, + call.kind.clone(), + &call.args, + meta, + )?; + return Ok((maybe_expr, meta)); + } + // `HirExprKind::Conditional` represents the ternary operator in glsl (`:?`) + // + // The ternary operator is defined to only evaluate one of the two possible + // expressions which means that it's behavior is that of an `if` statement, + // and it's merely syntactic sugar for it. + HirExprKind::Conditional { + condition, + accept, + reject, + } if ExprPos::Lhs != pos => { + // Given an expression `a ? b : c`, we need to produce a Naga + // statement roughly like: + // + // var temp; + // if a { + // temp = convert(b); + // } else { + // temp = convert(c); + // } + // + // where `convert` stands for type conversions to bring `b` and `c` to + // the same type, and then use `temp` to represent the value of the whole + // conditional expression in subsequent code. + + // Lower the condition first to the current bodyy + let condition = self + .lower_expect_inner(stmt, frontend, condition, ExprPos::Rhs)? + .0; + + let (mut accept_body, (mut accept, accept_meta)) = + self.new_body_with_ret(|ctx| { + // Lower the `true` branch + ctx.lower_expect_inner(stmt, frontend, accept, pos) + })?; + + let (mut reject_body, (mut reject, reject_meta)) = + self.new_body_with_ret(|ctx| { + // Lower the `false` branch + ctx.lower_expect_inner(stmt, frontend, reject, pos) + })?; + + // We need to do some custom implicit conversions since the two target expressions + // are in different bodies + if let (Some((accept_power, accept_scalar)), Some((reject_power, reject_scalar))) = ( + // Get the components of both branches and calculate the type power + self.expr_scalar_components(accept, accept_meta)? + .and_then(|scalar| Some((type_power(scalar)?, scalar))), + self.expr_scalar_components(reject, reject_meta)? + .and_then(|scalar| Some((type_power(scalar)?, scalar))), + ) { + match accept_power.cmp(&reject_power) { + std::cmp::Ordering::Less => { + accept_body = self.with_body(accept_body, |ctx| { + ctx.conversion(&mut accept, accept_meta, reject_scalar)?; + Ok(()) + })?; + } + std::cmp::Ordering::Equal => {} + std::cmp::Ordering::Greater => { + reject_body = self.with_body(reject_body, |ctx| { + ctx.conversion(&mut reject, reject_meta, accept_scalar)?; + Ok(()) + })?; + } + } + } + + // We need to get the type of the resulting expression to create the local, + // this must be done after implicit conversions to ensure both branches have + // the same type. + let ty = self.resolve_type_handle(accept, accept_meta)?; + + // Add the local that will hold the result of our conditional + let local = self.locals.append( + LocalVariable { + name: None, + ty, + init: None, + }, + meta, + ); + + let local_expr = self.add_expression(Expression::LocalVariable(local), meta)?; + + // Add to each the store to the result variable + accept_body.push( + Statement::Store { + pointer: local_expr, + value: accept, + }, + accept_meta, + ); + reject_body.push( + Statement::Store { + pointer: local_expr, + value: reject, + }, + reject_meta, + ); + + // Finally add the `If` to the main body with the `condition` we lowered + // earlier and the branches we prepared. + self.body.push( + Statement::If { + condition, + accept: accept_body, + reject: reject_body, + }, + meta, + ); + + // Note: `Expression::Load` must be emitted before it's used so make + // sure the emitter is active here. + self.add_expression( + Expression::Load { + pointer: local_expr, + }, + meta, + )? + } + HirExprKind::Assign { tgt, value } if ExprPos::Lhs != pos => { + let (pointer, ptr_meta) = + self.lower_expect_inner(stmt, frontend, tgt, ExprPos::Lhs)?; + let (mut value, value_meta) = + self.lower_expect_inner(stmt, frontend, value, ExprPos::Rhs)?; + + let ty = match *self.resolve_type(pointer, ptr_meta)? { + TypeInner::Pointer { base, .. } => &self.module.types[base].inner, + ref ty => ty, + }; + + if let Some(scalar) = scalar_components(ty) { + self.implicit_conversion(&mut value, value_meta, scalar)?; + } + + self.lower_store(pointer, value, meta)?; + + value + } + HirExprKind::PrePostfix { op, postfix, expr } if ExprPos::Lhs != pos => { + let (pointer, _) = self.lower_expect_inner(stmt, frontend, expr, ExprPos::Lhs)?; + let left = if let Expression::Swizzle { .. } = self.expressions[pointer] { + pointer + } else { + self.add_expression(Expression::Load { pointer }, meta)? + }; + + let res = match *self.resolve_type(left, meta)? { + TypeInner::Scalar(scalar) => { + let ty = TypeInner::Scalar(scalar); + Literal::one(scalar).map(|i| (ty, i, None, None)) + } + TypeInner::Vector { size, scalar } => { + let ty = TypeInner::Vector { size, scalar }; + Literal::one(scalar).map(|i| (ty, i, Some(size), None)) + } + TypeInner::Matrix { + columns, + rows, + scalar, + } => { + let ty = TypeInner::Matrix { + columns, + rows, + scalar, + }; + Literal::one(scalar).map(|i| (ty, i, Some(rows), Some(columns))) + } + _ => None, + }; + let (ty_inner, literal, rows, columns) = match res { + Some(res) => res, + None => { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError( + "Increment/decrement only works on scalar/vector/matrix".into(), + ), + meta, + }); + return Ok((Some(left), meta)); + } + }; + + let mut right = self.add_expression(Expression::Literal(literal), meta)?; + + // Glsl allows pre/postfixes operations on vectors and matrices, so if the + // target is either of them change the right side of the addition to be splatted + // to the same size as the target, furthermore if the target is a matrix + // use a composed matrix using the splatted value. + if let Some(size) = rows { + right = self.add_expression(Expression::Splat { size, value: right }, meta)?; + + if let Some(cols) = columns { + let ty = self.module.types.insert( + Type { + name: None, + inner: ty_inner, + }, + meta, + ); + + right = self.add_expression( + Expression::Compose { + ty, + components: std::iter::repeat(right).take(cols as usize).collect(), + }, + meta, + )?; + } + } + + let value = self.add_expression(Expression::Binary { op, left, right }, meta)?; + + self.lower_store(pointer, value, meta)?; + + if postfix { + left + } else { + value + } + } + HirExprKind::Method { + expr: object, + ref name, + ref args, + } if ExprPos::Lhs != pos => { + let args = args + .iter() + .map(|e| self.lower_expect_inner(stmt, frontend, *e, ExprPos::Rhs)) + .collect::<Result<Vec<_>>>()?; + match name.as_ref() { + "length" => { + if !args.is_empty() { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError( + ".length() doesn't take any arguments".into(), + ), + meta, + }); + } + let lowered_array = self.lower_expect_inner(stmt, frontend, object, pos)?.0; + let array_type = self.resolve_type(lowered_array, meta)?; + + match *array_type { + TypeInner::Array { + size: crate::ArraySize::Constant(size), + .. + } => { + let mut array_length = self.add_expression( + Expression::Literal(Literal::U32(size.get())), + meta, + )?; + self.forced_conversion(&mut array_length, meta, Scalar::I32)?; + array_length + } + // let the error be handled in type checking if it's not a dynamic array + _ => { + let mut array_length = self + .add_expression(Expression::ArrayLength(lowered_array), meta)?; + self.conversion(&mut array_length, meta, Scalar::I32)?; + array_length + } + } + } + _ => { + return Err(Error { + kind: ErrorKind::SemanticError( + format!("unknown method '{name}'").into(), + ), + meta, + }); + } + } + } + _ => { + return Err(Error { + kind: ErrorKind::SemanticError( + format!("{:?} cannot be in the left hand side", stmt.hir_exprs[expr]) + .into(), + ), + meta, + }) + } + }; + + log::trace!( + "Lowered {:?}\n\tKind = {:?}\n\tPos = {:?}\n\tResult = {:?}", + expr, + kind, + pos, + handle + ); + + Ok((Some(handle), meta)) + } + + pub fn expr_scalar_components( + &mut self, + expr: Handle<Expression>, + meta: Span, + ) -> Result<Option<Scalar>> { + let ty = self.resolve_type(expr, meta)?; + Ok(scalar_components(ty)) + } + + pub fn expr_power(&mut self, expr: Handle<Expression>, meta: Span) -> Result<Option<u32>> { + Ok(self + .expr_scalar_components(expr, meta)? + .and_then(type_power)) + } + + pub fn conversion( + &mut self, + expr: &mut Handle<Expression>, + meta: Span, + scalar: Scalar, + ) -> Result<()> { + *expr = self.add_expression( + Expression::As { + expr: *expr, + kind: scalar.kind, + convert: Some(scalar.width), + }, + meta, + )?; + + Ok(()) + } + + pub fn implicit_conversion( + &mut self, + expr: &mut Handle<Expression>, + meta: Span, + scalar: Scalar, + ) -> Result<()> { + if let (Some(tgt_power), Some(expr_power)) = + (type_power(scalar), self.expr_power(*expr, meta)?) + { + if tgt_power > expr_power { + self.conversion(expr, meta, scalar)?; + } + } + + Ok(()) + } + + pub fn forced_conversion( + &mut self, + expr: &mut Handle<Expression>, + meta: Span, + scalar: Scalar, + ) -> Result<()> { + if let Some(expr_scalar) = self.expr_scalar_components(*expr, meta)? { + if expr_scalar != scalar { + self.conversion(expr, meta, scalar)?; + } + } + + Ok(()) + } + + pub fn binary_implicit_conversion( + &mut self, + left: &mut Handle<Expression>, + left_meta: Span, + right: &mut Handle<Expression>, + right_meta: Span, + ) -> Result<()> { + let left_components = self.expr_scalar_components(*left, left_meta)?; + let right_components = self.expr_scalar_components(*right, right_meta)?; + + if let (Some((left_power, left_scalar)), Some((right_power, right_scalar))) = ( + left_components.and_then(|scalar| Some((type_power(scalar)?, scalar))), + right_components.and_then(|scalar| Some((type_power(scalar)?, scalar))), + ) { + match left_power.cmp(&right_power) { + std::cmp::Ordering::Less => { + self.conversion(left, left_meta, right_scalar)?; + } + std::cmp::Ordering::Equal => {} + std::cmp::Ordering::Greater => { + self.conversion(right, right_meta, left_scalar)?; + } + } + } + + Ok(()) + } + + pub fn implicit_splat( + &mut self, + expr: &mut Handle<Expression>, + meta: Span, + vector_size: Option<VectorSize>, + ) -> Result<()> { + let expr_type = self.resolve_type(*expr, meta)?; + + if let (&TypeInner::Scalar { .. }, Some(size)) = (expr_type, vector_size) { + *expr = self.add_expression(Expression::Splat { size, value: *expr }, meta)? + } + + Ok(()) + } + + pub fn vector_resize( + &mut self, + size: VectorSize, + vector: Handle<Expression>, + meta: Span, + ) -> Result<Handle<Expression>> { + self.add_expression( + Expression::Swizzle { + size, + vector, + pattern: crate::SwizzleComponent::XYZW, + }, + meta, + ) + } +} + +impl Index<Handle<Expression>> for Context<'_> { + type Output = Expression; + + fn index(&self, index: Handle<Expression>) -> &Self::Output { + if self.is_const { + &self.module.const_expressions[index] + } else { + &self.expressions[index] + } + } +} + +/// Helper struct passed when parsing expressions +/// +/// This struct should only be obtained through [`stmt_ctx`](Context::stmt_ctx) +/// and only one of these may be active at any time per context. +#[derive(Debug)] +pub struct StmtContext { + /// A arena of high level expressions which can be lowered through a + /// [`Context`] to Naga's [`Expression`]s + pub hir_exprs: Arena<HirExpr>, +} + +impl StmtContext { + const fn new() -> Self { + StmtContext { + hir_exprs: Arena::new(), + } + } +} diff --git a/third_party/rust/naga/src/front/glsl/error.rs b/third_party/rust/naga/src/front/glsl/error.rs new file mode 100644 index 0000000000..bd16ee30bc --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/error.rs @@ -0,0 +1,191 @@ +use super::token::TokenValue; +use crate::{proc::ConstantEvaluatorError, Span}; +use codespan_reporting::diagnostic::{Diagnostic, Label}; +use codespan_reporting::files::SimpleFile; +use codespan_reporting::term; +use pp_rs::token::PreprocessorError; +use std::borrow::Cow; +use termcolor::{NoColor, WriteColor}; +use thiserror::Error; + +fn join_with_comma(list: &[ExpectedToken]) -> String { + let mut string = "".to_string(); + for (i, val) in list.iter().enumerate() { + string.push_str(&val.to_string()); + match i { + i if i == list.len() - 1 => {} + i if i == list.len() - 2 => string.push_str(" or "), + _ => string.push_str(", "), + } + } + string +} + +/// One of the expected tokens returned in [`InvalidToken`](ErrorKind::InvalidToken). +#[derive(Clone, Debug, PartialEq)] +pub enum ExpectedToken { + /// A specific token was expected. + Token(TokenValue), + /// A type was expected. + TypeName, + /// An identifier was expected. + Identifier, + /// An integer literal was expected. + IntLiteral, + /// A float literal was expected. + FloatLiteral, + /// A boolean literal was expected. + BoolLiteral, + /// The end of file was expected. + Eof, +} +impl From<TokenValue> for ExpectedToken { + fn from(token: TokenValue) -> Self { + ExpectedToken::Token(token) + } +} +impl std::fmt::Display for ExpectedToken { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match *self { + ExpectedToken::Token(ref token) => write!(f, "{token:?}"), + ExpectedToken::TypeName => write!(f, "a type"), + ExpectedToken::Identifier => write!(f, "identifier"), + ExpectedToken::IntLiteral => write!(f, "integer literal"), + ExpectedToken::FloatLiteral => write!(f, "float literal"), + ExpectedToken::BoolLiteral => write!(f, "bool literal"), + ExpectedToken::Eof => write!(f, "end of file"), + } + } +} + +/// Information about the cause of an error. +#[derive(Clone, Debug, Error)] +#[cfg_attr(test, derive(PartialEq))] +pub enum ErrorKind { + /// Whilst parsing as encountered an unexpected EOF. + #[error("Unexpected end of file")] + EndOfFile, + /// The shader specified an unsupported or invalid profile. + #[error("Invalid profile: {0}")] + InvalidProfile(String), + /// The shader requested an unsupported or invalid version. + #[error("Invalid version: {0}")] + InvalidVersion(u64), + /// Whilst parsing an unexpected token was encountered. + /// + /// A list of expected tokens is also returned. + #[error("Expected {}, found {0:?}", join_with_comma(.1))] + InvalidToken(TokenValue, Vec<ExpectedToken>), + /// A specific feature is not yet implemented. + /// + /// To help prioritize work please open an issue in the github issue tracker + /// if none exist already or react to the already existing one. + #[error("Not implemented: {0}")] + NotImplemented(&'static str), + /// A reference to a variable that wasn't declared was used. + #[error("Unknown variable: {0}")] + UnknownVariable(String), + /// A reference to a type that wasn't declared was used. + #[error("Unknown type: {0}")] + UnknownType(String), + /// A reference to a non existent member of a type was made. + #[error("Unknown field: {0}")] + UnknownField(String), + /// An unknown layout qualifier was used. + /// + /// If the qualifier does exist please open an issue in the github issue tracker + /// if none exist already or react to the already existing one to help + /// prioritize work. + #[error("Unknown layout qualifier: {0}")] + UnknownLayoutQualifier(String), + /// Unsupported matrix of the form matCx2 + /// + /// Our IR expects matrices of the form matCx2 to have a stride of 8 however + /// matrices in the std140 layout have a stride of at least 16 + #[error("unsupported matrix of the form matCx2 in std140 block layout")] + UnsupportedMatrixTypeInStd140, + /// A variable with the same name already exists in the current scope. + #[error("Variable already declared: {0}")] + VariableAlreadyDeclared(String), + /// A semantic error was detected in the shader. + #[error("{0}")] + SemanticError(Cow<'static, str>), + /// An error was returned by the preprocessor. + #[error("{0:?}")] + PreprocessorError(PreprocessorError), + /// The parser entered an illegal state and exited + /// + /// This obviously is a bug and as such should be reported in the github issue tracker + #[error("Internal error: {0}")] + InternalError(&'static str), +} + +impl From<ConstantEvaluatorError> for ErrorKind { + fn from(err: ConstantEvaluatorError) -> Self { + ErrorKind::SemanticError(err.to_string().into()) + } +} + +/// Error returned during shader parsing. +#[derive(Clone, Debug, Error)] +#[error("{kind}")] +#[cfg_attr(test, derive(PartialEq))] +pub struct Error { + /// Holds the information about the error itself. + pub kind: ErrorKind, + /// Holds information about the range of the source code where the error happened. + pub meta: Span, +} + +/// A collection of errors returned during shader parsing. +#[derive(Clone, Debug)] +#[cfg_attr(test, derive(PartialEq))] +pub struct ParseError { + pub errors: Vec<Error>, +} + +impl ParseError { + pub fn emit_to_writer(&self, writer: &mut impl WriteColor, source: &str) { + self.emit_to_writer_with_path(writer, source, "glsl"); + } + + pub fn emit_to_writer_with_path(&self, writer: &mut impl WriteColor, source: &str, path: &str) { + let path = path.to_string(); + let files = SimpleFile::new(path, source); + let config = codespan_reporting::term::Config::default(); + + for err in &self.errors { + let mut diagnostic = Diagnostic::error().with_message(err.kind.to_string()); + + if let Some(range) = err.meta.to_range() { + diagnostic = diagnostic.with_labels(vec![Label::primary((), range)]); + } + + term::emit(writer, &config, &files, &diagnostic).expect("cannot write error"); + } + } + + pub fn emit_to_string(&self, source: &str) -> String { + let mut writer = NoColor::new(Vec::new()); + self.emit_to_writer(&mut writer, source); + String::from_utf8(writer.into_inner()).unwrap() + } +} + +impl std::fmt::Display for ParseError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + self.errors.iter().try_for_each(|e| write!(f, "{e:?}")) + } +} + +impl std::error::Error for ParseError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + None + } +} + +impl From<Vec<Error>> for ParseError { + fn from(errors: Vec<Error>) -> Self { + Self { errors } + } +} diff --git a/third_party/rust/naga/src/front/glsl/functions.rs b/third_party/rust/naga/src/front/glsl/functions.rs new file mode 100644 index 0000000000..df8cc8a30e --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/functions.rs @@ -0,0 +1,1602 @@ +use super::{ + ast::*, + builtins::{inject_builtin, sampled_to_depth}, + context::{Context, ExprPos, StmtContext}, + error::{Error, ErrorKind}, + types::scalar_components, + Frontend, Result, +}; +use crate::{ + front::glsl::types::type_power, proc::ensure_block_returns, AddressSpace, Block, EntryPoint, + Expression, Function, FunctionArgument, FunctionResult, Handle, Literal, LocalVariable, Scalar, + ScalarKind, Span, Statement, StructMember, Type, TypeInner, +}; +use std::iter; + +/// Struct detailing a store operation that must happen after a function call +struct ProxyWrite { + /// The store target + target: Handle<Expression>, + /// A pointer to read the value of the store + value: Handle<Expression>, + /// An optional conversion to be applied + convert: Option<Scalar>, +} + +impl Frontend { + pub(crate) fn function_or_constructor_call( + &mut self, + ctx: &mut Context, + stmt: &StmtContext, + fc: FunctionCallKind, + raw_args: &[Handle<HirExpr>], + meta: Span, + ) -> Result<Option<Handle<Expression>>> { + let args: Vec<_> = raw_args + .iter() + .map(|e| ctx.lower_expect_inner(stmt, self, *e, ExprPos::Rhs)) + .collect::<Result<_>>()?; + + match fc { + FunctionCallKind::TypeConstructor(ty) => { + if args.len() == 1 { + self.constructor_single(ctx, ty, args[0], meta).map(Some) + } else { + self.constructor_many(ctx, ty, args, meta).map(Some) + } + } + FunctionCallKind::Function(name) => { + self.function_call(ctx, stmt, name, args, raw_args, meta) + } + } + } + + fn constructor_single( + &mut self, + ctx: &mut Context, + ty: Handle<Type>, + (mut value, expr_meta): (Handle<Expression>, Span), + meta: Span, + ) -> Result<Handle<Expression>> { + let expr_type = ctx.resolve_type(value, expr_meta)?; + + let vector_size = match *expr_type { + TypeInner::Vector { size, .. } => Some(size), + _ => None, + }; + + let expr_is_bool = expr_type.scalar_kind() == Some(ScalarKind::Bool); + + // Special case: if casting from a bool, we need to use Select and not As. + match ctx.module.types[ty].inner.scalar() { + Some(result_scalar) if expr_is_bool && result_scalar.kind != ScalarKind::Bool => { + let result_scalar = Scalar { + width: 4, + ..result_scalar + }; + let l0 = Literal::zero(result_scalar).unwrap(); + let l1 = Literal::one(result_scalar).unwrap(); + let mut reject = ctx.add_expression(Expression::Literal(l0), expr_meta)?; + let mut accept = ctx.add_expression(Expression::Literal(l1), expr_meta)?; + + ctx.implicit_splat(&mut reject, meta, vector_size)?; + ctx.implicit_splat(&mut accept, meta, vector_size)?; + + let h = ctx.add_expression( + Expression::Select { + accept, + reject, + condition: value, + }, + expr_meta, + )?; + + return Ok(h); + } + _ => {} + } + + Ok(match ctx.module.types[ty].inner { + TypeInner::Vector { size, scalar } if vector_size.is_none() => { + ctx.forced_conversion(&mut value, expr_meta, scalar)?; + + if let TypeInner::Scalar { .. } = *ctx.resolve_type(value, expr_meta)? { + ctx.add_expression(Expression::Splat { size, value }, meta)? + } else { + self.vector_constructor(ctx, ty, size, scalar, &[(value, expr_meta)], meta)? + } + } + TypeInner::Scalar(scalar) => { + let mut expr = value; + if let TypeInner::Vector { .. } | TypeInner::Matrix { .. } = + *ctx.resolve_type(value, expr_meta)? + { + expr = ctx.add_expression( + Expression::AccessIndex { + base: expr, + index: 0, + }, + meta, + )?; + } + + if let TypeInner::Matrix { .. } = *ctx.resolve_type(value, expr_meta)? { + expr = ctx.add_expression( + Expression::AccessIndex { + base: expr, + index: 0, + }, + meta, + )?; + } + + ctx.add_expression( + Expression::As { + kind: scalar.kind, + expr, + convert: Some(scalar.width), + }, + meta, + )? + } + TypeInner::Vector { size, scalar } => { + if vector_size.map_or(true, |s| s != size) { + value = ctx.vector_resize(size, value, expr_meta)?; + } + + ctx.add_expression( + Expression::As { + kind: scalar.kind, + expr: value, + convert: Some(scalar.width), + }, + meta, + )? + } + TypeInner::Matrix { + columns, + rows, + scalar, + } => self.matrix_one_arg(ctx, ty, columns, rows, scalar, (value, expr_meta), meta)?, + TypeInner::Struct { ref members, .. } => { + let scalar_components = members + .get(0) + .and_then(|member| scalar_components(&ctx.module.types[member.ty].inner)); + if let Some(scalar) = scalar_components { + ctx.implicit_conversion(&mut value, expr_meta, scalar)?; + } + + ctx.add_expression( + Expression::Compose { + ty, + components: vec![value], + }, + meta, + )? + } + + TypeInner::Array { base, .. } => { + let scalar_components = scalar_components(&ctx.module.types[base].inner); + if let Some(scalar) = scalar_components { + ctx.implicit_conversion(&mut value, expr_meta, scalar)?; + } + + ctx.add_expression( + Expression::Compose { + ty, + components: vec![value], + }, + meta, + )? + } + _ => { + self.errors.push(Error { + kind: ErrorKind::SemanticError("Bad type constructor".into()), + meta, + }); + + value + } + }) + } + + #[allow(clippy::too_many_arguments)] + fn matrix_one_arg( + &mut self, + ctx: &mut Context, + ty: Handle<Type>, + columns: crate::VectorSize, + rows: crate::VectorSize, + element_scalar: Scalar, + (mut value, expr_meta): (Handle<Expression>, Span), + meta: Span, + ) -> Result<Handle<Expression>> { + let mut components = Vec::with_capacity(columns as usize); + // TODO: casts + // `Expression::As` doesn't support matrix width + // casts so we need to do some extra work for casts + + ctx.forced_conversion(&mut value, expr_meta, element_scalar)?; + match *ctx.resolve_type(value, expr_meta)? { + TypeInner::Scalar(_) => { + // If a matrix is constructed with a single scalar value, then that + // value is used to initialize all the values along the diagonal of + // the matrix; the rest are given zeros. + let vector_ty = ctx.module.types.insert( + Type { + name: None, + inner: TypeInner::Vector { + size: rows, + scalar: element_scalar, + }, + }, + meta, + ); + + let zero_literal = Literal::zero(element_scalar).unwrap(); + let zero = ctx.add_expression(Expression::Literal(zero_literal), meta)?; + + for i in 0..columns as u32 { + components.push( + ctx.add_expression( + Expression::Compose { + ty: vector_ty, + components: (0..rows as u32) + .map(|r| match r == i { + true => value, + false => zero, + }) + .collect(), + }, + meta, + )?, + ) + } + } + TypeInner::Matrix { + rows: ori_rows, + columns: ori_cols, + .. + } => { + // If a matrix is constructed from a matrix, then each component + // (column i, row j) in the result that has a corresponding component + // (column i, row j) in the argument will be initialized from there. All + // other components will be initialized to the identity matrix. + + let zero_literal = Literal::zero(element_scalar).unwrap(); + let one_literal = Literal::one(element_scalar).unwrap(); + + let zero = ctx.add_expression(Expression::Literal(zero_literal), meta)?; + let one = ctx.add_expression(Expression::Literal(one_literal), meta)?; + + let vector_ty = ctx.module.types.insert( + Type { + name: None, + inner: TypeInner::Vector { + size: rows, + scalar: element_scalar, + }, + }, + meta, + ); + + for i in 0..columns as u32 { + if i < ori_cols as u32 { + use std::cmp::Ordering; + + let vector = ctx.add_expression( + Expression::AccessIndex { + base: value, + index: i, + }, + meta, + )?; + + components.push(match ori_rows.cmp(&rows) { + Ordering::Less => { + let components = (0..rows as u32) + .map(|r| { + if r < ori_rows as u32 { + ctx.add_expression( + Expression::AccessIndex { + base: vector, + index: r, + }, + meta, + ) + } else if r == i { + Ok(one) + } else { + Ok(zero) + } + }) + .collect::<Result<_>>()?; + + ctx.add_expression( + Expression::Compose { + ty: vector_ty, + components, + }, + meta, + )? + } + Ordering::Equal => vector, + Ordering::Greater => ctx.vector_resize(rows, vector, meta)?, + }) + } else { + let compose_expr = Expression::Compose { + ty: vector_ty, + components: (0..rows as u32) + .map(|r| match r == i { + true => one, + false => zero, + }) + .collect(), + }; + + let vec = ctx.add_expression(compose_expr, meta)?; + + components.push(vec) + } + } + } + _ => { + components = iter::repeat(value).take(columns as usize).collect(); + } + } + + ctx.add_expression(Expression::Compose { ty, components }, meta) + } + + #[allow(clippy::too_many_arguments)] + fn vector_constructor( + &mut self, + ctx: &mut Context, + ty: Handle<Type>, + size: crate::VectorSize, + scalar: Scalar, + args: &[(Handle<Expression>, Span)], + meta: Span, + ) -> Result<Handle<Expression>> { + let mut components = Vec::with_capacity(size as usize); + + for (mut arg, expr_meta) in args.iter().copied() { + ctx.forced_conversion(&mut arg, expr_meta, scalar)?; + + if components.len() >= size as usize { + break; + } + + match *ctx.resolve_type(arg, expr_meta)? { + TypeInner::Scalar { .. } => components.push(arg), + TypeInner::Matrix { rows, columns, .. } => { + components.reserve(rows as usize * columns as usize); + for c in 0..(columns as u32) { + let base = ctx.add_expression( + Expression::AccessIndex { + base: arg, + index: c, + }, + expr_meta, + )?; + for r in 0..(rows as u32) { + components.push(ctx.add_expression( + Expression::AccessIndex { base, index: r }, + expr_meta, + )?) + } + } + } + TypeInner::Vector { size: ori_size, .. } => { + components.reserve(ori_size as usize); + for index in 0..(ori_size as u32) { + components.push(ctx.add_expression( + Expression::AccessIndex { base: arg, index }, + expr_meta, + )?) + } + } + _ => components.push(arg), + } + } + + components.truncate(size as usize); + + ctx.add_expression(Expression::Compose { ty, components }, meta) + } + + fn constructor_many( + &mut self, + ctx: &mut Context, + ty: Handle<Type>, + args: Vec<(Handle<Expression>, Span)>, + meta: Span, + ) -> Result<Handle<Expression>> { + let mut components = Vec::with_capacity(args.len()); + + let struct_member_data = match ctx.module.types[ty].inner { + TypeInner::Matrix { + columns, + rows, + scalar: element_scalar, + } => { + let mut flattened = Vec::with_capacity(columns as usize * rows as usize); + + for (mut arg, meta) in args.iter().copied() { + ctx.forced_conversion(&mut arg, meta, element_scalar)?; + + match *ctx.resolve_type(arg, meta)? { + TypeInner::Vector { size, .. } => { + for i in 0..(size as u32) { + flattened.push(ctx.add_expression( + Expression::AccessIndex { + base: arg, + index: i, + }, + meta, + )?) + } + } + _ => flattened.push(arg), + } + } + + let ty = ctx.module.types.insert( + Type { + name: None, + inner: TypeInner::Vector { + size: rows, + scalar: element_scalar, + }, + }, + meta, + ); + + for chunk in flattened.chunks(rows as usize) { + components.push(ctx.add_expression( + Expression::Compose { + ty, + components: Vec::from(chunk), + }, + meta, + )?) + } + None + } + TypeInner::Vector { size, scalar } => { + return self.vector_constructor(ctx, ty, size, scalar, &args, meta) + } + TypeInner::Array { base, .. } => { + for (mut arg, meta) in args.iter().copied() { + let scalar_components = scalar_components(&ctx.module.types[base].inner); + if let Some(scalar) = scalar_components { + ctx.implicit_conversion(&mut arg, meta, scalar)?; + } + + components.push(arg) + } + None + } + TypeInner::Struct { ref members, .. } => Some( + members + .iter() + .map(|member| scalar_components(&ctx.module.types[member.ty].inner)) + .collect::<Vec<_>>(), + ), + _ => { + return Err(Error { + kind: ErrorKind::SemanticError("Constructor: Too many arguments".into()), + meta, + }) + } + }; + + if let Some(struct_member_data) = struct_member_data { + for ((mut arg, meta), scalar_components) in + args.iter().copied().zip(struct_member_data.iter().copied()) + { + if let Some(scalar) = scalar_components { + ctx.implicit_conversion(&mut arg, meta, scalar)?; + } + + components.push(arg) + } + } + + ctx.add_expression(Expression::Compose { ty, components }, meta) + } + + #[allow(clippy::too_many_arguments)] + fn function_call( + &mut self, + ctx: &mut Context, + stmt: &StmtContext, + name: String, + args: Vec<(Handle<Expression>, Span)>, + raw_args: &[Handle<HirExpr>], + meta: Span, + ) -> Result<Option<Handle<Expression>>> { + // Grow the typifier to be able to index it later without needing + // to hold the context mutably + for &(expr, span) in args.iter() { + ctx.typifier_grow(expr, span)?; + } + + // Check if the passed arguments require any special variations + let mut variations = + builtin_required_variations(args.iter().map(|&(expr, _)| ctx.get_type(expr))); + + // Initiate the declaration if it wasn't previously initialized and inject builtins + let declaration = self.lookup_function.entry(name.clone()).or_insert_with(|| { + variations |= BuiltinVariations::STANDARD; + Default::default() + }); + inject_builtin(declaration, ctx.module, &name, variations); + + // Borrow again but without mutability, at this point a declaration is guaranteed + let declaration = self.lookup_function.get(&name).unwrap(); + + // Possibly contains the overload to be used in the call + let mut maybe_overload = None; + // The conversions needed for the best analyzed overload, this is initialized all to + // `NONE` to make sure that conversions always pass the first time without ambiguity + let mut old_conversions = vec![Conversion::None; args.len()]; + // Tracks whether the comparison between overloads lead to an ambiguity + let mut ambiguous = false; + + // Iterate over all the available overloads to select either an exact match or a + // overload which has suitable implicit conversions + 'outer: for (overload_idx, overload) in declaration.overloads.iter().enumerate() { + // If the overload and the function call don't have the same number of arguments + // continue to the next overload + if args.len() != overload.parameters.len() { + continue; + } + + log::trace!("Testing overload {}", overload_idx); + + // Stores whether the current overload matches exactly the function call + let mut exact = true; + // State of the selection + // If None we still don't know what is the best overload + // If Some(true) the new overload is better + // If Some(false) the old overload is better + let mut superior = None; + // Store the conversions for the current overload so that later they can replace the + // conversions used for querying the best overload + let mut new_conversions = vec![Conversion::None; args.len()]; + + // Loop through the overload parameters and check if the current overload is better + // compared to the previous best overload. + for (i, overload_parameter) in overload.parameters.iter().enumerate() { + let call_argument = &args[i]; + let parameter_info = &overload.parameters_info[i]; + + // If the image is used in the overload as a depth texture convert it + // before comparing, otherwise exact matches wouldn't be reported + if parameter_info.depth { + sampled_to_depth(ctx, call_argument.0, call_argument.1, &mut self.errors); + ctx.invalidate_expression(call_argument.0, call_argument.1)? + } + + ctx.typifier_grow(call_argument.0, call_argument.1)?; + + let overload_param_ty = &ctx.module.types[*overload_parameter].inner; + let call_arg_ty = ctx.get_type(call_argument.0); + + log::trace!( + "Testing parameter {}\n\tOverload = {:?}\n\tCall = {:?}", + i, + overload_param_ty, + call_arg_ty + ); + + // Storage images cannot be directly compared since while the access is part of the + // type in naga's IR, in glsl they are a qualifier and don't enter in the match as + // long as the access needed is satisfied. + if let ( + &TypeInner::Image { + class: + crate::ImageClass::Storage { + format: overload_format, + access: overload_access, + }, + dim: overload_dim, + arrayed: overload_arrayed, + }, + &TypeInner::Image { + class: + crate::ImageClass::Storage { + format: call_format, + access: call_access, + }, + dim: call_dim, + arrayed: call_arrayed, + }, + ) = (overload_param_ty, call_arg_ty) + { + // Images size must match otherwise the overload isn't what we want + let good_size = call_dim == overload_dim && call_arrayed == overload_arrayed; + // Glsl requires the formats to strictly match unless you are builtin + // function overload and have not been replaced, in which case we only + // check that the format scalar kind matches + let good_format = overload_format == call_format + || (overload.internal + && ScalarKind::from(overload_format) == ScalarKind::from(call_format)); + if !(good_size && good_format) { + continue 'outer; + } + + // While storage access mismatch is an error it isn't one that causes + // the overload matching to fail so we defer the error and consider + // that the images match exactly + if !call_access.contains(overload_access) { + self.errors.push(Error { + kind: ErrorKind::SemanticError( + format!( + "'{name}': image needs {overload_access:?} access but only {call_access:?} was provided" + ) + .into(), + ), + meta, + }); + } + + // The images satisfy the conditions to be considered as an exact match + new_conversions[i] = Conversion::Exact; + continue; + } else if overload_param_ty == call_arg_ty { + // If the types match there's no need to check for conversions so continue + new_conversions[i] = Conversion::Exact; + continue; + } + + // Glsl defines that inout follows both the conversions for input parameters and + // output parameters, this means that the type must have a conversion from both the + // call argument to the function parameter and the function parameter to the call + // argument, the only way this is possible is for the conversion to be an identity + // (i.e. call argument = function parameter) + if let ParameterQualifier::InOut = parameter_info.qualifier { + continue 'outer; + } + + // The function call argument and the function definition + // parameter are not equal at this point, so we need to try + // implicit conversions. + // + // Now there are two cases, the argument is defined as a normal + // parameter (`in` or `const`), in this case an implicit + // conversion is made from the calling argument to the + // definition argument. If the parameter is `out` the + // opposite needs to be done, so the implicit conversion is made + // from the definition argument to the calling argument. + let maybe_conversion = if parameter_info.qualifier.is_lhs() { + conversion(call_arg_ty, overload_param_ty) + } else { + conversion(overload_param_ty, call_arg_ty) + }; + + let conversion = match maybe_conversion { + Some(info) => info, + None => continue 'outer, + }; + + // At this point a conversion will be needed so the overload no longer + // exactly matches the call arguments + exact = false; + + // Compare the conversions needed for this overload parameter to that of the + // last overload analyzed respective parameter, the value is: + // - `true` when the new overload argument has a better conversion + // - `false` when the old overload argument has a better conversion + let best_arg = match (conversion, old_conversions[i]) { + // An exact match is always better, we don't need to check this for the + // current overload since it was checked earlier + (_, Conversion::Exact) => false, + // No overload was yet analyzed so this one is the best yet + (_, Conversion::None) => true, + // A conversion from a float to a double is the best possible conversion + (Conversion::FloatToDouble, _) => true, + (_, Conversion::FloatToDouble) => false, + // A conversion from a float to an integer is preferred than one + // from double to an integer + (Conversion::IntToFloat, Conversion::IntToDouble) => true, + (Conversion::IntToDouble, Conversion::IntToFloat) => false, + // This case handles things like no conversion and exact which were already + // treated and other cases which no conversion is better than the other + _ => continue, + }; + + // Check if the best parameter corresponds to the current selected overload + // to pass to the next comparison, if this isn't true mark it as ambiguous + match best_arg { + true => match superior { + Some(false) => ambiguous = true, + _ => { + superior = Some(true); + new_conversions[i] = conversion + } + }, + false => match superior { + Some(true) => ambiguous = true, + _ => superior = Some(false), + }, + } + } + + // The overload matches exactly the function call so there's no ambiguity (since + // repeated overload aren't allowed) and the current overload is selected, no + // further querying is needed. + if exact { + maybe_overload = Some(overload); + ambiguous = false; + break; + } + + match superior { + // New overload is better keep it + Some(true) => { + maybe_overload = Some(overload); + // Replace the conversions + old_conversions = new_conversions; + } + // Old overload is better do nothing + Some(false) => {} + // No overload was better than the other this can be caused + // when all conversions are ambiguous in which the overloads themselves are + // ambiguous. + None => { + ambiguous = true; + // Assign the new overload, this helps ensures that in this case of + // ambiguity the parsing won't end immediately and allow for further + // collection of errors. + maybe_overload = Some(overload); + } + } + } + + if ambiguous { + self.errors.push(Error { + kind: ErrorKind::SemanticError( + format!("Ambiguous best function for '{name}'").into(), + ), + meta, + }) + } + + let overload = maybe_overload.ok_or_else(|| Error { + kind: ErrorKind::SemanticError(format!("Unknown function '{name}'").into()), + meta, + })?; + + let parameters_info = overload.parameters_info.clone(); + let parameters = overload.parameters.clone(); + let is_void = overload.void; + let kind = overload.kind; + + let mut arguments = Vec::with_capacity(args.len()); + let mut proxy_writes = Vec::new(); + + // Iterate through the function call arguments applying transformations as needed + for (((parameter_info, call_argument), expr), parameter) in parameters_info + .iter() + .zip(&args) + .zip(raw_args) + .zip(¶meters) + { + let (mut handle, meta) = + ctx.lower_expect_inner(stmt, self, *expr, parameter_info.qualifier.as_pos())?; + + if parameter_info.qualifier.is_lhs() { + self.process_lhs_argument( + ctx, + meta, + *parameter, + parameter_info, + handle, + call_argument, + &mut proxy_writes, + &mut arguments, + )?; + + continue; + } + + let scalar_comps = scalar_components(&ctx.module.types[*parameter].inner); + + // Apply implicit conversions as needed + if let Some(scalar) = scalar_comps { + ctx.implicit_conversion(&mut handle, meta, scalar)?; + } + + arguments.push(handle) + } + + match kind { + FunctionKind::Call(function) => { + ctx.emit_end(); + + let result = if !is_void { + Some(ctx.add_expression(Expression::CallResult(function), meta)?) + } else { + None + }; + + ctx.body.push( + crate::Statement::Call { + function, + arguments, + result, + }, + meta, + ); + + ctx.emit_start(); + + // Write back all the variables that were scheduled to their original place + for proxy_write in proxy_writes { + let mut value = ctx.add_expression( + Expression::Load { + pointer: proxy_write.value, + }, + meta, + )?; + + if let Some(scalar) = proxy_write.convert { + ctx.conversion(&mut value, meta, scalar)?; + } + + ctx.emit_restart(); + + ctx.body.push( + Statement::Store { + pointer: proxy_write.target, + value, + }, + meta, + ); + } + + Ok(result) + } + FunctionKind::Macro(builtin) => builtin.call(self, ctx, arguments.as_mut_slice(), meta), + } + } + + /// Processes a function call argument that appears in place of an output + /// parameter. + #[allow(clippy::too_many_arguments)] + fn process_lhs_argument( + &mut self, + ctx: &mut Context, + meta: Span, + parameter_ty: Handle<Type>, + parameter_info: &ParameterInfo, + original: Handle<Expression>, + call_argument: &(Handle<Expression>, Span), + proxy_writes: &mut Vec<ProxyWrite>, + arguments: &mut Vec<Handle<Expression>>, + ) -> Result<()> { + let original_ty = ctx.resolve_type(original, meta)?; + let original_pointer_space = original_ty.pointer_space(); + + // The type of a possible spill variable needed for a proxy write + let mut maybe_ty = match *original_ty { + // If the argument is to be passed as a pointer but the type of the + // expression returns a vector it must mean that it was for example + // swizzled and it must be spilled into a local before calling + TypeInner::Vector { size, scalar } => Some(ctx.module.types.insert( + Type { + name: None, + inner: TypeInner::Vector { size, scalar }, + }, + Span::default(), + )), + // If the argument is a pointer whose address space isn't `Function`, an + // indirection through a local variable is needed to align the address + // spaces of the call argument and the overload parameter. + TypeInner::Pointer { base, space } if space != AddressSpace::Function => Some(base), + TypeInner::ValuePointer { + size, + scalar, + space, + } if space != AddressSpace::Function => { + let inner = match size { + Some(size) => TypeInner::Vector { size, scalar }, + None => TypeInner::Scalar(scalar), + }; + + Some( + ctx.module + .types + .insert(Type { name: None, inner }, Span::default()), + ) + } + _ => None, + }; + + // Since the original expression might be a pointer and we want a value + // for the proxy writes, we might need to load the pointer. + let value = if original_pointer_space.is_some() { + ctx.add_expression(Expression::Load { pointer: original }, Span::default())? + } else { + original + }; + + ctx.typifier_grow(call_argument.0, call_argument.1)?; + + let overload_param_ty = &ctx.module.types[parameter_ty].inner; + let call_arg_ty = ctx.get_type(call_argument.0); + let needs_conversion = call_arg_ty != overload_param_ty; + + let arg_scalar_comps = scalar_components(call_arg_ty); + + // Since output parameters also allow implicit conversions from the + // parameter to the argument, we need to spill the conversion to a + // variable and create a proxy write for the original variable. + if needs_conversion { + maybe_ty = Some(parameter_ty); + } + + if let Some(ty) = maybe_ty { + // Create the spill variable + let spill_var = ctx.locals.append( + LocalVariable { + name: None, + ty, + init: None, + }, + Span::default(), + ); + let spill_expr = + ctx.add_expression(Expression::LocalVariable(spill_var), Span::default())?; + + // If the argument is also copied in we must store the value of the + // original variable to the spill variable. + if let ParameterQualifier::InOut = parameter_info.qualifier { + ctx.body.push( + Statement::Store { + pointer: spill_expr, + value, + }, + Span::default(), + ); + } + + // Add the spill variable as an argument to the function call + arguments.push(spill_expr); + + let convert = if needs_conversion { + arg_scalar_comps + } else { + None + }; + + // Register the temporary local to be written back to it's original + // place after the function call + if let Expression::Swizzle { + size, + mut vector, + pattern, + } = ctx.expressions[original] + { + if let Expression::Load { pointer } = ctx.expressions[vector] { + vector = pointer; + } + + for (i, component) in pattern.iter().take(size as usize).enumerate() { + let original = ctx.add_expression( + Expression::AccessIndex { + base: vector, + index: *component as u32, + }, + Span::default(), + )?; + + let spill_component = ctx.add_expression( + Expression::AccessIndex { + base: spill_expr, + index: i as u32, + }, + Span::default(), + )?; + + proxy_writes.push(ProxyWrite { + target: original, + value: spill_component, + convert, + }); + } + } else { + proxy_writes.push(ProxyWrite { + target: original, + value: spill_expr, + convert, + }); + } + } else { + arguments.push(original); + } + + Ok(()) + } + + pub(crate) fn add_function( + &mut self, + mut ctx: Context, + name: String, + result: Option<FunctionResult>, + meta: Span, + ) { + ensure_block_returns(&mut ctx.body); + + let void = result.is_none(); + + // Check if the passed arguments require any special variations + let mut variations = builtin_required_variations( + ctx.parameters + .iter() + .map(|&arg| &ctx.module.types[arg].inner), + ); + + // Initiate the declaration if it wasn't previously initialized and inject builtins + let declaration = self.lookup_function.entry(name.clone()).or_insert_with(|| { + variations |= BuiltinVariations::STANDARD; + Default::default() + }); + inject_builtin(declaration, ctx.module, &name, variations); + + let Context { + expressions, + locals, + arguments, + parameters, + parameters_info, + body, + module, + .. + } = ctx; + + let function = Function { + name: Some(name), + arguments, + result, + local_variables: locals, + expressions, + named_expressions: crate::NamedExpressions::default(), + body, + }; + + 'outer: for decl in declaration.overloads.iter_mut() { + if parameters.len() != decl.parameters.len() { + continue; + } + + for (new_parameter, old_parameter) in parameters.iter().zip(decl.parameters.iter()) { + let new_inner = &module.types[*new_parameter].inner; + let old_inner = &module.types[*old_parameter].inner; + + if new_inner != old_inner { + continue 'outer; + } + } + + if decl.defined { + return self.errors.push(Error { + kind: ErrorKind::SemanticError("Function already defined".into()), + meta, + }); + } + + decl.defined = true; + decl.parameters_info = parameters_info; + match decl.kind { + FunctionKind::Call(handle) => *module.functions.get_mut(handle) = function, + FunctionKind::Macro(_) => { + let handle = module.functions.append(function, meta); + decl.kind = FunctionKind::Call(handle) + } + } + return; + } + + let handle = module.functions.append(function, meta); + declaration.overloads.push(Overload { + parameters, + parameters_info, + kind: FunctionKind::Call(handle), + defined: true, + internal: false, + void, + }); + } + + pub(crate) fn add_prototype( + &mut self, + ctx: Context, + name: String, + result: Option<FunctionResult>, + meta: Span, + ) { + let void = result.is_none(); + + // Check if the passed arguments require any special variations + let mut variations = builtin_required_variations( + ctx.parameters + .iter() + .map(|&arg| &ctx.module.types[arg].inner), + ); + + // Initiate the declaration if it wasn't previously initialized and inject builtins + let declaration = self.lookup_function.entry(name.clone()).or_insert_with(|| { + variations |= BuiltinVariations::STANDARD; + Default::default() + }); + inject_builtin(declaration, ctx.module, &name, variations); + + let Context { + arguments, + parameters, + parameters_info, + module, + .. + } = ctx; + + let function = Function { + name: Some(name), + arguments, + result, + ..Default::default() + }; + + 'outer: for decl in declaration.overloads.iter() { + if parameters.len() != decl.parameters.len() { + continue; + } + + for (new_parameter, old_parameter) in parameters.iter().zip(decl.parameters.iter()) { + let new_inner = &module.types[*new_parameter].inner; + let old_inner = &module.types[*old_parameter].inner; + + if new_inner != old_inner { + continue 'outer; + } + } + + return self.errors.push(Error { + kind: ErrorKind::SemanticError("Prototype already defined".into()), + meta, + }); + } + + let handle = module.functions.append(function, meta); + declaration.overloads.push(Overload { + parameters, + parameters_info, + kind: FunctionKind::Call(handle), + defined: false, + internal: false, + void, + }); + } + + /// Create a Naga [`EntryPoint`] that calls the GLSL `main` function. + /// + /// We compile the GLSL `main` function as an ordinary Naga [`Function`]. + /// This function synthesizes a Naga [`EntryPoint`] to call that. + /// + /// Each GLSL input and output variable (including builtins) becomes a Naga + /// [`GlobalVariable`]s in the [`Private`] address space, which `main` can + /// access in the usual way. + /// + /// The `EntryPoint` we synthesize here has an argument for each GLSL input + /// variable, and returns a struct with a member for each GLSL output + /// variable. The entry point contains code to: + /// + /// - copy its arguments into the Naga globals representing the GLSL input + /// variables, + /// + /// - call the Naga `Function` representing the GLSL `main` function, and then + /// + /// - build its return value from whatever values the GLSL `main` left in + /// the Naga globals representing GLSL `output` variables. + /// + /// Upon entry, [`ctx.body`] should contain code, accumulated by prior calls + /// to [`ParsingContext::parse_external_declaration`][pxd], to initialize + /// private global variables as needed. This code gets spliced into the + /// entry point before the call to `main`. + /// + /// [`GlobalVariable`]: crate::GlobalVariable + /// [`Private`]: crate::AddressSpace::Private + /// [`ctx.body`]: Context::body + /// [pxd]: super::ParsingContext::parse_external_declaration + pub(crate) fn add_entry_point( + &mut self, + function: Handle<Function>, + mut ctx: Context, + ) -> Result<()> { + let mut arguments = Vec::new(); + + let body = Block::with_capacity( + // global init body + ctx.body.len() + + // prologue and epilogue + self.entry_args.len() * 2 + // Call, Emit for composing struct and return + + 3, + ); + + let global_init_body = std::mem::replace(&mut ctx.body, body); + + for arg in self.entry_args.iter() { + if arg.storage != StorageQualifier::Input { + continue; + } + + let pointer = ctx + .expressions + .append(Expression::GlobalVariable(arg.handle), Default::default()); + + let ty = ctx.module.global_variables[arg.handle].ty; + + ctx.arg_type_walker( + arg.name.clone(), + arg.binding.clone(), + pointer, + ty, + &mut |ctx, name, pointer, ty, binding| { + let idx = arguments.len() as u32; + + arguments.push(FunctionArgument { + name, + ty, + binding: Some(binding), + }); + + let value = ctx + .expressions + .append(Expression::FunctionArgument(idx), Default::default()); + ctx.body + .push(Statement::Store { pointer, value }, Default::default()); + }, + )? + } + + ctx.body.extend_block(global_init_body); + + ctx.body.push( + Statement::Call { + function, + arguments: Vec::new(), + result: None, + }, + Default::default(), + ); + + let mut span = 0; + let mut members = Vec::new(); + let mut components = Vec::new(); + + for arg in self.entry_args.iter() { + if arg.storage != StorageQualifier::Output { + continue; + } + + let pointer = ctx + .expressions + .append(Expression::GlobalVariable(arg.handle), Default::default()); + + let ty = ctx.module.global_variables[arg.handle].ty; + + ctx.arg_type_walker( + arg.name.clone(), + arg.binding.clone(), + pointer, + ty, + &mut |ctx, name, pointer, ty, binding| { + members.push(StructMember { + name, + ty, + binding: Some(binding), + offset: span, + }); + + span += ctx.module.types[ty].inner.size(ctx.module.to_ctx()); + + let len = ctx.expressions.len(); + let load = ctx + .expressions + .append(Expression::Load { pointer }, Default::default()); + ctx.body.push( + Statement::Emit(ctx.expressions.range_from(len)), + Default::default(), + ); + components.push(load) + }, + )? + } + + let (ty, value) = if !components.is_empty() { + let ty = ctx.module.types.insert( + Type { + name: None, + inner: TypeInner::Struct { members, span }, + }, + Default::default(), + ); + + let len = ctx.expressions.len(); + let res = ctx + .expressions + .append(Expression::Compose { ty, components }, Default::default()); + ctx.body.push( + Statement::Emit(ctx.expressions.range_from(len)), + Default::default(), + ); + + (Some(ty), Some(res)) + } else { + (None, None) + }; + + ctx.body + .push(Statement::Return { value }, Default::default()); + + let Context { + body, expressions, .. + } = ctx; + + ctx.module.entry_points.push(EntryPoint { + name: "main".to_string(), + stage: self.meta.stage, + early_depth_test: Some(crate::EarlyDepthTest { conservative: None }) + .filter(|_| self.meta.early_fragment_tests), + workgroup_size: self.meta.workgroup_size, + function: Function { + arguments, + expressions, + body, + result: ty.map(|ty| FunctionResult { ty, binding: None }), + ..Default::default() + }, + }); + + Ok(()) + } +} + +impl Context<'_> { + /// Helper function for building the input/output interface of the entry point + /// + /// Calls `f` with the data of the entry point argument, flattening composite types + /// recursively + /// + /// The passed arguments to the callback are: + /// - The ctx + /// - The name + /// - The pointer expression to the global storage + /// - The handle to the type of the entry point argument + /// - The binding of the entry point argument + fn arg_type_walker( + &mut self, + name: Option<String>, + binding: crate::Binding, + pointer: Handle<Expression>, + ty: Handle<Type>, + f: &mut impl FnMut( + &mut Context, + Option<String>, + Handle<Expression>, + Handle<Type>, + crate::Binding, + ), + ) -> Result<()> { + match self.module.types[ty].inner { + // TODO: Better error reporting + // right now we just don't walk the array if the size isn't known at + // compile time and let validation catch it + TypeInner::Array { + base, + size: crate::ArraySize::Constant(size), + .. + } => { + let mut location = match binding { + crate::Binding::Location { location, .. } => location, + crate::Binding::BuiltIn(_) => return Ok(()), + }; + + let interpolation = + self.module.types[base] + .inner + .scalar_kind() + .map(|kind| match kind { + ScalarKind::Float => crate::Interpolation::Perspective, + _ => crate::Interpolation::Flat, + }); + + for index in 0..size.get() { + let member_pointer = self.add_expression( + Expression::AccessIndex { + base: pointer, + index, + }, + crate::Span::default(), + )?; + + let binding = crate::Binding::Location { + location, + interpolation, + sampling: None, + second_blend_source: false, + }; + location += 1; + + self.arg_type_walker(name.clone(), binding, member_pointer, base, f)? + } + } + TypeInner::Struct { ref members, .. } => { + let mut location = match binding { + crate::Binding::Location { location, .. } => location, + crate::Binding::BuiltIn(_) => return Ok(()), + }; + + for (i, member) in members.clone().into_iter().enumerate() { + let member_pointer = self.add_expression( + Expression::AccessIndex { + base: pointer, + index: i as u32, + }, + crate::Span::default(), + )?; + + let binding = match member.binding { + Some(binding) => binding, + None => { + let interpolation = self.module.types[member.ty] + .inner + .scalar_kind() + .map(|kind| match kind { + ScalarKind::Float => crate::Interpolation::Perspective, + _ => crate::Interpolation::Flat, + }); + let binding = crate::Binding::Location { + location, + interpolation, + sampling: None, + second_blend_source: false, + }; + location += 1; + binding + } + }; + + self.arg_type_walker(member.name, binding, member_pointer, member.ty, f)? + } + } + _ => f(self, name, pointer, ty, binding), + } + + Ok(()) + } +} + +/// Helper enum containing the type of conversion need for a call +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +enum Conversion { + /// No conversion needed + Exact, + /// Float to double conversion needed + FloatToDouble, + /// Int or uint to float conversion needed + IntToFloat, + /// Int or uint to double conversion needed + IntToDouble, + /// Other type of conversion needed + Other, + /// No conversion was yet registered + None, +} + +/// Helper function, returns the type of conversion from `source` to `target`, if a +/// conversion is not possible returns None. +fn conversion(target: &TypeInner, source: &TypeInner) -> Option<Conversion> { + use ScalarKind::*; + + // Gather the `ScalarKind` and scalar width from both the target and the source + let (target_scalar, source_scalar) = match (target, source) { + // Conversions between scalars are allowed + (&TypeInner::Scalar(tgt_scalar), &TypeInner::Scalar(src_scalar)) => { + (tgt_scalar, src_scalar) + } + // Conversions between vectors of the same size are allowed + ( + &TypeInner::Vector { + size: tgt_size, + scalar: tgt_scalar, + }, + &TypeInner::Vector { + size: src_size, + scalar: src_scalar, + }, + ) if tgt_size == src_size => (tgt_scalar, src_scalar), + // Conversions between matrices of the same size are allowed + ( + &TypeInner::Matrix { + rows: tgt_rows, + columns: tgt_cols, + scalar: tgt_scalar, + }, + &TypeInner::Matrix { + rows: src_rows, + columns: src_cols, + scalar: src_scalar, + }, + ) if tgt_cols == src_cols && tgt_rows == src_rows => (tgt_scalar, src_scalar), + _ => return None, + }; + + // Check if source can be converted into target, if this is the case then the type + // power of target must be higher than that of source + let target_power = type_power(target_scalar); + let source_power = type_power(source_scalar); + if target_power < source_power { + return None; + } + + Some(match (target_scalar, source_scalar) { + // A conversion from a float to a double is special + (Scalar::F64, Scalar::F32) => Conversion::FloatToDouble, + // A conversion from an integer to a float is special + ( + Scalar::F32, + Scalar { + kind: Sint | Uint, + width: _, + }, + ) => Conversion::IntToFloat, + // A conversion from an integer to a double is special + ( + Scalar::F64, + Scalar { + kind: Sint | Uint, + width: _, + }, + ) => Conversion::IntToDouble, + _ => Conversion::Other, + }) +} + +/// Helper method returning all the non standard builtin variations needed +/// to process the function call with the passed arguments +fn builtin_required_variations<'a>(args: impl Iterator<Item = &'a TypeInner>) -> BuiltinVariations { + let mut variations = BuiltinVariations::empty(); + + for ty in args { + match *ty { + TypeInner::ValuePointer { scalar, .. } + | TypeInner::Scalar(scalar) + | TypeInner::Vector { scalar, .. } + | TypeInner::Matrix { scalar, .. } => { + if scalar == Scalar::F64 { + variations |= BuiltinVariations::DOUBLE + } + } + TypeInner::Image { + dim, + arrayed, + class, + } => { + if dim == crate::ImageDimension::Cube && arrayed { + variations |= BuiltinVariations::CUBE_TEXTURES_ARRAY + } + + if dim == crate::ImageDimension::D2 && arrayed && class.is_multisampled() { + variations |= BuiltinVariations::D2_MULTI_TEXTURES_ARRAY + } + } + _ => {} + } + } + + variations +} diff --git a/third_party/rust/naga/src/front/glsl/lex.rs b/third_party/rust/naga/src/front/glsl/lex.rs new file mode 100644 index 0000000000..1b59a9bf3e --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/lex.rs @@ -0,0 +1,301 @@ +use super::{ + ast::Precision, + token::{Directive, DirectiveKind, Token, TokenValue}, + types::parse_type, +}; +use crate::{FastHashMap, Span, StorageAccess}; +use pp_rs::{ + pp::Preprocessor, + token::{PreprocessorError, Punct, TokenValue as PPTokenValue}, +}; + +#[derive(Debug)] +#[cfg_attr(test, derive(PartialEq))] +pub struct LexerResult { + pub kind: LexerResultKind, + pub meta: Span, +} + +#[derive(Debug)] +#[cfg_attr(test, derive(PartialEq))] +pub enum LexerResultKind { + Token(Token), + Directive(Directive), + Error(PreprocessorError), +} + +pub struct Lexer<'a> { + pp: Preprocessor<'a>, +} + +impl<'a> Lexer<'a> { + pub fn new(input: &'a str, defines: &'a FastHashMap<String, String>) -> Self { + let mut pp = Preprocessor::new(input); + for (define, value) in defines { + pp.add_define(define, value).unwrap(); //TODO: handle error + } + Lexer { pp } + } +} + +impl<'a> Iterator for Lexer<'a> { + type Item = LexerResult; + fn next(&mut self) -> Option<Self::Item> { + let pp_token = match self.pp.next()? { + Ok(t) => t, + Err((err, loc)) => { + return Some(LexerResult { + kind: LexerResultKind::Error(err), + meta: loc.into(), + }); + } + }; + + let meta = pp_token.location.into(); + let value = match pp_token.value { + PPTokenValue::Extension(extension) => { + return Some(LexerResult { + kind: LexerResultKind::Directive(Directive { + kind: DirectiveKind::Extension, + tokens: extension.tokens, + }), + meta, + }) + } + PPTokenValue::Float(float) => TokenValue::FloatConstant(float), + PPTokenValue::Ident(ident) => { + match ident.as_str() { + // Qualifiers + "layout" => TokenValue::Layout, + "in" => TokenValue::In, + "out" => TokenValue::Out, + "uniform" => TokenValue::Uniform, + "buffer" => TokenValue::Buffer, + "shared" => TokenValue::Shared, + "invariant" => TokenValue::Invariant, + "flat" => TokenValue::Interpolation(crate::Interpolation::Flat), + "noperspective" => TokenValue::Interpolation(crate::Interpolation::Linear), + "smooth" => TokenValue::Interpolation(crate::Interpolation::Perspective), + "centroid" => TokenValue::Sampling(crate::Sampling::Centroid), + "sample" => TokenValue::Sampling(crate::Sampling::Sample), + "const" => TokenValue::Const, + "inout" => TokenValue::InOut, + "precision" => TokenValue::Precision, + "highp" => TokenValue::PrecisionQualifier(Precision::High), + "mediump" => TokenValue::PrecisionQualifier(Precision::Medium), + "lowp" => TokenValue::PrecisionQualifier(Precision::Low), + "restrict" => TokenValue::Restrict, + "readonly" => TokenValue::MemoryQualifier(StorageAccess::LOAD), + "writeonly" => TokenValue::MemoryQualifier(StorageAccess::STORE), + // values + "true" => TokenValue::BoolConstant(true), + "false" => TokenValue::BoolConstant(false), + // jump statements + "continue" => TokenValue::Continue, + "break" => TokenValue::Break, + "return" => TokenValue::Return, + "discard" => TokenValue::Discard, + // selection statements + "if" => TokenValue::If, + "else" => TokenValue::Else, + "switch" => TokenValue::Switch, + "case" => TokenValue::Case, + "default" => TokenValue::Default, + // iteration statements + "while" => TokenValue::While, + "do" => TokenValue::Do, + "for" => TokenValue::For, + // types + "void" => TokenValue::Void, + "struct" => TokenValue::Struct, + word => match parse_type(word) { + Some(t) => TokenValue::TypeName(t), + None => TokenValue::Identifier(String::from(word)), + }, + } + } + PPTokenValue::Integer(integer) => TokenValue::IntConstant(integer), + PPTokenValue::Punct(punct) => match punct { + // Compound assignments + Punct::AddAssign => TokenValue::AddAssign, + Punct::SubAssign => TokenValue::SubAssign, + Punct::MulAssign => TokenValue::MulAssign, + Punct::DivAssign => TokenValue::DivAssign, + Punct::ModAssign => TokenValue::ModAssign, + Punct::LeftShiftAssign => TokenValue::LeftShiftAssign, + Punct::RightShiftAssign => TokenValue::RightShiftAssign, + Punct::AndAssign => TokenValue::AndAssign, + Punct::XorAssign => TokenValue::XorAssign, + Punct::OrAssign => TokenValue::OrAssign, + + // Two character punctuation + Punct::Increment => TokenValue::Increment, + Punct::Decrement => TokenValue::Decrement, + Punct::LogicalAnd => TokenValue::LogicalAnd, + Punct::LogicalOr => TokenValue::LogicalOr, + Punct::LogicalXor => TokenValue::LogicalXor, + Punct::LessEqual => TokenValue::LessEqual, + Punct::GreaterEqual => TokenValue::GreaterEqual, + Punct::EqualEqual => TokenValue::Equal, + Punct::NotEqual => TokenValue::NotEqual, + Punct::LeftShift => TokenValue::LeftShift, + Punct::RightShift => TokenValue::RightShift, + + // Parenthesis or similar + Punct::LeftBrace => TokenValue::LeftBrace, + Punct::RightBrace => TokenValue::RightBrace, + Punct::LeftParen => TokenValue::LeftParen, + Punct::RightParen => TokenValue::RightParen, + Punct::LeftBracket => TokenValue::LeftBracket, + Punct::RightBracket => TokenValue::RightBracket, + + // Other one character punctuation + Punct::LeftAngle => TokenValue::LeftAngle, + Punct::RightAngle => TokenValue::RightAngle, + Punct::Semicolon => TokenValue::Semicolon, + Punct::Comma => TokenValue::Comma, + Punct::Colon => TokenValue::Colon, + Punct::Dot => TokenValue::Dot, + Punct::Equal => TokenValue::Assign, + Punct::Bang => TokenValue::Bang, + Punct::Minus => TokenValue::Dash, + Punct::Tilde => TokenValue::Tilde, + Punct::Plus => TokenValue::Plus, + Punct::Star => TokenValue::Star, + Punct::Slash => TokenValue::Slash, + Punct::Percent => TokenValue::Percent, + Punct::Pipe => TokenValue::VerticalBar, + Punct::Caret => TokenValue::Caret, + Punct::Ampersand => TokenValue::Ampersand, + Punct::Question => TokenValue::Question, + }, + PPTokenValue::Pragma(pragma) => { + return Some(LexerResult { + kind: LexerResultKind::Directive(Directive { + kind: DirectiveKind::Pragma, + tokens: pragma.tokens, + }), + meta, + }) + } + PPTokenValue::Version(version) => { + return Some(LexerResult { + kind: LexerResultKind::Directive(Directive { + kind: DirectiveKind::Version { + is_first_directive: version.is_first_directive, + }, + tokens: version.tokens, + }), + meta, + }) + } + }; + + Some(LexerResult { + kind: LexerResultKind::Token(Token { value, meta }), + meta, + }) + } +} + +#[cfg(test)] +mod tests { + use pp_rs::token::{Integer, Location, Token as PPToken, TokenValue as PPTokenValue}; + + use super::{ + super::token::{Directive, DirectiveKind, Token, TokenValue}, + Lexer, LexerResult, LexerResultKind, + }; + use crate::Span; + + #[test] + fn lex_tokens() { + let defines = crate::FastHashMap::default(); + + // line comments + let mut lex = Lexer::new("#version 450\nvoid main () {}", &defines); + let mut location = Location::default(); + location.start = 9; + location.end = 12; + assert_eq!( + lex.next().unwrap(), + LexerResult { + kind: LexerResultKind::Directive(Directive { + kind: DirectiveKind::Version { + is_first_directive: true + }, + tokens: vec![PPToken { + value: PPTokenValue::Integer(Integer { + signed: true, + value: 450, + width: 32 + }), + location + }] + }), + meta: Span::new(1, 8) + } + ); + assert_eq!( + lex.next().unwrap(), + LexerResult { + kind: LexerResultKind::Token(Token { + value: TokenValue::Void, + meta: Span::new(13, 17) + }), + meta: Span::new(13, 17) + } + ); + assert_eq!( + lex.next().unwrap(), + LexerResult { + kind: LexerResultKind::Token(Token { + value: TokenValue::Identifier("main".into()), + meta: Span::new(18, 22) + }), + meta: Span::new(18, 22) + } + ); + assert_eq!( + lex.next().unwrap(), + LexerResult { + kind: LexerResultKind::Token(Token { + value: TokenValue::LeftParen, + meta: Span::new(23, 24) + }), + meta: Span::new(23, 24) + } + ); + assert_eq!( + lex.next().unwrap(), + LexerResult { + kind: LexerResultKind::Token(Token { + value: TokenValue::RightParen, + meta: Span::new(24, 25) + }), + meta: Span::new(24, 25) + } + ); + assert_eq!( + lex.next().unwrap(), + LexerResult { + kind: LexerResultKind::Token(Token { + value: TokenValue::LeftBrace, + meta: Span::new(26, 27) + }), + meta: Span::new(26, 27) + } + ); + assert_eq!( + lex.next().unwrap(), + LexerResult { + kind: LexerResultKind::Token(Token { + value: TokenValue::RightBrace, + meta: Span::new(27, 28) + }), + meta: Span::new(27, 28) + } + ); + assert_eq!(lex.next(), None); + } +} diff --git a/third_party/rust/naga/src/front/glsl/mod.rs b/third_party/rust/naga/src/front/glsl/mod.rs new file mode 100644 index 0000000000..75f3929db4 --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/mod.rs @@ -0,0 +1,232 @@ +/*! +Frontend for [GLSL][glsl] (OpenGL Shading Language). + +To begin, take a look at the documentation for the [`Frontend`]. + +# Supported versions +## Vulkan +- 440 (partial) +- 450 +- 460 + +[glsl]: https://www.khronos.org/registry/OpenGL/index_gl.php +*/ + +pub use ast::{Precision, Profile}; +pub use error::{Error, ErrorKind, ExpectedToken, ParseError}; +pub use token::TokenValue; + +use crate::{proc::Layouter, FastHashMap, FastHashSet, Handle, Module, ShaderStage, Span, Type}; +use ast::{EntryArg, FunctionDeclaration, GlobalLookup}; +use parser::ParsingContext; + +mod ast; +mod builtins; +mod context; +mod error; +mod functions; +mod lex; +mod offset; +mod parser; +#[cfg(test)] +mod parser_tests; +mod token; +mod types; +mod variables; + +type Result<T> = std::result::Result<T, Error>; + +/// Per-shader options passed to [`parse`](Frontend::parse). +/// +/// The [`From`] trait is implemented for [`ShaderStage`] to provide a quick way +/// to create an `Options` instance. +/// +/// ```rust +/// # use naga::ShaderStage; +/// # use naga::front::glsl::Options; +/// Options::from(ShaderStage::Vertex); +/// ``` +#[derive(Debug)] +pub struct Options { + /// The shader stage in the pipeline. + pub stage: ShaderStage, + /// Preprocessor definitions to be used, akin to having + /// ```glsl + /// #define key value + /// ``` + /// for each key value pair in the map. + pub defines: FastHashMap<String, String>, +} + +impl From<ShaderStage> for Options { + fn from(stage: ShaderStage) -> Self { + Options { + stage, + defines: FastHashMap::default(), + } + } +} + +/// Additional information about the GLSL shader. +/// +/// Stores additional information about the GLSL shader which might not be +/// stored in the shader [`Module`]. +#[derive(Debug)] +pub struct ShaderMetadata { + /// The GLSL version specified in the shader through the use of the + /// `#version` preprocessor directive. + pub version: u16, + /// The GLSL profile specified in the shader through the use of the + /// `#version` preprocessor directive. + pub profile: Profile, + /// The shader stage in the pipeline, passed to the [`parse`](Frontend::parse) + /// method via the [`Options`] struct. + pub stage: ShaderStage, + + /// The workgroup size for compute shaders, defaults to `[1; 3]` for + /// compute shaders and `[0; 3]` for non compute shaders. + pub workgroup_size: [u32; 3], + /// Whether or not early fragment tests where requested by the shader. + /// Defaults to `false`. + pub early_fragment_tests: bool, + + /// The shader can request extensions via the + /// `#extension` preprocessor directive, in the directive a behavior + /// parameter is used to control whether the extension should be disabled, + /// warn on usage, enabled if possible or required. + /// + /// This field only stores extensions which were required or requested to + /// be enabled if possible and they are supported. + pub extensions: FastHashSet<String>, +} + +impl ShaderMetadata { + fn reset(&mut self, stage: ShaderStage) { + self.version = 0; + self.profile = Profile::Core; + self.stage = stage; + self.workgroup_size = [u32::from(stage == ShaderStage::Compute); 3]; + self.early_fragment_tests = false; + self.extensions.clear(); + } +} + +impl Default for ShaderMetadata { + fn default() -> Self { + ShaderMetadata { + version: 0, + profile: Profile::Core, + stage: ShaderStage::Vertex, + workgroup_size: [0; 3], + early_fragment_tests: false, + extensions: FastHashSet::default(), + } + } +} + +/// The `Frontend` is the central structure of the GLSL frontend. +/// +/// To instantiate a new `Frontend` the [`Default`] trait is used, so a +/// call to the associated function [`Frontend::default`](Frontend::default) will +/// return a new `Frontend` instance. +/// +/// To parse a shader simply call the [`parse`](Frontend::parse) method with a +/// [`Options`] struct and a [`&str`](str) holding the glsl code. +/// +/// The `Frontend` also provides the [`metadata`](Frontend::metadata) to get some +/// further information about the previously parsed shader, like version and +/// extensions used (see the documentation for +/// [`ShaderMetadata`] to see all the returned information) +/// +/// # Example usage +/// ```rust +/// use naga::ShaderStage; +/// use naga::front::glsl::{Frontend, Options}; +/// +/// let glsl = r#" +/// #version 450 core +/// +/// void main() {} +/// "#; +/// +/// let mut frontend = Frontend::default(); +/// let options = Options::from(ShaderStage::Vertex); +/// frontend.parse(&options, glsl); +/// ``` +/// +/// # Reusability +/// +/// If there's a need to parse more than one shader reusing the same `Frontend` +/// instance may be beneficial since internal allocations will be reused. +/// +/// Calling the [`parse`](Frontend::parse) method multiple times will reset the +/// `Frontend` so no extra care is needed when reusing. +#[derive(Debug, Default)] +pub struct Frontend { + meta: ShaderMetadata, + + lookup_function: FastHashMap<String, FunctionDeclaration>, + lookup_type: FastHashMap<String, Handle<Type>>, + + global_variables: Vec<(String, GlobalLookup)>, + + entry_args: Vec<EntryArg>, + + layouter: Layouter, + + errors: Vec<Error>, +} + +impl Frontend { + fn reset(&mut self, stage: ShaderStage) { + self.meta.reset(stage); + + self.lookup_function.clear(); + self.lookup_type.clear(); + self.global_variables.clear(); + self.entry_args.clear(); + self.layouter.clear(); + } + + /// Parses a shader either outputting a shader [`Module`] or a list of + /// [`Error`]s. + /// + /// Multiple calls using the same `Frontend` and different shaders are supported. + pub fn parse( + &mut self, + options: &Options, + source: &str, + ) -> std::result::Result<Module, ParseError> { + self.reset(options.stage); + + let lexer = lex::Lexer::new(source, &options.defines); + let mut ctx = ParsingContext::new(lexer); + + match ctx.parse(self) { + Ok(module) => { + if self.errors.is_empty() { + Ok(module) + } else { + Err(std::mem::take(&mut self.errors).into()) + } + } + Err(e) => { + self.errors.push(e); + Err(std::mem::take(&mut self.errors).into()) + } + } + } + + /// Returns additional information about the parsed shader which might not + /// be stored in the [`Module`], see the documentation for + /// [`ShaderMetadata`] for more information about the returned data. + /// + /// # Notes + /// + /// Following an unsuccessful parsing the state of the returned information + /// is undefined, it might contain only partial information about the + /// current shader, the previous shader or both. + pub const fn metadata(&self) -> &ShaderMetadata { + &self.meta + } +} diff --git a/third_party/rust/naga/src/front/glsl/offset.rs b/third_party/rust/naga/src/front/glsl/offset.rs new file mode 100644 index 0000000000..c88c46598d --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/offset.rs @@ -0,0 +1,173 @@ +/*! +Module responsible for calculating the offset and span for types. + +There exists two types of layouts std140 and std430 (there's technically +two more layouts, shared and packed. Shared is not supported by spirv. Packed is +implementation dependent and for now it's just implemented as an alias to +std140). + +The OpenGl spec (the layout rules are defined by the OpenGl spec in section +7.6.2.2 as opposed to the GLSL spec) uses the term basic machine units which are +equivalent to bytes. +*/ + +use super::{ + ast::StructLayout, + error::{Error, ErrorKind}, + Span, +}; +use crate::{proc::Alignment, Handle, Scalar, Type, TypeInner, UniqueArena}; + +/// Struct with information needed for defining a struct member. +/// +/// Returned by [`calculate_offset`]. +#[derive(Debug)] +pub struct TypeAlignSpan { + /// The handle to the type, this might be the same handle passed to + /// [`calculate_offset`] or a new such a new array type with a different + /// stride set. + pub ty: Handle<Type>, + /// The alignment required by the type. + pub align: Alignment, + /// The size of the type. + pub span: u32, +} + +/// Returns the type, alignment and span of a struct member according to a [`StructLayout`]. +/// +/// The functions returns a [`TypeAlignSpan`] which has a `ty` member this +/// should be used as the struct member type because for example arrays may have +/// to change the stride and as such need to have a different type. +pub fn calculate_offset( + mut ty: Handle<Type>, + meta: Span, + layout: StructLayout, + types: &mut UniqueArena<Type>, + errors: &mut Vec<Error>, +) -> TypeAlignSpan { + // When using the std430 storage layout, shader storage blocks will be laid out in buffer storage + // identically to uniform and shader storage blocks using the std140 layout, except + // that the base alignment and stride of arrays of scalars and vectors in rule 4 and of + // structures in rule 9 are not rounded up a multiple of the base alignment of a vec4. + + let (align, span) = match types[ty].inner { + // 1. If the member is a scalar consuming N basic machine units, + // the base alignment is N. + TypeInner::Scalar(Scalar { width, .. }) => (Alignment::from_width(width), width as u32), + // 2. If the member is a two- or four-component vector with components + // consuming N basic machine units, the base alignment is 2N or 4N, respectively. + // 3. If the member is a three-component vector with components consuming N + // basic machine units, the base alignment is 4N. + TypeInner::Vector { + size, + scalar: Scalar { width, .. }, + } => ( + Alignment::from(size) * Alignment::from_width(width), + size as u32 * width as u32, + ), + // 4. If the member is an array of scalars or vectors, the base alignment and array + // stride are set to match the base alignment of a single array element, according + // to rules (1), (2), and (3), and rounded up to the base alignment of a vec4. + // TODO: Matrices array + TypeInner::Array { base, size, .. } => { + let info = calculate_offset(base, meta, layout, types, errors); + + let name = types[ty].name.clone(); + + // See comment at the beginning of the function + let (align, stride) = if StructLayout::Std430 == layout { + (info.align, info.align.round_up(info.span)) + } else { + let align = info.align.max(Alignment::MIN_UNIFORM); + (align, align.round_up(info.span)) + }; + + let span = match size { + crate::ArraySize::Constant(size) => size.get() * stride, + crate::ArraySize::Dynamic => stride, + }; + + let ty_span = types.get_span(ty); + ty = types.insert( + Type { + name, + inner: TypeInner::Array { + base: info.ty, + size, + stride, + }, + }, + ty_span, + ); + + (align, span) + } + // 5. If the member is a column-major matrix with C columns and R rows, the + // matrix is stored identically to an array of C column vectors with R + // components each, according to rule (4) + // TODO: Row major matrices + TypeInner::Matrix { + columns, + rows, + scalar, + } => { + let mut align = Alignment::from(rows) * Alignment::from_width(scalar.width); + + // See comment at the beginning of the function + if StructLayout::Std430 != layout { + align = align.max(Alignment::MIN_UNIFORM); + } + + // See comment on the error kind + if StructLayout::Std140 == layout && rows == crate::VectorSize::Bi { + errors.push(Error { + kind: ErrorKind::UnsupportedMatrixTypeInStd140, + meta, + }); + } + + (align, align * columns as u32) + } + TypeInner::Struct { ref members, .. } => { + let mut span = 0; + let mut align = Alignment::ONE; + let mut members = members.clone(); + let name = types[ty].name.clone(); + + for member in members.iter_mut() { + let info = calculate_offset(member.ty, meta, layout, types, errors); + + let member_alignment = info.align; + span = member_alignment.round_up(span); + align = member_alignment.max(align); + + member.ty = info.ty; + member.offset = span; + + span += info.span; + } + + span = align.round_up(span); + + let ty_span = types.get_span(ty); + ty = types.insert( + Type { + name, + inner: TypeInner::Struct { members, span }, + }, + ty_span, + ); + + (align, span) + } + _ => { + errors.push(Error { + kind: ErrorKind::SemanticError("Invalid struct member type".into()), + meta, + }); + (Alignment::ONE, 0) + } + }; + + TypeAlignSpan { ty, align, span } +} diff --git a/third_party/rust/naga/src/front/glsl/parser.rs b/third_party/rust/naga/src/front/glsl/parser.rs new file mode 100644 index 0000000000..851d2e1d79 --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/parser.rs @@ -0,0 +1,431 @@ +use super::{ + ast::{FunctionKind, Profile, TypeQualifiers}, + context::{Context, ExprPos}, + error::ExpectedToken, + error::{Error, ErrorKind}, + lex::{Lexer, LexerResultKind}, + token::{Directive, DirectiveKind}, + token::{Token, TokenValue}, + variables::{GlobalOrConstant, VarDeclaration}, + Frontend, Result, +}; +use crate::{arena::Handle, proc::U32EvalError, Expression, Module, Span, Type}; +use pp_rs::token::{PreprocessorError, Token as PPToken, TokenValue as PPTokenValue}; +use std::iter::Peekable; + +mod declarations; +mod expressions; +mod functions; +mod types; + +pub struct ParsingContext<'source> { + lexer: Peekable<Lexer<'source>>, + /// Used to store tokens already consumed by the parser but that need to be backtracked + backtracked_token: Option<Token>, + last_meta: Span, +} + +impl<'source> ParsingContext<'source> { + pub fn new(lexer: Lexer<'source>) -> Self { + ParsingContext { + lexer: lexer.peekable(), + backtracked_token: None, + last_meta: Span::default(), + } + } + + /// Helper method for backtracking from a consumed token + /// + /// This method should always be used instead of assigning to `backtracked_token` since + /// it validates that backtracking hasn't occurred more than one time in a row + /// + /// # Panics + /// - If the parser already backtracked without bumping in between + pub fn backtrack(&mut self, token: Token) -> Result<()> { + // This should never happen + if let Some(ref prev_token) = self.backtracked_token { + return Err(Error { + kind: ErrorKind::InternalError("The parser tried to backtrack twice in a row"), + meta: prev_token.meta, + }); + } + + self.backtracked_token = Some(token); + + Ok(()) + } + + pub fn expect_ident(&mut self, frontend: &mut Frontend) -> Result<(String, Span)> { + let token = self.bump(frontend)?; + + match token.value { + TokenValue::Identifier(name) => Ok((name, token.meta)), + _ => Err(Error { + kind: ErrorKind::InvalidToken(token.value, vec![ExpectedToken::Identifier]), + meta: token.meta, + }), + } + } + + pub fn expect(&mut self, frontend: &mut Frontend, value: TokenValue) -> Result<Token> { + let token = self.bump(frontend)?; + + if token.value != value { + Err(Error { + kind: ErrorKind::InvalidToken(token.value, vec![value.into()]), + meta: token.meta, + }) + } else { + Ok(token) + } + } + + pub fn next(&mut self, frontend: &mut Frontend) -> Option<Token> { + loop { + if let Some(token) = self.backtracked_token.take() { + self.last_meta = token.meta; + break Some(token); + } + + let res = self.lexer.next()?; + + match res.kind { + LexerResultKind::Token(token) => { + self.last_meta = token.meta; + break Some(token); + } + LexerResultKind::Directive(directive) => { + frontend.handle_directive(directive, res.meta) + } + LexerResultKind::Error(error) => frontend.errors.push(Error { + kind: ErrorKind::PreprocessorError(error), + meta: res.meta, + }), + } + } + } + + pub fn bump(&mut self, frontend: &mut Frontend) -> Result<Token> { + self.next(frontend).ok_or(Error { + kind: ErrorKind::EndOfFile, + meta: self.last_meta, + }) + } + + /// Returns None on the end of the file rather than an error like other methods + pub fn bump_if(&mut self, frontend: &mut Frontend, value: TokenValue) -> Option<Token> { + if self.peek(frontend).filter(|t| t.value == value).is_some() { + self.bump(frontend).ok() + } else { + None + } + } + + pub fn peek(&mut self, frontend: &mut Frontend) -> Option<&Token> { + loop { + if let Some(ref token) = self.backtracked_token { + break Some(token); + } + + match self.lexer.peek()?.kind { + LexerResultKind::Token(_) => { + let res = self.lexer.peek()?; + + match res.kind { + LexerResultKind::Token(ref token) => break Some(token), + _ => unreachable!(), + } + } + LexerResultKind::Error(_) | LexerResultKind::Directive(_) => { + let res = self.lexer.next()?; + + match res.kind { + LexerResultKind::Directive(directive) => { + frontend.handle_directive(directive, res.meta) + } + LexerResultKind::Error(error) => frontend.errors.push(Error { + kind: ErrorKind::PreprocessorError(error), + meta: res.meta, + }), + LexerResultKind::Token(_) => unreachable!(), + } + } + } + } + } + + pub fn expect_peek(&mut self, frontend: &mut Frontend) -> Result<&Token> { + let meta = self.last_meta; + self.peek(frontend).ok_or(Error { + kind: ErrorKind::EndOfFile, + meta, + }) + } + + pub fn parse(&mut self, frontend: &mut Frontend) -> Result<Module> { + let mut module = Module::default(); + + // Body and expression arena for global initialization + let mut ctx = Context::new(frontend, &mut module, false)?; + + while self.peek(frontend).is_some() { + self.parse_external_declaration(frontend, &mut ctx)?; + } + + // Add an `EntryPoint` to `parser.module` for `main`, if a + // suitable overload exists. Error out if we can't find one. + if let Some(declaration) = frontend.lookup_function.get("main") { + for decl in declaration.overloads.iter() { + if let FunctionKind::Call(handle) = decl.kind { + if decl.defined && decl.parameters.is_empty() { + frontend.add_entry_point(handle, ctx)?; + return Ok(module); + } + } + } + } + + Err(Error { + kind: ErrorKind::SemanticError("Missing entry point".into()), + meta: Span::default(), + }) + } + + fn parse_uint_constant( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + ) -> Result<(u32, Span)> { + let (const_expr, meta) = self.parse_constant_expression(frontend, ctx.module)?; + + let res = ctx.module.to_ctx().eval_expr_to_u32(const_expr); + + let int = match res { + Ok(value) => Ok(value), + Err(U32EvalError::Negative) => Err(Error { + kind: ErrorKind::SemanticError("int constant overflows".into()), + meta, + }), + Err(U32EvalError::NonConst) => Err(Error { + kind: ErrorKind::SemanticError("Expected a uint constant".into()), + meta, + }), + }?; + + Ok((int, meta)) + } + + fn parse_constant_expression( + &mut self, + frontend: &mut Frontend, + module: &mut Module, + ) -> Result<(Handle<Expression>, Span)> { + let mut ctx = Context::new(frontend, module, true)?; + + let mut stmt_ctx = ctx.stmt_ctx(); + let expr = self.parse_conditional(frontend, &mut ctx, &mut stmt_ctx, None)?; + let (root, meta) = ctx.lower_expect(stmt_ctx, frontend, expr, ExprPos::Rhs)?; + + Ok((root, meta)) + } +} + +impl Frontend { + fn handle_directive(&mut self, directive: Directive, meta: Span) { + let mut tokens = directive.tokens.into_iter(); + + match directive.kind { + DirectiveKind::Version { is_first_directive } => { + if !is_first_directive { + self.errors.push(Error { + kind: ErrorKind::SemanticError( + "#version must occur first in shader".into(), + ), + meta, + }) + } + + match tokens.next() { + Some(PPToken { + value: PPTokenValue::Integer(int), + location, + }) => match int.value { + 440 | 450 | 460 => self.meta.version = int.value as u16, + _ => self.errors.push(Error { + kind: ErrorKind::InvalidVersion(int.value), + meta: location.into(), + }), + }, + Some(PPToken { value, location }) => self.errors.push(Error { + kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken( + value, + )), + meta: location.into(), + }), + None => self.errors.push(Error { + kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedNewLine), + meta, + }), + }; + + match tokens.next() { + Some(PPToken { + value: PPTokenValue::Ident(name), + location, + }) => match name.as_str() { + "core" => self.meta.profile = Profile::Core, + _ => self.errors.push(Error { + kind: ErrorKind::InvalidProfile(name), + meta: location.into(), + }), + }, + Some(PPToken { value, location }) => self.errors.push(Error { + kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken( + value, + )), + meta: location.into(), + }), + None => {} + }; + + if let Some(PPToken { value, location }) = tokens.next() { + self.errors.push(Error { + kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken( + value, + )), + meta: location.into(), + }) + } + } + DirectiveKind::Extension => { + // TODO: Proper extension handling + // - Checking for extension support in the compiler + // - Handle behaviors such as warn + // - Handle the all extension + let name = match tokens.next() { + Some(PPToken { + value: PPTokenValue::Ident(name), + .. + }) => Some(name), + Some(PPToken { value, location }) => { + self.errors.push(Error { + kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken( + value, + )), + meta: location.into(), + }); + + None + } + None => { + self.errors.push(Error { + kind: ErrorKind::PreprocessorError( + PreprocessorError::UnexpectedNewLine, + ), + meta, + }); + + None + } + }; + + match tokens.next() { + Some(PPToken { + value: PPTokenValue::Punct(pp_rs::token::Punct::Colon), + .. + }) => {} + Some(PPToken { value, location }) => self.errors.push(Error { + kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken( + value, + )), + meta: location.into(), + }), + None => self.errors.push(Error { + kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedNewLine), + meta, + }), + }; + + match tokens.next() { + Some(PPToken { + value: PPTokenValue::Ident(behavior), + location, + }) => match behavior.as_str() { + "require" | "enable" | "warn" | "disable" => { + if let Some(name) = name { + self.meta.extensions.insert(name); + } + } + _ => self.errors.push(Error { + kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken( + PPTokenValue::Ident(behavior), + )), + meta: location.into(), + }), + }, + Some(PPToken { value, location }) => self.errors.push(Error { + kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken( + value, + )), + meta: location.into(), + }), + None => self.errors.push(Error { + kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedNewLine), + meta, + }), + } + + if let Some(PPToken { value, location }) = tokens.next() { + self.errors.push(Error { + kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken( + value, + )), + meta: location.into(), + }) + } + } + DirectiveKind::Pragma => { + // TODO: handle some common pragmas? + } + } + } +} + +pub struct DeclarationContext<'ctx, 'qualifiers, 'a> { + qualifiers: TypeQualifiers<'qualifiers>, + /// Indicates a global declaration + external: bool, + is_inside_loop: bool, + ctx: &'ctx mut Context<'a>, +} + +impl<'ctx, 'qualifiers, 'a> DeclarationContext<'ctx, 'qualifiers, 'a> { + fn add_var( + &mut self, + frontend: &mut Frontend, + ty: Handle<Type>, + name: String, + init: Option<Handle<Expression>>, + meta: Span, + ) -> Result<Handle<Expression>> { + let decl = VarDeclaration { + qualifiers: &mut self.qualifiers, + ty, + name: Some(name), + init, + meta, + }; + + match self.external { + true => { + let global = frontend.add_global_var(self.ctx, decl)?; + let expr = match global { + GlobalOrConstant::Global(handle) => Expression::GlobalVariable(handle), + GlobalOrConstant::Constant(handle) => Expression::Constant(handle), + }; + Ok(self.ctx.add_expression(expr, meta)?) + } + false => frontend.add_local_var(self.ctx, decl), + } + } +} diff --git a/third_party/rust/naga/src/front/glsl/parser/declarations.rs b/third_party/rust/naga/src/front/glsl/parser/declarations.rs new file mode 100644 index 0000000000..f5e38fb016 --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/parser/declarations.rs @@ -0,0 +1,677 @@ +use crate::{ + front::glsl::{ + ast::{ + GlobalLookup, GlobalLookupKind, Precision, QualifierKey, QualifierValue, + StorageQualifier, StructLayout, TypeQualifiers, + }, + context::{Context, ExprPos}, + error::ExpectedToken, + offset, + token::{Token, TokenValue}, + types::scalar_components, + variables::{GlobalOrConstant, VarDeclaration}, + Error, ErrorKind, Frontend, Span, + }, + proc::Alignment, + AddressSpace, Expression, FunctionResult, Handle, Scalar, ScalarKind, Statement, StructMember, + Type, TypeInner, +}; + +use super::{DeclarationContext, ParsingContext, Result}; + +/// Helper method used to retrieve the child type of `ty` at +/// index `i`. +/// +/// # Note +/// +/// Does not check if the index is valid and returns the same type +/// when indexing out-of-bounds a struct or indexing a non indexable +/// type. +fn element_or_member_type( + ty: Handle<Type>, + i: usize, + types: &mut crate::UniqueArena<Type>, +) -> Handle<Type> { + match types[ty].inner { + // The child type of a vector is a scalar of the same kind and width + TypeInner::Vector { scalar, .. } => types.insert( + Type { + name: None, + inner: TypeInner::Scalar(scalar), + }, + Default::default(), + ), + // The child type of a matrix is a vector of floats with the same + // width and the size of the matrix rows. + TypeInner::Matrix { rows, scalar, .. } => types.insert( + Type { + name: None, + inner: TypeInner::Vector { size: rows, scalar }, + }, + Default::default(), + ), + // The child type of an array is the base type of the array + TypeInner::Array { base, .. } => base, + // The child type of a struct at index `i` is the type of it's + // member at that same index. + // + // In case the index is out of bounds the same type is returned + TypeInner::Struct { ref members, .. } => { + members.get(i).map(|member| member.ty).unwrap_or(ty) + } + // The type isn't indexable, the same type is returned + _ => ty, + } +} + +impl<'source> ParsingContext<'source> { + pub fn parse_external_declaration( + &mut self, + frontend: &mut Frontend, + global_ctx: &mut Context, + ) -> Result<()> { + if self + .parse_declaration(frontend, global_ctx, true, false)? + .is_none() + { + let token = self.bump(frontend)?; + match token.value { + TokenValue::Semicolon if frontend.meta.version == 460 => Ok(()), + _ => { + let expected = match frontend.meta.version { + 460 => vec![TokenValue::Semicolon.into(), ExpectedToken::Eof], + _ => vec![ExpectedToken::Eof], + }; + Err(Error { + kind: ErrorKind::InvalidToken(token.value, expected), + meta: token.meta, + }) + } + } + } else { + Ok(()) + } + } + + pub fn parse_initializer( + &mut self, + frontend: &mut Frontend, + ty: Handle<Type>, + ctx: &mut Context, + ) -> Result<(Handle<Expression>, Span)> { + // initializer: + // assignment_expression + // LEFT_BRACE initializer_list RIGHT_BRACE + // LEFT_BRACE initializer_list COMMA RIGHT_BRACE + // + // initializer_list: + // initializer + // initializer_list COMMA initializer + if let Some(Token { mut meta, .. }) = self.bump_if(frontend, TokenValue::LeftBrace) { + // initializer_list + let mut components = Vec::new(); + loop { + // The type expected to be parsed inside the initializer list + let new_ty = element_or_member_type(ty, components.len(), &mut ctx.module.types); + + components.push(self.parse_initializer(frontend, new_ty, ctx)?.0); + + let token = self.bump(frontend)?; + match token.value { + TokenValue::Comma => { + if let Some(Token { meta: end_meta, .. }) = + self.bump_if(frontend, TokenValue::RightBrace) + { + meta.subsume(end_meta); + break; + } + } + TokenValue::RightBrace => { + meta.subsume(token.meta); + break; + } + _ => { + return Err(Error { + kind: ErrorKind::InvalidToken( + token.value, + vec![TokenValue::Comma.into(), TokenValue::RightBrace.into()], + ), + meta: token.meta, + }) + } + } + } + + Ok(( + ctx.add_expression(Expression::Compose { ty, components }, meta)?, + meta, + )) + } else { + let mut stmt = ctx.stmt_ctx(); + let expr = self.parse_assignment(frontend, ctx, &mut stmt)?; + let (mut init, init_meta) = ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs)?; + + let scalar_components = scalar_components(&ctx.module.types[ty].inner); + if let Some(scalar) = scalar_components { + ctx.implicit_conversion(&mut init, init_meta, scalar)?; + } + + Ok((init, init_meta)) + } + } + + // Note: caller preparsed the type and qualifiers + // Note: caller skips this if the fallthrough token is not expected to be consumed here so this + // produced Error::InvalidToken if it isn't consumed + pub fn parse_init_declarator_list( + &mut self, + frontend: &mut Frontend, + mut ty: Handle<Type>, + ctx: &mut DeclarationContext, + ) -> Result<()> { + // init_declarator_list: + // single_declaration + // init_declarator_list COMMA IDENTIFIER + // init_declarator_list COMMA IDENTIFIER array_specifier + // init_declarator_list COMMA IDENTIFIER array_specifier EQUAL initializer + // init_declarator_list COMMA IDENTIFIER EQUAL initializer + // + // single_declaration: + // fully_specified_type + // fully_specified_type IDENTIFIER + // fully_specified_type IDENTIFIER array_specifier + // fully_specified_type IDENTIFIER array_specifier EQUAL initializer + // fully_specified_type IDENTIFIER EQUAL initializer + + // Consume any leading comma, e.g. this is valid: `float, a=1;` + if self + .peek(frontend) + .map_or(false, |t| t.value == TokenValue::Comma) + { + self.next(frontend); + } + + loop { + let token = self.bump(frontend)?; + let name = match token.value { + TokenValue::Semicolon => break, + TokenValue::Identifier(name) => name, + _ => { + return Err(Error { + kind: ErrorKind::InvalidToken( + token.value, + vec![ExpectedToken::Identifier, TokenValue::Semicolon.into()], + ), + meta: token.meta, + }) + } + }; + let mut meta = token.meta; + + // array_specifier + // array_specifier EQUAL initializer + // EQUAL initializer + + // parse an array specifier if it exists + // NOTE: unlike other parse methods this one doesn't expect an array specifier and + // returns Ok(None) rather than an error if there is not one + self.parse_array_specifier(frontend, ctx.ctx, &mut meta, &mut ty)?; + + let is_global_const = + ctx.qualifiers.storage.0 == StorageQualifier::Const && ctx.external; + + let init = self + .bump_if(frontend, TokenValue::Assign) + .map::<Result<_>, _>(|_| { + let prev_const = ctx.ctx.is_const; + ctx.ctx.is_const = is_global_const; + + let (mut expr, init_meta) = self.parse_initializer(frontend, ty, ctx.ctx)?; + + let scalar_components = scalar_components(&ctx.ctx.module.types[ty].inner); + if let Some(scalar) = scalar_components { + ctx.ctx.implicit_conversion(&mut expr, init_meta, scalar)?; + } + + ctx.ctx.is_const = prev_const; + + meta.subsume(init_meta); + + Ok(expr) + }) + .transpose()?; + + let decl_initializer; + let late_initializer; + if is_global_const { + decl_initializer = init; + late_initializer = None; + } else if ctx.external { + decl_initializer = + init.and_then(|expr| ctx.ctx.lift_up_const_expression(expr).ok()); + late_initializer = None; + } else if let Some(init) = init { + if ctx.is_inside_loop || !ctx.ctx.expression_constness.is_const(init) { + decl_initializer = None; + late_initializer = Some(init); + } else { + decl_initializer = Some(init); + late_initializer = None; + } + } else { + decl_initializer = None; + late_initializer = None; + }; + + let pointer = ctx.add_var(frontend, ty, name, decl_initializer, meta)?; + + if let Some(value) = late_initializer { + ctx.ctx.emit_restart(); + ctx.ctx.body.push(Statement::Store { pointer, value }, meta); + } + + let token = self.bump(frontend)?; + match token.value { + TokenValue::Semicolon => break, + TokenValue::Comma => {} + _ => { + return Err(Error { + kind: ErrorKind::InvalidToken( + token.value, + vec![TokenValue::Comma.into(), TokenValue::Semicolon.into()], + ), + meta: token.meta, + }) + } + } + } + + Ok(()) + } + + /// `external` whether or not we are in a global or local context + pub fn parse_declaration( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + external: bool, + is_inside_loop: bool, + ) -> Result<Option<Span>> { + //declaration: + // function_prototype SEMICOLON + // + // init_declarator_list SEMICOLON + // PRECISION precision_qualifier type_specifier SEMICOLON + // + // type_qualifier IDENTIFIER LEFT_BRACE struct_declaration_list RIGHT_BRACE SEMICOLON + // type_qualifier IDENTIFIER LEFT_BRACE struct_declaration_list RIGHT_BRACE IDENTIFIER SEMICOLON + // type_qualifier IDENTIFIER LEFT_BRACE struct_declaration_list RIGHT_BRACE IDENTIFIER array_specifier SEMICOLON + // type_qualifier SEMICOLON type_qualifier IDENTIFIER SEMICOLON + // type_qualifier IDENTIFIER identifier_list SEMICOLON + + if self.peek_type_qualifier(frontend) || self.peek_type_name(frontend) { + let mut qualifiers = self.parse_type_qualifiers(frontend, ctx)?; + + if self.peek_type_name(frontend) { + // This branch handles variables and function prototypes and if + // external is true also function definitions + let (ty, mut meta) = self.parse_type(frontend, ctx)?; + + let token = self.bump(frontend)?; + let token_fallthrough = match token.value { + TokenValue::Identifier(name) => match self.expect_peek(frontend)?.value { + TokenValue::LeftParen => { + // This branch handles function definition and prototypes + self.bump(frontend)?; + + let result = ty.map(|ty| FunctionResult { ty, binding: None }); + + let mut context = Context::new(frontend, ctx.module, false)?; + + self.parse_function_args(frontend, &mut context)?; + + let end_meta = self.expect(frontend, TokenValue::RightParen)?.meta; + meta.subsume(end_meta); + + let token = self.bump(frontend)?; + return match token.value { + TokenValue::Semicolon => { + // This branch handles function prototypes + frontend.add_prototype(context, name, result, meta); + + Ok(Some(meta)) + } + TokenValue::LeftBrace if external => { + // This branch handles function definitions + // as you can see by the guard this branch + // only happens if external is also true + + // parse the body + self.parse_compound_statement( + token.meta, + frontend, + &mut context, + &mut None, + false, + )?; + + frontend.add_function(context, name, result, meta); + + Ok(Some(meta)) + } + _ if external => Err(Error { + kind: ErrorKind::InvalidToken( + token.value, + vec![ + TokenValue::LeftBrace.into(), + TokenValue::Semicolon.into(), + ], + ), + meta: token.meta, + }), + _ => Err(Error { + kind: ErrorKind::InvalidToken( + token.value, + vec![TokenValue::Semicolon.into()], + ), + meta: token.meta, + }), + }; + } + // Pass the token to the init_declarator_list parser + _ => Token { + value: TokenValue::Identifier(name), + meta: token.meta, + }, + }, + // Pass the token to the init_declarator_list parser + _ => token, + }; + + // If program execution has reached here then this will be a + // init_declarator_list + // token_fallthrough will have a token that was already bumped + if let Some(ty) = ty { + let mut ctx = DeclarationContext { + qualifiers, + external, + is_inside_loop, + ctx, + }; + + self.backtrack(token_fallthrough)?; + self.parse_init_declarator_list(frontend, ty, &mut ctx)?; + } else { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError("Declaration cannot have void type".into()), + meta, + }) + } + + Ok(Some(meta)) + } else { + // This branch handles struct definitions and modifiers like + // ```glsl + // layout(early_fragment_tests); + // ``` + let token = self.bump(frontend)?; + match token.value { + TokenValue::Identifier(ty_name) => { + if self.bump_if(frontend, TokenValue::LeftBrace).is_some() { + self.parse_block_declaration( + frontend, + ctx, + &mut qualifiers, + ty_name, + token.meta, + ) + .map(Some) + } else { + if qualifiers.invariant.take().is_some() { + frontend.make_variable_invariant(ctx, &ty_name, token.meta)?; + + qualifiers.unused_errors(&mut frontend.errors); + self.expect(frontend, TokenValue::Semicolon)?; + return Ok(Some(qualifiers.span)); + } + + //TODO: declaration + // type_qualifier IDENTIFIER SEMICOLON + // type_qualifier IDENTIFIER identifier_list SEMICOLON + Err(Error { + kind: ErrorKind::NotImplemented("variable qualifier"), + meta: token.meta, + }) + } + } + TokenValue::Semicolon => { + if let Some(value) = + qualifiers.uint_layout_qualifier("local_size_x", &mut frontend.errors) + { + frontend.meta.workgroup_size[0] = value; + } + if let Some(value) = + qualifiers.uint_layout_qualifier("local_size_y", &mut frontend.errors) + { + frontend.meta.workgroup_size[1] = value; + } + if let Some(value) = + qualifiers.uint_layout_qualifier("local_size_z", &mut frontend.errors) + { + frontend.meta.workgroup_size[2] = value; + } + + frontend.meta.early_fragment_tests |= qualifiers + .none_layout_qualifier("early_fragment_tests", &mut frontend.errors); + + qualifiers.unused_errors(&mut frontend.errors); + + Ok(Some(qualifiers.span)) + } + _ => Err(Error { + kind: ErrorKind::InvalidToken( + token.value, + vec![ExpectedToken::Identifier, TokenValue::Semicolon.into()], + ), + meta: token.meta, + }), + } + } + } else { + match self.peek(frontend).map(|t| &t.value) { + Some(&TokenValue::Precision) => { + // PRECISION precision_qualifier type_specifier SEMICOLON + self.bump(frontend)?; + + let token = self.bump(frontend)?; + let _ = match token.value { + TokenValue::PrecisionQualifier(p) => p, + _ => { + return Err(Error { + kind: ErrorKind::InvalidToken( + token.value, + vec![ + TokenValue::PrecisionQualifier(Precision::High).into(), + TokenValue::PrecisionQualifier(Precision::Medium).into(), + TokenValue::PrecisionQualifier(Precision::Low).into(), + ], + ), + meta: token.meta, + }) + } + }; + + let (ty, meta) = self.parse_type_non_void(frontend, ctx)?; + + match ctx.module.types[ty].inner { + TypeInner::Scalar(Scalar { + kind: ScalarKind::Float | ScalarKind::Sint, + .. + }) => {} + _ => frontend.errors.push(Error { + kind: ErrorKind::SemanticError( + "Precision statement can only work on floats and ints".into(), + ), + meta, + }), + } + + self.expect(frontend, TokenValue::Semicolon)?; + + Ok(Some(meta)) + } + _ => Ok(None), + } + } + } + + pub fn parse_block_declaration( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + qualifiers: &mut TypeQualifiers, + ty_name: String, + mut meta: Span, + ) -> Result<Span> { + let layout = match qualifiers.layout_qualifiers.remove(&QualifierKey::Layout) { + Some((QualifierValue::Layout(l), _)) => l, + None => { + if let StorageQualifier::AddressSpace(AddressSpace::Storage { .. }) = + qualifiers.storage.0 + { + StructLayout::Std430 + } else { + StructLayout::Std140 + } + } + _ => unreachable!(), + }; + + let mut members = Vec::new(); + let span = self.parse_struct_declaration_list(frontend, ctx, &mut members, layout)?; + self.expect(frontend, TokenValue::RightBrace)?; + + let mut ty = ctx.module.types.insert( + Type { + name: Some(ty_name), + inner: TypeInner::Struct { + members: members.clone(), + span, + }, + }, + Default::default(), + ); + + let token = self.bump(frontend)?; + let name = match token.value { + TokenValue::Semicolon => None, + TokenValue::Identifier(name) => { + self.parse_array_specifier(frontend, ctx, &mut meta, &mut ty)?; + + self.expect(frontend, TokenValue::Semicolon)?; + + Some(name) + } + _ => { + return Err(Error { + kind: ErrorKind::InvalidToken( + token.value, + vec![ExpectedToken::Identifier, TokenValue::Semicolon.into()], + ), + meta: token.meta, + }) + } + }; + + let global = frontend.add_global_var( + ctx, + VarDeclaration { + qualifiers, + ty, + name, + init: None, + meta, + }, + )?; + + for (i, k, ty) in members.into_iter().enumerate().filter_map(|(i, m)| { + let ty = m.ty; + m.name.map(|s| (i as u32, s, ty)) + }) { + let lookup = GlobalLookup { + kind: match global { + GlobalOrConstant::Global(handle) => GlobalLookupKind::BlockSelect(handle, i), + GlobalOrConstant::Constant(handle) => GlobalLookupKind::Constant(handle, ty), + }, + entry_arg: None, + mutable: true, + }; + ctx.add_global(&k, lookup)?; + + frontend.global_variables.push((k, lookup)); + } + + Ok(meta) + } + + // TODO: Accept layout arguments + pub fn parse_struct_declaration_list( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + members: &mut Vec<StructMember>, + layout: StructLayout, + ) -> Result<u32> { + let mut span = 0; + let mut align = Alignment::ONE; + + loop { + // TODO: type_qualifier + + let (base_ty, mut meta) = self.parse_type_non_void(frontend, ctx)?; + + loop { + let (name, name_meta) = self.expect_ident(frontend)?; + let mut ty = base_ty; + self.parse_array_specifier(frontend, ctx, &mut meta, &mut ty)?; + + meta.subsume(name_meta); + + let info = offset::calculate_offset( + ty, + meta, + layout, + &mut ctx.module.types, + &mut frontend.errors, + ); + + let member_alignment = info.align; + span = member_alignment.round_up(span); + align = member_alignment.max(align); + + members.push(StructMember { + name: Some(name), + ty: info.ty, + binding: None, + offset: span, + }); + + span += info.span; + + if self.bump_if(frontend, TokenValue::Comma).is_none() { + break; + } + } + + self.expect(frontend, TokenValue::Semicolon)?; + + if let TokenValue::RightBrace = self.expect_peek(frontend)?.value { + break; + } + } + + span = align.round_up(span); + + Ok(span) + } +} diff --git a/third_party/rust/naga/src/front/glsl/parser/expressions.rs b/third_party/rust/naga/src/front/glsl/parser/expressions.rs new file mode 100644 index 0000000000..1b8febce90 --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/parser/expressions.rs @@ -0,0 +1,542 @@ +use std::num::NonZeroU32; + +use crate::{ + front::glsl::{ + ast::{FunctionCall, FunctionCallKind, HirExpr, HirExprKind}, + context::{Context, StmtContext}, + error::{ErrorKind, ExpectedToken}, + parser::ParsingContext, + token::{Token, TokenValue}, + Error, Frontend, Result, Span, + }, + ArraySize, BinaryOperator, Handle, Literal, Type, TypeInner, UnaryOperator, +}; + +impl<'source> ParsingContext<'source> { + pub fn parse_primary( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + stmt: &mut StmtContext, + ) -> Result<Handle<HirExpr>> { + let mut token = self.bump(frontend)?; + + let literal = match token.value { + TokenValue::IntConstant(int) => { + if int.width != 32 { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError("Unsupported non-32bit integer".into()), + meta: token.meta, + }); + } + if int.signed { + Literal::I32(int.value as i32) + } else { + Literal::U32(int.value as u32) + } + } + TokenValue::FloatConstant(float) => { + if float.width != 32 { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError("Unsupported floating-point value (expected single-precision floating-point number)".into()), + meta: token.meta, + }); + } + Literal::F32(float.value) + } + TokenValue::BoolConstant(value) => Literal::Bool(value), + TokenValue::LeftParen => { + let expr = self.parse_expression(frontend, ctx, stmt)?; + let meta = self.expect(frontend, TokenValue::RightParen)?.meta; + + token.meta.subsume(meta); + + return Ok(expr); + } + _ => { + return Err(Error { + kind: ErrorKind::InvalidToken( + token.value, + vec![ + TokenValue::LeftParen.into(), + ExpectedToken::IntLiteral, + ExpectedToken::FloatLiteral, + ExpectedToken::BoolLiteral, + ], + ), + meta: token.meta, + }); + } + }; + + Ok(stmt.hir_exprs.append( + HirExpr { + kind: HirExprKind::Literal(literal), + meta: token.meta, + }, + Default::default(), + )) + } + + pub fn parse_function_call_args( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + stmt: &mut StmtContext, + meta: &mut Span, + ) -> Result<Vec<Handle<HirExpr>>> { + let mut args = Vec::new(); + if let Some(token) = self.bump_if(frontend, TokenValue::RightParen) { + meta.subsume(token.meta); + } else { + loop { + args.push(self.parse_assignment(frontend, ctx, stmt)?); + + let token = self.bump(frontend)?; + match token.value { + TokenValue::Comma => {} + TokenValue::RightParen => { + meta.subsume(token.meta); + break; + } + _ => { + return Err(Error { + kind: ErrorKind::InvalidToken( + token.value, + vec![TokenValue::Comma.into(), TokenValue::RightParen.into()], + ), + meta: token.meta, + }); + } + } + } + } + + Ok(args) + } + + pub fn parse_postfix( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + stmt: &mut StmtContext, + ) -> Result<Handle<HirExpr>> { + let mut base = if self.peek_type_name(frontend) { + let (mut handle, mut meta) = self.parse_type_non_void(frontend, ctx)?; + + self.expect(frontend, TokenValue::LeftParen)?; + let args = self.parse_function_call_args(frontend, ctx, stmt, &mut meta)?; + + if let TypeInner::Array { + size: ArraySize::Dynamic, + stride, + base, + } = ctx.module.types[handle].inner + { + let span = ctx.module.types.get_span(handle); + + let size = u32::try_from(args.len()) + .ok() + .and_then(NonZeroU32::new) + .ok_or(Error { + kind: ErrorKind::SemanticError( + "There must be at least one argument".into(), + ), + meta, + })?; + + handle = ctx.module.types.insert( + Type { + name: None, + inner: TypeInner::Array { + stride, + base, + size: ArraySize::Constant(size), + }, + }, + span, + ) + } + + stmt.hir_exprs.append( + HirExpr { + kind: HirExprKind::Call(FunctionCall { + kind: FunctionCallKind::TypeConstructor(handle), + args, + }), + meta, + }, + Default::default(), + ) + } else if let TokenValue::Identifier(_) = self.expect_peek(frontend)?.value { + let (name, mut meta) = self.expect_ident(frontend)?; + + let expr = if self.bump_if(frontend, TokenValue::LeftParen).is_some() { + let args = self.parse_function_call_args(frontend, ctx, stmt, &mut meta)?; + + let kind = match frontend.lookup_type.get(&name) { + Some(ty) => FunctionCallKind::TypeConstructor(*ty), + None => FunctionCallKind::Function(name), + }; + + HirExpr { + kind: HirExprKind::Call(FunctionCall { kind, args }), + meta, + } + } else { + let var = match frontend.lookup_variable(ctx, &name, meta)? { + Some(var) => var, + None => { + return Err(Error { + kind: ErrorKind::UnknownVariable(name), + meta, + }) + } + }; + + HirExpr { + kind: HirExprKind::Variable(var), + meta, + } + }; + + stmt.hir_exprs.append(expr, Default::default()) + } else { + self.parse_primary(frontend, ctx, stmt)? + }; + + while let TokenValue::LeftBracket + | TokenValue::Dot + | TokenValue::Increment + | TokenValue::Decrement = self.expect_peek(frontend)?.value + { + let Token { value, mut meta } = self.bump(frontend)?; + + match value { + TokenValue::LeftBracket => { + let index = self.parse_expression(frontend, ctx, stmt)?; + let end_meta = self.expect(frontend, TokenValue::RightBracket)?.meta; + + meta.subsume(end_meta); + base = stmt.hir_exprs.append( + HirExpr { + kind: HirExprKind::Access { base, index }, + meta, + }, + Default::default(), + ) + } + TokenValue::Dot => { + let (field, end_meta) = self.expect_ident(frontend)?; + + if self.bump_if(frontend, TokenValue::LeftParen).is_some() { + let args = self.parse_function_call_args(frontend, ctx, stmt, &mut meta)?; + + base = stmt.hir_exprs.append( + HirExpr { + kind: HirExprKind::Method { + expr: base, + name: field, + args, + }, + meta, + }, + Default::default(), + ); + continue; + } + + meta.subsume(end_meta); + base = stmt.hir_exprs.append( + HirExpr { + kind: HirExprKind::Select { base, field }, + meta, + }, + Default::default(), + ) + } + TokenValue::Increment | TokenValue::Decrement => { + base = stmt.hir_exprs.append( + HirExpr { + kind: HirExprKind::PrePostfix { + op: match value { + TokenValue::Increment => crate::BinaryOperator::Add, + _ => crate::BinaryOperator::Subtract, + }, + postfix: true, + expr: base, + }, + meta, + }, + Default::default(), + ) + } + _ => unreachable!(), + } + } + + Ok(base) + } + + pub fn parse_unary( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + stmt: &mut StmtContext, + ) -> Result<Handle<HirExpr>> { + Ok(match self.expect_peek(frontend)?.value { + TokenValue::Plus | TokenValue::Dash | TokenValue::Bang | TokenValue::Tilde => { + let Token { value, mut meta } = self.bump(frontend)?; + + let expr = self.parse_unary(frontend, ctx, stmt)?; + let end_meta = stmt.hir_exprs[expr].meta; + + let kind = match value { + TokenValue::Dash => HirExprKind::Unary { + op: UnaryOperator::Negate, + expr, + }, + TokenValue::Bang => HirExprKind::Unary { + op: UnaryOperator::LogicalNot, + expr, + }, + TokenValue::Tilde => HirExprKind::Unary { + op: UnaryOperator::BitwiseNot, + expr, + }, + _ => return Ok(expr), + }; + + meta.subsume(end_meta); + stmt.hir_exprs + .append(HirExpr { kind, meta }, Default::default()) + } + TokenValue::Increment | TokenValue::Decrement => { + let Token { value, meta } = self.bump(frontend)?; + + let expr = self.parse_unary(frontend, ctx, stmt)?; + + stmt.hir_exprs.append( + HirExpr { + kind: HirExprKind::PrePostfix { + op: match value { + TokenValue::Increment => crate::BinaryOperator::Add, + _ => crate::BinaryOperator::Subtract, + }, + postfix: false, + expr, + }, + meta, + }, + Default::default(), + ) + } + _ => self.parse_postfix(frontend, ctx, stmt)?, + }) + } + + pub fn parse_binary( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + stmt: &mut StmtContext, + passthrough: Option<Handle<HirExpr>>, + min_bp: u8, + ) -> Result<Handle<HirExpr>> { + let mut left = passthrough + .ok_or(ErrorKind::EndOfFile /* Dummy error */) + .or_else(|_| self.parse_unary(frontend, ctx, stmt))?; + let mut meta = stmt.hir_exprs[left].meta; + + while let Some((l_bp, r_bp)) = binding_power(&self.expect_peek(frontend)?.value) { + if l_bp < min_bp { + break; + } + + let Token { value, .. } = self.bump(frontend)?; + + let right = self.parse_binary(frontend, ctx, stmt, None, r_bp)?; + let end_meta = stmt.hir_exprs[right].meta; + + meta.subsume(end_meta); + left = stmt.hir_exprs.append( + HirExpr { + kind: HirExprKind::Binary { + left, + op: match value { + TokenValue::LogicalOr => BinaryOperator::LogicalOr, + TokenValue::LogicalXor => BinaryOperator::NotEqual, + TokenValue::LogicalAnd => BinaryOperator::LogicalAnd, + TokenValue::VerticalBar => BinaryOperator::InclusiveOr, + TokenValue::Caret => BinaryOperator::ExclusiveOr, + TokenValue::Ampersand => BinaryOperator::And, + TokenValue::Equal => BinaryOperator::Equal, + TokenValue::NotEqual => BinaryOperator::NotEqual, + TokenValue::GreaterEqual => BinaryOperator::GreaterEqual, + TokenValue::LessEqual => BinaryOperator::LessEqual, + TokenValue::LeftAngle => BinaryOperator::Less, + TokenValue::RightAngle => BinaryOperator::Greater, + TokenValue::LeftShift => BinaryOperator::ShiftLeft, + TokenValue::RightShift => BinaryOperator::ShiftRight, + TokenValue::Plus => BinaryOperator::Add, + TokenValue::Dash => BinaryOperator::Subtract, + TokenValue::Star => BinaryOperator::Multiply, + TokenValue::Slash => BinaryOperator::Divide, + TokenValue::Percent => BinaryOperator::Modulo, + _ => unreachable!(), + }, + right, + }, + meta, + }, + Default::default(), + ) + } + + Ok(left) + } + + pub fn parse_conditional( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + stmt: &mut StmtContext, + passthrough: Option<Handle<HirExpr>>, + ) -> Result<Handle<HirExpr>> { + let mut condition = self.parse_binary(frontend, ctx, stmt, passthrough, 0)?; + let mut meta = stmt.hir_exprs[condition].meta; + + if self.bump_if(frontend, TokenValue::Question).is_some() { + let accept = self.parse_expression(frontend, ctx, stmt)?; + self.expect(frontend, TokenValue::Colon)?; + let reject = self.parse_assignment(frontend, ctx, stmt)?; + let end_meta = stmt.hir_exprs[reject].meta; + + meta.subsume(end_meta); + condition = stmt.hir_exprs.append( + HirExpr { + kind: HirExprKind::Conditional { + condition, + accept, + reject, + }, + meta, + }, + Default::default(), + ) + } + + Ok(condition) + } + + pub fn parse_assignment( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + stmt: &mut StmtContext, + ) -> Result<Handle<HirExpr>> { + let tgt = self.parse_unary(frontend, ctx, stmt)?; + let mut meta = stmt.hir_exprs[tgt].meta; + + Ok(match self.expect_peek(frontend)?.value { + TokenValue::Assign => { + self.bump(frontend)?; + let value = self.parse_assignment(frontend, ctx, stmt)?; + let end_meta = stmt.hir_exprs[value].meta; + + meta.subsume(end_meta); + stmt.hir_exprs.append( + HirExpr { + kind: HirExprKind::Assign { tgt, value }, + meta, + }, + Default::default(), + ) + } + TokenValue::OrAssign + | TokenValue::AndAssign + | TokenValue::AddAssign + | TokenValue::DivAssign + | TokenValue::ModAssign + | TokenValue::SubAssign + | TokenValue::MulAssign + | TokenValue::LeftShiftAssign + | TokenValue::RightShiftAssign + | TokenValue::XorAssign => { + let token = self.bump(frontend)?; + let right = self.parse_assignment(frontend, ctx, stmt)?; + let end_meta = stmt.hir_exprs[right].meta; + + meta.subsume(end_meta); + let value = stmt.hir_exprs.append( + HirExpr { + meta, + kind: HirExprKind::Binary { + left: tgt, + op: match token.value { + TokenValue::OrAssign => BinaryOperator::InclusiveOr, + TokenValue::AndAssign => BinaryOperator::And, + TokenValue::AddAssign => BinaryOperator::Add, + TokenValue::DivAssign => BinaryOperator::Divide, + TokenValue::ModAssign => BinaryOperator::Modulo, + TokenValue::SubAssign => BinaryOperator::Subtract, + TokenValue::MulAssign => BinaryOperator::Multiply, + TokenValue::LeftShiftAssign => BinaryOperator::ShiftLeft, + TokenValue::RightShiftAssign => BinaryOperator::ShiftRight, + TokenValue::XorAssign => BinaryOperator::ExclusiveOr, + _ => unreachable!(), + }, + right, + }, + }, + Default::default(), + ); + + stmt.hir_exprs.append( + HirExpr { + kind: HirExprKind::Assign { tgt, value }, + meta, + }, + Default::default(), + ) + } + _ => self.parse_conditional(frontend, ctx, stmt, Some(tgt))?, + }) + } + + pub fn parse_expression( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + stmt: &mut StmtContext, + ) -> Result<Handle<HirExpr>> { + let mut expr = self.parse_assignment(frontend, ctx, stmt)?; + + while let TokenValue::Comma = self.expect_peek(frontend)?.value { + self.bump(frontend)?; + expr = self.parse_assignment(frontend, ctx, stmt)?; + } + + Ok(expr) + } +} + +const fn binding_power(value: &TokenValue) -> Option<(u8, u8)> { + Some(match *value { + TokenValue::LogicalOr => (1, 2), + TokenValue::LogicalXor => (3, 4), + TokenValue::LogicalAnd => (5, 6), + TokenValue::VerticalBar => (7, 8), + TokenValue::Caret => (9, 10), + TokenValue::Ampersand => (11, 12), + TokenValue::Equal | TokenValue::NotEqual => (13, 14), + TokenValue::GreaterEqual + | TokenValue::LessEqual + | TokenValue::LeftAngle + | TokenValue::RightAngle => (15, 16), + TokenValue::LeftShift | TokenValue::RightShift => (17, 18), + TokenValue::Plus | TokenValue::Dash => (19, 20), + TokenValue::Star | TokenValue::Slash | TokenValue::Percent => (21, 22), + _ => return None, + }) +} diff --git a/third_party/rust/naga/src/front/glsl/parser/functions.rs b/third_party/rust/naga/src/front/glsl/parser/functions.rs new file mode 100644 index 0000000000..38184eedf7 --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/parser/functions.rs @@ -0,0 +1,656 @@ +use crate::front::glsl::context::ExprPos; +use crate::front::glsl::Span; +use crate::Literal; +use crate::{ + front::glsl::{ + ast::ParameterQualifier, + context::Context, + parser::ParsingContext, + token::{Token, TokenValue}, + variables::VarDeclaration, + Error, ErrorKind, Frontend, Result, + }, + Block, Expression, Statement, SwitchCase, UnaryOperator, +}; + +impl<'source> ParsingContext<'source> { + pub fn peek_parameter_qualifier(&mut self, frontend: &mut Frontend) -> bool { + self.peek(frontend).map_or(false, |t| match t.value { + TokenValue::In | TokenValue::Out | TokenValue::InOut | TokenValue::Const => true, + _ => false, + }) + } + + /// Returns the parsed `ParameterQualifier` or `ParameterQualifier::In` + pub fn parse_parameter_qualifier(&mut self, frontend: &mut Frontend) -> ParameterQualifier { + if self.peek_parameter_qualifier(frontend) { + match self.bump(frontend).unwrap().value { + TokenValue::In => ParameterQualifier::In, + TokenValue::Out => ParameterQualifier::Out, + TokenValue::InOut => ParameterQualifier::InOut, + TokenValue::Const => ParameterQualifier::Const, + _ => unreachable!(), + } + } else { + ParameterQualifier::In + } + } + + pub fn parse_statement( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + terminator: &mut Option<usize>, + is_inside_loop: bool, + ) -> Result<Option<Span>> { + // Type qualifiers always identify a declaration statement + if self.peek_type_qualifier(frontend) { + return self.parse_declaration(frontend, ctx, false, is_inside_loop); + } + + // Type names can identify either declaration statements or type constructors + // depending on whether the token following the type name is a `(` (LeftParen) + if self.peek_type_name(frontend) { + // Start by consuming the type name so that we can peek the token after it + let token = self.bump(frontend)?; + // Peek the next token and check if it's a `(` (LeftParen) if so the statement + // is a constructor, otherwise it's a declaration. We need to do the check + // beforehand and not in the if since we will backtrack before the if + let declaration = TokenValue::LeftParen != self.expect_peek(frontend)?.value; + + self.backtrack(token)?; + + if declaration { + return self.parse_declaration(frontend, ctx, false, is_inside_loop); + } + } + + let new_break = || { + let mut block = Block::new(); + block.push(Statement::Break, crate::Span::default()); + block + }; + + let &Token { + ref value, + mut meta, + } = self.expect_peek(frontend)?; + + let meta_rest = match *value { + TokenValue::Continue => { + let meta = self.bump(frontend)?.meta; + ctx.body.push(Statement::Continue, meta); + terminator.get_or_insert(ctx.body.len()); + self.expect(frontend, TokenValue::Semicolon)?.meta + } + TokenValue::Break => { + let meta = self.bump(frontend)?.meta; + ctx.body.push(Statement::Break, meta); + terminator.get_or_insert(ctx.body.len()); + self.expect(frontend, TokenValue::Semicolon)?.meta + } + TokenValue::Return => { + self.bump(frontend)?; + let (value, meta) = match self.expect_peek(frontend)?.value { + TokenValue::Semicolon => (None, self.bump(frontend)?.meta), + _ => { + // TODO: Implicit conversions + let mut stmt = ctx.stmt_ctx(); + let expr = self.parse_expression(frontend, ctx, &mut stmt)?; + self.expect(frontend, TokenValue::Semicolon)?; + let (handle, meta) = + ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs)?; + (Some(handle), meta) + } + }; + + ctx.emit_restart(); + + ctx.body.push(Statement::Return { value }, meta); + terminator.get_or_insert(ctx.body.len()); + + meta + } + TokenValue::Discard => { + let meta = self.bump(frontend)?.meta; + ctx.body.push(Statement::Kill, meta); + terminator.get_or_insert(ctx.body.len()); + + self.expect(frontend, TokenValue::Semicolon)?.meta + } + TokenValue::If => { + let mut meta = self.bump(frontend)?.meta; + + self.expect(frontend, TokenValue::LeftParen)?; + let condition = { + let mut stmt = ctx.stmt_ctx(); + let expr = self.parse_expression(frontend, ctx, &mut stmt)?; + let (handle, more_meta) = + ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs)?; + meta.subsume(more_meta); + handle + }; + self.expect(frontend, TokenValue::RightParen)?; + + let accept = ctx.new_body(|ctx| { + if let Some(more_meta) = + self.parse_statement(frontend, ctx, &mut None, is_inside_loop)? + { + meta.subsume(more_meta); + } + Ok(()) + })?; + + let reject = ctx.new_body(|ctx| { + if self.bump_if(frontend, TokenValue::Else).is_some() { + if let Some(more_meta) = + self.parse_statement(frontend, ctx, &mut None, is_inside_loop)? + { + meta.subsume(more_meta); + } + } + Ok(()) + })?; + + ctx.body.push( + Statement::If { + condition, + accept, + reject, + }, + meta, + ); + + meta + } + TokenValue::Switch => { + let mut meta = self.bump(frontend)?.meta; + let end_meta; + + self.expect(frontend, TokenValue::LeftParen)?; + + let (selector, uint) = { + let mut stmt = ctx.stmt_ctx(); + let expr = self.parse_expression(frontend, ctx, &mut stmt)?; + let (root, meta) = ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs)?; + let uint = ctx.resolve_type(root, meta)?.scalar_kind() + == Some(crate::ScalarKind::Uint); + (root, uint) + }; + + self.expect(frontend, TokenValue::RightParen)?; + + ctx.emit_restart(); + + let mut cases = Vec::new(); + // Track if any default case is present in the switch statement. + let mut default_present = false; + + self.expect(frontend, TokenValue::LeftBrace)?; + loop { + let value = match self.expect_peek(frontend)?.value { + TokenValue::Case => { + self.bump(frontend)?; + + let (const_expr, meta) = + self.parse_constant_expression(frontend, ctx.module)?; + + match ctx.module.const_expressions[const_expr] { + Expression::Literal(Literal::I32(value)) => match uint { + // This unchecked cast isn't good, but since + // we only reach this code when the selector + // is unsigned but the case label is signed, + // verification will reject the module + // anyway (which also matches GLSL's rules). + true => crate::SwitchValue::U32(value as u32), + false => crate::SwitchValue::I32(value), + }, + Expression::Literal(Literal::U32(value)) => { + crate::SwitchValue::U32(value) + } + _ => { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError( + "Case values can only be integers".into(), + ), + meta, + }); + + crate::SwitchValue::I32(0) + } + } + } + TokenValue::Default => { + self.bump(frontend)?; + default_present = true; + crate::SwitchValue::Default + } + TokenValue::RightBrace => { + end_meta = self.bump(frontend)?.meta; + break; + } + _ => { + let Token { value, meta } = self.bump(frontend)?; + return Err(Error { + kind: ErrorKind::InvalidToken( + value, + vec![ + TokenValue::Case.into(), + TokenValue::Default.into(), + TokenValue::RightBrace.into(), + ], + ), + meta, + }); + } + }; + + self.expect(frontend, TokenValue::Colon)?; + + let mut fall_through = true; + + let body = ctx.new_body(|ctx| { + let mut case_terminator = None; + loop { + match self.expect_peek(frontend)?.value { + TokenValue::Case | TokenValue::Default | TokenValue::RightBrace => { + break + } + _ => { + self.parse_statement( + frontend, + ctx, + &mut case_terminator, + is_inside_loop, + )?; + } + } + } + + if let Some(mut idx) = case_terminator { + if let Statement::Break = ctx.body[idx - 1] { + fall_through = false; + idx -= 1; + } + + ctx.body.cull(idx..) + } + + Ok(()) + })?; + + cases.push(SwitchCase { + value, + body, + fall_through, + }) + } + + meta.subsume(end_meta); + + // NOTE: do not unwrap here since a switch statement isn't required + // to have any cases. + if let Some(case) = cases.last_mut() { + // GLSL requires that the last case not be empty, so we check + // that here and produce an error otherwise (fall_through must + // also be checked because `break`s count as statements but + // they aren't added to the body) + if case.body.is_empty() && case.fall_through { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError( + "last case/default label must be followed by statements".into(), + ), + meta, + }) + } + + // GLSL allows the last case to not have any `break` statement, + // this would mark it as fall through but naga's IR requires that + // the last case must not be fall through, so we mark need to mark + // the last case as not fall through always. + case.fall_through = false; + } + + // Add an empty default case in case non was present, this is needed because + // naga's IR requires that all switch statements must have a default case but + // GLSL doesn't require that, so we might need to add an empty default case. + if !default_present { + cases.push(SwitchCase { + value: crate::SwitchValue::Default, + body: Block::new(), + fall_through: false, + }) + } + + ctx.body.push(Statement::Switch { selector, cases }, meta); + + meta + } + TokenValue::While => { + let mut meta = self.bump(frontend)?.meta; + + let loop_body = ctx.new_body(|ctx| { + let mut stmt = ctx.stmt_ctx(); + self.expect(frontend, TokenValue::LeftParen)?; + let root = self.parse_expression(frontend, ctx, &mut stmt)?; + meta.subsume(self.expect(frontend, TokenValue::RightParen)?.meta); + + let (expr, expr_meta) = ctx.lower_expect(stmt, frontend, root, ExprPos::Rhs)?; + let condition = ctx.add_expression( + Expression::Unary { + op: UnaryOperator::LogicalNot, + expr, + }, + expr_meta, + )?; + + ctx.emit_restart(); + + ctx.body.push( + Statement::If { + condition, + accept: new_break(), + reject: Block::new(), + }, + crate::Span::default(), + ); + + meta.subsume(expr_meta); + + if let Some(body_meta) = self.parse_statement(frontend, ctx, &mut None, true)? { + meta.subsume(body_meta); + } + Ok(()) + })?; + + ctx.body.push( + Statement::Loop { + body: loop_body, + continuing: Block::new(), + break_if: None, + }, + meta, + ); + + meta + } + TokenValue::Do => { + let mut meta = self.bump(frontend)?.meta; + + let loop_body = ctx.new_body(|ctx| { + let mut terminator = None; + self.parse_statement(frontend, ctx, &mut terminator, true)?; + + let mut stmt = ctx.stmt_ctx(); + + self.expect(frontend, TokenValue::While)?; + self.expect(frontend, TokenValue::LeftParen)?; + let root = self.parse_expression(frontend, ctx, &mut stmt)?; + let end_meta = self.expect(frontend, TokenValue::RightParen)?.meta; + + meta.subsume(end_meta); + + let (expr, expr_meta) = ctx.lower_expect(stmt, frontend, root, ExprPos::Rhs)?; + let condition = ctx.add_expression( + Expression::Unary { + op: UnaryOperator::LogicalNot, + expr, + }, + expr_meta, + )?; + + ctx.emit_restart(); + + ctx.body.push( + Statement::If { + condition, + accept: new_break(), + reject: Block::new(), + }, + crate::Span::default(), + ); + + if let Some(idx) = terminator { + ctx.body.cull(idx..) + } + Ok(()) + })?; + + ctx.body.push( + Statement::Loop { + body: loop_body, + continuing: Block::new(), + break_if: None, + }, + meta, + ); + + meta + } + TokenValue::For => { + let mut meta = self.bump(frontend)?.meta; + + ctx.symbol_table.push_scope(); + self.expect(frontend, TokenValue::LeftParen)?; + + if self.bump_if(frontend, TokenValue::Semicolon).is_none() { + if self.peek_type_name(frontend) || self.peek_type_qualifier(frontend) { + self.parse_declaration(frontend, ctx, false, false)?; + } else { + let mut stmt = ctx.stmt_ctx(); + let expr = self.parse_expression(frontend, ctx, &mut stmt)?; + ctx.lower(stmt, frontend, expr, ExprPos::Rhs)?; + self.expect(frontend, TokenValue::Semicolon)?; + } + } + + let loop_body = ctx.new_body(|ctx| { + if self.bump_if(frontend, TokenValue::Semicolon).is_none() { + let (expr, expr_meta) = if self.peek_type_name(frontend) + || self.peek_type_qualifier(frontend) + { + let mut qualifiers = self.parse_type_qualifiers(frontend, ctx)?; + let (ty, mut meta) = self.parse_type_non_void(frontend, ctx)?; + let name = self.expect_ident(frontend)?.0; + + self.expect(frontend, TokenValue::Assign)?; + + let (value, end_meta) = self.parse_initializer(frontend, ty, ctx)?; + meta.subsume(end_meta); + + let decl = VarDeclaration { + qualifiers: &mut qualifiers, + ty, + name: Some(name), + init: None, + meta, + }; + + let pointer = frontend.add_local_var(ctx, decl)?; + + ctx.emit_restart(); + + ctx.body.push(Statement::Store { pointer, value }, meta); + + (value, end_meta) + } else { + let mut stmt = ctx.stmt_ctx(); + let root = self.parse_expression(frontend, ctx, &mut stmt)?; + ctx.lower_expect(stmt, frontend, root, ExprPos::Rhs)? + }; + + let condition = ctx.add_expression( + Expression::Unary { + op: UnaryOperator::LogicalNot, + expr, + }, + expr_meta, + )?; + + ctx.emit_restart(); + + ctx.body.push( + Statement::If { + condition, + accept: new_break(), + reject: Block::new(), + }, + crate::Span::default(), + ); + + self.expect(frontend, TokenValue::Semicolon)?; + } + Ok(()) + })?; + + let continuing = ctx.new_body(|ctx| { + match self.expect_peek(frontend)?.value { + TokenValue::RightParen => {} + _ => { + let mut stmt = ctx.stmt_ctx(); + let rest = self.parse_expression(frontend, ctx, &mut stmt)?; + ctx.lower(stmt, frontend, rest, ExprPos::Rhs)?; + } + } + Ok(()) + })?; + + meta.subsume(self.expect(frontend, TokenValue::RightParen)?.meta); + + let loop_body = ctx.with_body(loop_body, |ctx| { + if let Some(stmt_meta) = self.parse_statement(frontend, ctx, &mut None, true)? { + meta.subsume(stmt_meta); + } + Ok(()) + })?; + + ctx.body.push( + Statement::Loop { + body: loop_body, + continuing, + break_if: None, + }, + meta, + ); + + ctx.symbol_table.pop_scope(); + + meta + } + TokenValue::LeftBrace => { + let mut meta = self.bump(frontend)?.meta; + + let mut block_terminator = None; + + let block = ctx.new_body(|ctx| { + let block_meta = self.parse_compound_statement( + meta, + frontend, + ctx, + &mut block_terminator, + is_inside_loop, + )?; + meta.subsume(block_meta); + Ok(()) + })?; + + ctx.body.push(Statement::Block(block), meta); + if block_terminator.is_some() { + terminator.get_or_insert(ctx.body.len()); + } + + meta + } + TokenValue::Semicolon => self.bump(frontend)?.meta, + _ => { + // Attempt to force expression parsing for remainder of the + // tokens. Unknown or invalid tokens will be caught there and + // turned into an error. + let mut stmt = ctx.stmt_ctx(); + let expr = self.parse_expression(frontend, ctx, &mut stmt)?; + ctx.lower(stmt, frontend, expr, ExprPos::Rhs)?; + self.expect(frontend, TokenValue::Semicolon)?.meta + } + }; + + meta.subsume(meta_rest); + Ok(Some(meta)) + } + + pub fn parse_compound_statement( + &mut self, + mut meta: Span, + frontend: &mut Frontend, + ctx: &mut Context, + terminator: &mut Option<usize>, + is_inside_loop: bool, + ) -> Result<Span> { + ctx.symbol_table.push_scope(); + + loop { + if let Some(Token { + meta: brace_meta, .. + }) = self.bump_if(frontend, TokenValue::RightBrace) + { + meta.subsume(brace_meta); + break; + } + + let stmt = self.parse_statement(frontend, ctx, terminator, is_inside_loop)?; + + if let Some(stmt_meta) = stmt { + meta.subsume(stmt_meta); + } + } + + if let Some(idx) = *terminator { + ctx.body.cull(idx..) + } + + ctx.symbol_table.pop_scope(); + + Ok(meta) + } + + pub fn parse_function_args( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + ) -> Result<()> { + if self.bump_if(frontend, TokenValue::Void).is_some() { + return Ok(()); + } + + loop { + if self.peek_type_name(frontend) || self.peek_parameter_qualifier(frontend) { + let qualifier = self.parse_parameter_qualifier(frontend); + let mut ty = self.parse_type_non_void(frontend, ctx)?.0; + + match self.expect_peek(frontend)?.value { + TokenValue::Comma => { + self.bump(frontend)?; + ctx.add_function_arg(None, ty, qualifier)?; + continue; + } + TokenValue::Identifier(_) => { + let mut name = self.expect_ident(frontend)?; + self.parse_array_specifier(frontend, ctx, &mut name.1, &mut ty)?; + + ctx.add_function_arg(Some(name), ty, qualifier)?; + + if self.bump_if(frontend, TokenValue::Comma).is_some() { + continue; + } + + break; + } + _ => break, + } + } + + break; + } + + Ok(()) + } +} diff --git a/third_party/rust/naga/src/front/glsl/parser/types.rs b/third_party/rust/naga/src/front/glsl/parser/types.rs new file mode 100644 index 0000000000..1b612b298d --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/parser/types.rs @@ -0,0 +1,443 @@ +use std::num::NonZeroU32; + +use crate::{ + front::glsl::{ + ast::{QualifierKey, QualifierValue, StorageQualifier, StructLayout, TypeQualifiers}, + context::Context, + error::ExpectedToken, + parser::ParsingContext, + token::{Token, TokenValue}, + Error, ErrorKind, Frontend, Result, + }, + AddressSpace, ArraySize, Handle, Span, Type, TypeInner, +}; + +impl<'source> ParsingContext<'source> { + /// Parses an optional array_specifier returning whether or not it's present + /// and modifying the type handle if it exists + pub fn parse_array_specifier( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + span: &mut Span, + ty: &mut Handle<Type>, + ) -> Result<()> { + while self.parse_array_specifier_single(frontend, ctx, span, ty)? {} + Ok(()) + } + + /// Implementation of [`Self::parse_array_specifier`] for a single array_specifier + fn parse_array_specifier_single( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + span: &mut Span, + ty: &mut Handle<Type>, + ) -> Result<bool> { + if self.bump_if(frontend, TokenValue::LeftBracket).is_some() { + let size = if let Some(Token { meta, .. }) = + self.bump_if(frontend, TokenValue::RightBracket) + { + span.subsume(meta); + ArraySize::Dynamic + } else { + let (value, constant_span) = self.parse_uint_constant(frontend, ctx)?; + let size = NonZeroU32::new(value).ok_or(Error { + kind: ErrorKind::SemanticError("Array size must be greater than zero".into()), + meta: constant_span, + })?; + let end_span = self.expect(frontend, TokenValue::RightBracket)?.meta; + span.subsume(end_span); + ArraySize::Constant(size) + }; + + frontend.layouter.update(ctx.module.to_ctx()).unwrap(); + let stride = frontend.layouter[*ty].to_stride(); + *ty = ctx.module.types.insert( + Type { + name: None, + inner: TypeInner::Array { + base: *ty, + size, + stride, + }, + }, + *span, + ); + + Ok(true) + } else { + Ok(false) + } + } + + pub fn parse_type( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + ) -> Result<(Option<Handle<Type>>, Span)> { + let token = self.bump(frontend)?; + let mut handle = match token.value { + TokenValue::Void => return Ok((None, token.meta)), + TokenValue::TypeName(ty) => ctx.module.types.insert(ty, token.meta), + TokenValue::Struct => { + let mut meta = token.meta; + let ty_name = self.expect_ident(frontend)?.0; + self.expect(frontend, TokenValue::LeftBrace)?; + let mut members = Vec::new(); + let span = self.parse_struct_declaration_list( + frontend, + ctx, + &mut members, + StructLayout::Std140, + )?; + let end_meta = self.expect(frontend, TokenValue::RightBrace)?.meta; + meta.subsume(end_meta); + let ty = ctx.module.types.insert( + Type { + name: Some(ty_name.clone()), + inner: TypeInner::Struct { members, span }, + }, + meta, + ); + frontend.lookup_type.insert(ty_name, ty); + ty + } + TokenValue::Identifier(ident) => match frontend.lookup_type.get(&ident) { + Some(ty) => *ty, + None => { + return Err(Error { + kind: ErrorKind::UnknownType(ident), + meta: token.meta, + }) + } + }, + _ => { + return Err(Error { + kind: ErrorKind::InvalidToken( + token.value, + vec![ + TokenValue::Void.into(), + TokenValue::Struct.into(), + ExpectedToken::TypeName, + ], + ), + meta: token.meta, + }); + } + }; + + let mut span = token.meta; + self.parse_array_specifier(frontend, ctx, &mut span, &mut handle)?; + Ok((Some(handle), span)) + } + + pub fn parse_type_non_void( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + ) -> Result<(Handle<Type>, Span)> { + let (maybe_ty, meta) = self.parse_type(frontend, ctx)?; + let ty = maybe_ty.ok_or_else(|| Error { + kind: ErrorKind::SemanticError("Type can't be void".into()), + meta, + })?; + + Ok((ty, meta)) + } + + pub fn peek_type_qualifier(&mut self, frontend: &mut Frontend) -> bool { + self.peek(frontend).map_or(false, |t| match t.value { + TokenValue::Invariant + | TokenValue::Interpolation(_) + | TokenValue::Sampling(_) + | TokenValue::PrecisionQualifier(_) + | TokenValue::Const + | TokenValue::In + | TokenValue::Out + | TokenValue::Uniform + | TokenValue::Shared + | TokenValue::Buffer + | TokenValue::Restrict + | TokenValue::MemoryQualifier(_) + | TokenValue::Layout => true, + _ => false, + }) + } + + pub fn parse_type_qualifiers<'a>( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + ) -> Result<TypeQualifiers<'a>> { + let mut qualifiers = TypeQualifiers::default(); + + while self.peek_type_qualifier(frontend) { + let token = self.bump(frontend)?; + + // Handle layout qualifiers outside the match since this can push multiple values + if token.value == TokenValue::Layout { + self.parse_layout_qualifier_id_list(frontend, ctx, &mut qualifiers)?; + continue; + } + + qualifiers.span.subsume(token.meta); + + match token.value { + TokenValue::Invariant => { + if qualifiers.invariant.is_some() { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError( + "Cannot use more than one invariant qualifier per declaration" + .into(), + ), + meta: token.meta, + }) + } + + qualifiers.invariant = Some(token.meta); + } + TokenValue::Interpolation(i) => { + if qualifiers.interpolation.is_some() { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError( + "Cannot use more than one interpolation qualifier per declaration" + .into(), + ), + meta: token.meta, + }) + } + + qualifiers.interpolation = Some((i, token.meta)); + } + TokenValue::Const + | TokenValue::In + | TokenValue::Out + | TokenValue::Uniform + | TokenValue::Shared + | TokenValue::Buffer => { + let storage = match token.value { + TokenValue::Const => StorageQualifier::Const, + TokenValue::In => StorageQualifier::Input, + TokenValue::Out => StorageQualifier::Output, + TokenValue::Uniform => { + StorageQualifier::AddressSpace(AddressSpace::Uniform) + } + TokenValue::Shared => { + StorageQualifier::AddressSpace(AddressSpace::WorkGroup) + } + TokenValue::Buffer => { + StorageQualifier::AddressSpace(AddressSpace::Storage { + access: crate::StorageAccess::all(), + }) + } + _ => unreachable!(), + }; + + if StorageQualifier::AddressSpace(AddressSpace::Function) + != qualifiers.storage.0 + { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError( + "Cannot use more than one storage qualifier per declaration".into(), + ), + meta: token.meta, + }); + } + + qualifiers.storage = (storage, token.meta); + } + TokenValue::Sampling(s) => { + if qualifiers.sampling.is_some() { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError( + "Cannot use more than one sampling qualifier per declaration" + .into(), + ), + meta: token.meta, + }) + } + + qualifiers.sampling = Some((s, token.meta)); + } + TokenValue::PrecisionQualifier(p) => { + if qualifiers.precision.is_some() { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError( + "Cannot use more than one precision qualifier per declaration" + .into(), + ), + meta: token.meta, + }) + } + + qualifiers.precision = Some((p, token.meta)); + } + TokenValue::MemoryQualifier(access) => { + let storage_access = qualifiers + .storage_access + .get_or_insert((crate::StorageAccess::all(), Span::default())); + if !storage_access.0.contains(!access) { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError( + "The same memory qualifier can only be used once".into(), + ), + meta: token.meta, + }) + } + + storage_access.0 &= access; + storage_access.1.subsume(token.meta); + } + TokenValue::Restrict => continue, + _ => unreachable!(), + }; + } + + Ok(qualifiers) + } + + pub fn parse_layout_qualifier_id_list( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + qualifiers: &mut TypeQualifiers, + ) -> Result<()> { + self.expect(frontend, TokenValue::LeftParen)?; + loop { + self.parse_layout_qualifier_id(frontend, ctx, &mut qualifiers.layout_qualifiers)?; + + if self.bump_if(frontend, TokenValue::Comma).is_some() { + continue; + } + + break; + } + let token = self.expect(frontend, TokenValue::RightParen)?; + qualifiers.span.subsume(token.meta); + + Ok(()) + } + + pub fn parse_layout_qualifier_id( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + qualifiers: &mut crate::FastHashMap<QualifierKey, (QualifierValue, Span)>, + ) -> Result<()> { + // layout_qualifier_id: + // IDENTIFIER + // IDENTIFIER EQUAL constant_expression + // SHARED + let mut token = self.bump(frontend)?; + match token.value { + TokenValue::Identifier(name) => { + let (key, value) = match name.as_str() { + "std140" => ( + QualifierKey::Layout, + QualifierValue::Layout(StructLayout::Std140), + ), + "std430" => ( + QualifierKey::Layout, + QualifierValue::Layout(StructLayout::Std430), + ), + word => { + if let Some(format) = map_image_format(word) { + (QualifierKey::Format, QualifierValue::Format(format)) + } else { + let key = QualifierKey::String(name.into()); + let value = if self.bump_if(frontend, TokenValue::Assign).is_some() { + let (value, end_meta) = + match self.parse_uint_constant(frontend, ctx) { + Ok(v) => v, + Err(e) => { + frontend.errors.push(e); + (0, Span::default()) + } + }; + token.meta.subsume(end_meta); + + QualifierValue::Uint(value) + } else { + QualifierValue::None + }; + + (key, value) + } + } + }; + + qualifiers.insert(key, (value, token.meta)); + } + _ => frontend.errors.push(Error { + kind: ErrorKind::InvalidToken(token.value, vec![ExpectedToken::Identifier]), + meta: token.meta, + }), + } + + Ok(()) + } + + pub fn peek_type_name(&mut self, frontend: &mut Frontend) -> bool { + self.peek(frontend).map_or(false, |t| match t.value { + TokenValue::TypeName(_) | TokenValue::Void => true, + TokenValue::Struct => true, + TokenValue::Identifier(ref ident) => frontend.lookup_type.contains_key(ident), + _ => false, + }) + } +} + +fn map_image_format(word: &str) -> Option<crate::StorageFormat> { + use crate::StorageFormat as Sf; + + let format = match word { + // float-image-format-qualifier: + "rgba32f" => Sf::Rgba32Float, + "rgba16f" => Sf::Rgba16Float, + "rg32f" => Sf::Rg32Float, + "rg16f" => Sf::Rg16Float, + "r11f_g11f_b10f" => Sf::Rg11b10Float, + "r32f" => Sf::R32Float, + "r16f" => Sf::R16Float, + "rgba16" => Sf::Rgba16Unorm, + "rgb10_a2ui" => Sf::Rgb10a2Uint, + "rgb10_a2" => Sf::Rgb10a2Unorm, + "rgba8" => Sf::Rgba8Unorm, + "rg16" => Sf::Rg16Unorm, + "rg8" => Sf::Rg8Unorm, + "r16" => Sf::R16Unorm, + "r8" => Sf::R8Unorm, + "rgba16_snorm" => Sf::Rgba16Snorm, + "rgba8_snorm" => Sf::Rgba8Snorm, + "rg16_snorm" => Sf::Rg16Snorm, + "rg8_snorm" => Sf::Rg8Snorm, + "r16_snorm" => Sf::R16Snorm, + "r8_snorm" => Sf::R8Snorm, + // int-image-format-qualifier: + "rgba32i" => Sf::Rgba32Sint, + "rgba16i" => Sf::Rgba16Sint, + "rgba8i" => Sf::Rgba8Sint, + "rg32i" => Sf::Rg32Sint, + "rg16i" => Sf::Rg16Sint, + "rg8i" => Sf::Rg8Sint, + "r32i" => Sf::R32Sint, + "r16i" => Sf::R16Sint, + "r8i" => Sf::R8Sint, + // uint-image-format-qualifier: + "rgba32ui" => Sf::Rgba32Uint, + "rgba16ui" => Sf::Rgba16Uint, + "rgba8ui" => Sf::Rgba8Uint, + "rg32ui" => Sf::Rg32Uint, + "rg16ui" => Sf::Rg16Uint, + "rg8ui" => Sf::Rg8Uint, + "r32ui" => Sf::R32Uint, + "r16ui" => Sf::R16Uint, + "r8ui" => Sf::R8Uint, + // TODO: These next ones seem incorrect to me + // "rgb10_a2ui" => Sf::Rgb10a2Unorm, + _ => return None, + }; + + Some(format) +} diff --git a/third_party/rust/naga/src/front/glsl/parser_tests.rs b/third_party/rust/naga/src/front/glsl/parser_tests.rs new file mode 100644 index 0000000000..259052cd27 --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/parser_tests.rs @@ -0,0 +1,858 @@ +use super::{ + ast::Profile, + error::ExpectedToken, + error::{Error, ErrorKind, ParseError}, + token::TokenValue, + Frontend, Options, Span, +}; +use crate::ShaderStage; +use pp_rs::token::PreprocessorError; + +#[test] +fn version() { + let mut frontend = Frontend::default(); + + // invalid versions + assert_eq!( + frontend + .parse( + &Options::from(ShaderStage::Vertex), + "#version 99000\n void main(){}", + ) + .err() + .unwrap(), + ParseError { + errors: vec![Error { + kind: ErrorKind::InvalidVersion(99000), + meta: Span::new(9, 14) + }], + }, + ); + + assert_eq!( + frontend + .parse( + &Options::from(ShaderStage::Vertex), + "#version 449\n void main(){}", + ) + .err() + .unwrap(), + ParseError { + errors: vec![Error { + kind: ErrorKind::InvalidVersion(449), + meta: Span::new(9, 12) + }] + }, + ); + + assert_eq!( + frontend + .parse( + &Options::from(ShaderStage::Vertex), + "#version 450 smart\n void main(){}", + ) + .err() + .unwrap(), + ParseError { + errors: vec![Error { + kind: ErrorKind::InvalidProfile("smart".into()), + meta: Span::new(13, 18), + }] + }, + ); + + assert_eq!( + frontend + .parse( + &Options::from(ShaderStage::Vertex), + "#version 450\nvoid main(){} #version 450", + ) + .err() + .unwrap(), + ParseError { + errors: vec![ + Error { + kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedHash,), + meta: Span::new(27, 28), + }, + Error { + kind: ErrorKind::InvalidToken( + TokenValue::Identifier("version".into()), + vec![ExpectedToken::Eof] + ), + meta: Span::new(28, 35) + } + ] + }, + ); + + // valid versions + frontend + .parse( + &Options::from(ShaderStage::Vertex), + " # version 450\nvoid main() {}", + ) + .unwrap(); + assert_eq!( + (frontend.metadata().version, frontend.metadata().profile), + (450, Profile::Core) + ); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + "#version 450\nvoid main() {}", + ) + .unwrap(); + assert_eq!( + (frontend.metadata().version, frontend.metadata().profile), + (450, Profile::Core) + ); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + "#version 450 core\nvoid main(void) {}", + ) + .unwrap(); + assert_eq!( + (frontend.metadata().version, frontend.metadata().profile), + (450, Profile::Core) + ); +} + +#[test] +fn control_flow() { + let mut frontend = Frontend::default(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + void main() { + if (true) { + return 1; + } else { + return 2; + } + } + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + void main() { + if (true) { + return 1; + } + } + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + void main() { + int x; + int y = 3; + switch (5) { + case 2: + x = 2; + case 5: + x = 5; + y = 2; + break; + default: + x = 0; + } + } + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + void main() { + int x = 0; + while(x < 5) { + x = x + 1; + } + do { + x = x - 1; + } while(x >= 4) + } + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + void main() { + int x = 0; + for(int i = 0; i < 10;) { + x = x + 2; + } + for(;;); + return x; + } + "#, + ) + .unwrap(); +} + +#[test] +fn declarations() { + let mut frontend = Frontend::default(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + #version 450 + layout(location = 0) in vec2 v_uv; + layout(location = 0) out vec4 o_color; + layout(set = 1, binding = 1) uniform texture2D tex; + layout(set = 1, binding = 2) uniform sampler tex_sampler; + + layout(early_fragment_tests) in; + + void main() {} + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + #version 450 + layout(std140, set = 2, binding = 0) + uniform u_locals { + vec3 model_offs; + float load_time; + ivec4 atlas_offs; + }; + + void main() {} + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + #version 450 + layout(push_constant) + uniform u_locals { + vec3 model_offs; + float load_time; + ivec4 atlas_offs; + }; + + void main() {} + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + #version 450 + layout(std430, set = 2, binding = 0) + uniform u_locals { + vec3 model_offs; + float load_time; + ivec4 atlas_offs; + }; + + void main() {} + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + #version 450 + layout(std140, set = 2, binding = 0) + uniform u_locals { + vec3 model_offs; + float load_time; + } block_var; + + void main() { + load_time * model_offs; + block_var.load_time * block_var.model_offs; + } + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + #version 450 + float vector = vec4(1.0 / 17.0, 9.0 / 17.0, 3.0 / 17.0, 11.0 / 17.0); + + void main() {} + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + #version 450 + precision highp float; + + void main() {} + "#, + ) + .unwrap(); +} + +#[test] +fn textures() { + let mut frontend = Frontend::default(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + #version 450 + layout(location = 0) in vec2 v_uv; + layout(location = 0) out vec4 o_color; + layout(set = 1, binding = 1) uniform texture2D tex; + layout(set = 1, binding = 2) uniform sampler tex_sampler; + void main() { + o_color = texture(sampler2D(tex, tex_sampler), v_uv); + o_color.a = texture(sampler2D(tex, tex_sampler), v_uv, 2.0).a; + } + "#, + ) + .unwrap(); +} + +#[test] +fn functions() { + let mut frontend = Frontend::default(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + void test1(float); + void test1(float) {} + + void main() {} + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + void test2(float a) {} + void test3(float a, float b) {} + void test4(float, float) {} + + void main() {} + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + float test(float a) { return a; } + + void main() {} + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + float test(vec4 p) { + return p.x; + } + + void main() {} + "#, + ) + .unwrap(); + + // Function overloading + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + float test(vec2 p); + float test(vec3 p); + float test(vec4 p); + + float test(vec2 p) { + return p.x; + } + + float test(vec3 p) { + return p.x; + } + + float test(vec4 p) { + return p.x; + } + + void main() {} + "#, + ) + .unwrap(); + + assert_eq!( + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + int test(vec4 p) { + return p.x; + } + + float test(vec4 p) { + return p.x; + } + + void main() {} + "#, + ) + .err() + .unwrap(), + ParseError { + errors: vec![Error { + kind: ErrorKind::SemanticError("Function already defined".into()), + meta: Span::new(134, 152), + }] + }, + ); + + println!(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + float callee(uint q) { + return float(q); + } + + float caller() { + callee(1u); + } + + void main() {} + "#, + ) + .unwrap(); + + // Nested function call + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + layout(set = 0, binding = 1) uniform texture2D t_noise; + layout(set = 0, binding = 2) uniform sampler s_noise; + + void main() { + textureLod(sampler2D(t_noise, s_noise), vec2(1.0), 0); + } + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + void fun(vec2 in_parameter, out float out_parameter) { + ivec2 _ = ivec2(in_parameter); + } + + void main() { + float a; + fun(vec2(1.0), a); + } + "#, + ) + .unwrap(); +} + +#[test] +fn constants() { + use crate::{Constant, Expression, Type, TypeInner}; + + let mut frontend = Frontend::default(); + + let module = frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + const float a = 1.0; + float global = a; + const float b = a; + + void main() {} + "#, + ) + .unwrap(); + + let mut types = module.types.iter(); + let mut constants = module.constants.iter(); + let mut const_expressions = module.const_expressions.iter(); + + let (ty_handle, ty) = types.next().unwrap(); + assert_eq!( + ty, + &Type { + name: None, + inner: TypeInner::Scalar(crate::Scalar::F32) + } + ); + + let (init_handle, init) = const_expressions.next().unwrap(); + assert_eq!(init, &Expression::Literal(crate::Literal::F32(1.0))); + + assert_eq!( + constants.next().unwrap().1, + &Constant { + name: Some("a".to_owned()), + r#override: crate::Override::None, + ty: ty_handle, + init: init_handle + } + ); + + assert_eq!( + constants.next().unwrap().1, + &Constant { + name: Some("b".to_owned()), + r#override: crate::Override::None, + ty: ty_handle, + init: init_handle + } + ); + + assert!(constants.next().is_none()); +} + +#[test] +fn function_overloading() { + let mut frontend = Frontend::default(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + + float saturate(float v) { return clamp(v, 0.0, 1.0); } + vec2 saturate(vec2 v) { return clamp(v, vec2(0.0), vec2(1.0)); } + vec3 saturate(vec3 v) { return clamp(v, vec3(0.0), vec3(1.0)); } + vec4 saturate(vec4 v) { return clamp(v, vec4(0.0), vec4(1.0)); } + + void main() { + float v1 = saturate(1.5); + vec2 v2 = saturate(vec2(0.5, 1.5)); + vec3 v3 = saturate(vec3(0.5, 1.5, 2.5)); + vec3 v4 = saturate(vec4(0.5, 1.5, 2.5, 3.5)); + } + "#, + ) + .unwrap(); +} + +#[test] +fn implicit_conversions() { + let mut frontend = Frontend::default(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + void main() { + mat4 a = mat4(1); + float b = 1u; + float c = 1 + 2.0; + } + "#, + ) + .unwrap(); + + assert_eq!( + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + void test(int a) {} + void test(uint a) {} + + void main() { + test(1.0); + } + "#, + ) + .err() + .unwrap(), + ParseError { + errors: vec![Error { + kind: ErrorKind::SemanticError("Unknown function \'test\'".into()), + meta: Span::new(156, 165), + }] + }, + ); + + assert_eq!( + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + void test(float a) {} + void test(uint a) {} + + void main() { + test(1); + } + "#, + ) + .err() + .unwrap(), + ParseError { + errors: vec![Error { + kind: ErrorKind::SemanticError("Ambiguous best function for \'test\'".into()), + meta: Span::new(158, 165), + }] + } + ); +} + +#[test] +fn structs() { + let mut frontend = Frontend::default(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + Test { + vec4 pos; + } xx; + + void main() {} + "#, + ) + .unwrap_err(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + struct Test { + vec4 pos; + }; + + void main() {} + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + const int NUM_VECS = 42; + struct Test { + vec4 vecs[NUM_VECS]; + }; + + void main() {} + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + struct Hello { + vec4 test; + } test() { + return Hello( vec4(1.0) ); + } + + void main() {} + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + struct Test {}; + + void main() {} + "#, + ) + .unwrap_err(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + inout struct Test { + vec4 x; + }; + + void main() {} + "#, + ) + .unwrap_err(); +} + +#[test] +fn swizzles() { + let mut frontend = Frontend::default(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + void main() { + vec4 v = vec4(1); + v.xyz = vec3(2); + v.x = 5.0; + v.xyz.zxy.yx.xy = vec2(5.0, 1.0); + } + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + void main() { + vec4 v = vec4(1); + v.xx = vec2(5.0); + } + "#, + ) + .unwrap_err(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + void main() { + vec3 v = vec3(1); + v.w = 2.0; + } + "#, + ) + .unwrap_err(); +} + +#[test] +fn expressions() { + let mut frontend = Frontend::default(); + + // Vector indexing + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + float test(int index) { + vec4 v = vec4(1.0, 2.0, 3.0, 4.0); + return v[index] + 1.0; + } + + void main() {} + "#, + ) + .unwrap(); + + // Prefix increment/decrement + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + void main() { + uint index = 0; + + --index; + ++index; + } + "#, + ) + .unwrap(); + + // Dynamic indexing of array + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + void main() { + const vec4 positions[1] = { vec4(0) }; + + gl_Position = positions[gl_VertexIndex]; + } + "#, + ) + .unwrap(); +} diff --git a/third_party/rust/naga/src/front/glsl/token.rs b/third_party/rust/naga/src/front/glsl/token.rs new file mode 100644 index 0000000000..303723a27b --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/token.rs @@ -0,0 +1,137 @@ +pub use pp_rs::token::{Float, Integer, Location, Token as PPToken}; + +use super::ast::Precision; +use crate::{Interpolation, Sampling, Span, Type}; + +impl From<Location> for Span { + fn from(loc: Location) -> Self { + Span::new(loc.start, loc.end) + } +} + +#[derive(Debug)] +#[cfg_attr(test, derive(PartialEq))] +pub struct Token { + pub value: TokenValue, + pub meta: Span, +} + +/// A token passed from the lexing used in the parsing. +/// +/// This type is exported since it's returned in the +/// [`InvalidToken`](super::ErrorKind::InvalidToken) error. +#[derive(Clone, Debug, PartialEq)] +pub enum TokenValue { + Identifier(String), + + FloatConstant(Float), + IntConstant(Integer), + BoolConstant(bool), + + Layout, + In, + Out, + InOut, + Uniform, + Buffer, + Const, + Shared, + + Restrict, + /// A `glsl` memory qualifier such as `writeonly` + /// + /// The associated [`crate::StorageAccess`] is the access being allowed + /// (for example `writeonly` has an associated value of [`crate::StorageAccess::STORE`]) + MemoryQualifier(crate::StorageAccess), + + Invariant, + Interpolation(Interpolation), + Sampling(Sampling), + Precision, + PrecisionQualifier(Precision), + + Continue, + Break, + Return, + Discard, + + If, + Else, + Switch, + Case, + Default, + While, + Do, + For, + + Void, + Struct, + TypeName(Type), + + Assign, + AddAssign, + SubAssign, + MulAssign, + DivAssign, + ModAssign, + LeftShiftAssign, + RightShiftAssign, + AndAssign, + XorAssign, + OrAssign, + + Increment, + Decrement, + + LogicalOr, + LogicalAnd, + LogicalXor, + + LessEqual, + GreaterEqual, + Equal, + NotEqual, + + LeftShift, + RightShift, + + LeftBrace, + RightBrace, + LeftParen, + RightParen, + LeftBracket, + RightBracket, + LeftAngle, + RightAngle, + + Comma, + Semicolon, + Colon, + Dot, + Bang, + Dash, + Tilde, + Plus, + Star, + Slash, + Percent, + VerticalBar, + Caret, + Ampersand, + Question, +} + +#[derive(Debug)] +#[cfg_attr(test, derive(PartialEq))] +pub struct Directive { + pub kind: DirectiveKind, + pub tokens: Vec<PPToken>, +} + +#[derive(Debug)] +#[cfg_attr(test, derive(PartialEq))] +pub enum DirectiveKind { + Version { is_first_directive: bool }, + Extension, + Pragma, +} diff --git a/third_party/rust/naga/src/front/glsl/types.rs b/third_party/rust/naga/src/front/glsl/types.rs new file mode 100644 index 0000000000..e87d76fffc --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/types.rs @@ -0,0 +1,360 @@ +use super::{context::Context, Error, ErrorKind, Result, Span}; +use crate::{ + proc::ResolveContext, Expression, Handle, ImageClass, ImageDimension, Scalar, ScalarKind, Type, + TypeInner, VectorSize, +}; + +pub fn parse_type(type_name: &str) -> Option<Type> { + match type_name { + "bool" => Some(Type { + name: None, + inner: TypeInner::Scalar(Scalar::BOOL), + }), + "float" => Some(Type { + name: None, + inner: TypeInner::Scalar(Scalar::F32), + }), + "double" => Some(Type { + name: None, + inner: TypeInner::Scalar(Scalar::F64), + }), + "int" => Some(Type { + name: None, + inner: TypeInner::Scalar(Scalar::I32), + }), + "uint" => Some(Type { + name: None, + inner: TypeInner::Scalar(Scalar::U32), + }), + "sampler" | "samplerShadow" => Some(Type { + name: None, + inner: TypeInner::Sampler { + comparison: type_name == "samplerShadow", + }, + }), + word => { + fn kind_width_parse(ty: &str) -> Option<Scalar> { + Some(match ty { + "" => Scalar::F32, + "b" => Scalar::BOOL, + "i" => Scalar::I32, + "u" => Scalar::U32, + "d" => Scalar::F64, + _ => return None, + }) + } + + fn size_parse(n: &str) -> Option<VectorSize> { + Some(match n { + "2" => VectorSize::Bi, + "3" => VectorSize::Tri, + "4" => VectorSize::Quad, + _ => return None, + }) + } + + let vec_parse = |word: &str| { + let mut iter = word.split("vec"); + + let kind = iter.next()?; + let size = iter.next()?; + let scalar = kind_width_parse(kind)?; + let size = size_parse(size)?; + + Some(Type { + name: None, + inner: TypeInner::Vector { size, scalar }, + }) + }; + + let mat_parse = |word: &str| { + let mut iter = word.split("mat"); + + let kind = iter.next()?; + let size = iter.next()?; + let scalar = kind_width_parse(kind)?; + + let (columns, rows) = if let Some(size) = size_parse(size) { + (size, size) + } else { + let mut iter = size.split('x'); + match (iter.next()?, iter.next()?, iter.next()) { + (col, row, None) => (size_parse(col)?, size_parse(row)?), + _ => return None, + } + }; + + Some(Type { + name: None, + inner: TypeInner::Matrix { + columns, + rows, + scalar, + }, + }) + }; + + let texture_parse = |word: &str| { + let mut iter = word.split("texture"); + + let texture_kind = |ty| { + Some(match ty { + "" => ScalarKind::Float, + "i" => ScalarKind::Sint, + "u" => ScalarKind::Uint, + _ => return None, + }) + }; + + let kind = iter.next()?; + let size = iter.next()?; + let kind = texture_kind(kind)?; + + let sampled = |multi| ImageClass::Sampled { kind, multi }; + + let (dim, arrayed, class) = match size { + "1D" => (ImageDimension::D1, false, sampled(false)), + "1DArray" => (ImageDimension::D1, true, sampled(false)), + "2D" => (ImageDimension::D2, false, sampled(false)), + "2DArray" => (ImageDimension::D2, true, sampled(false)), + "2DMS" => (ImageDimension::D2, false, sampled(true)), + "2DMSArray" => (ImageDimension::D2, true, sampled(true)), + "3D" => (ImageDimension::D3, false, sampled(false)), + "Cube" => (ImageDimension::Cube, false, sampled(false)), + "CubeArray" => (ImageDimension::Cube, true, sampled(false)), + _ => return None, + }; + + Some(Type { + name: None, + inner: TypeInner::Image { + dim, + arrayed, + class, + }, + }) + }; + + let image_parse = |word: &str| { + let mut iter = word.split("image"); + + let texture_kind = |ty| { + Some(match ty { + "" => ScalarKind::Float, + "i" => ScalarKind::Sint, + "u" => ScalarKind::Uint, + _ => return None, + }) + }; + + let kind = iter.next()?; + let size = iter.next()?; + // TODO: Check that the texture format and the kind match + let _ = texture_kind(kind)?; + + let class = ImageClass::Storage { + format: crate::StorageFormat::R8Uint, + access: crate::StorageAccess::all(), + }; + + // TODO: glsl support multisampled storage images, naga doesn't + let (dim, arrayed) = match size { + "1D" => (ImageDimension::D1, false), + "1DArray" => (ImageDimension::D1, true), + "2D" => (ImageDimension::D2, false), + "2DArray" => (ImageDimension::D2, true), + "3D" => (ImageDimension::D3, false), + // Naga doesn't support cube images and it's usefulness + // is questionable, so they won't be supported for now + // "Cube" => (ImageDimension::Cube, false), + // "CubeArray" => (ImageDimension::Cube, true), + _ => return None, + }; + + Some(Type { + name: None, + inner: TypeInner::Image { + dim, + arrayed, + class, + }, + }) + }; + + vec_parse(word) + .or_else(|| mat_parse(word)) + .or_else(|| texture_parse(word)) + .or_else(|| image_parse(word)) + } + } +} + +pub const fn scalar_components(ty: &TypeInner) -> Option<Scalar> { + match *ty { + TypeInner::Scalar(scalar) + | TypeInner::Vector { scalar, .. } + | TypeInner::ValuePointer { scalar, .. } + | TypeInner::Matrix { scalar, .. } => Some(scalar), + _ => None, + } +} + +pub const fn type_power(scalar: Scalar) -> Option<u32> { + Some(match scalar.kind { + ScalarKind::Sint => 0, + ScalarKind::Uint => 1, + ScalarKind::Float if scalar.width == 4 => 2, + ScalarKind::Float => 3, + ScalarKind::Bool | ScalarKind::AbstractInt | ScalarKind::AbstractFloat => return None, + }) +} + +impl Context<'_> { + /// Resolves the types of the expressions until `expr` (inclusive) + /// + /// This needs to be done before the [`typifier`] can be queried for + /// the types of the expressions in the range between the last grow and `expr`. + /// + /// # Note + /// + /// The `resolve_type*` methods (like [`resolve_type`]) automatically + /// grow the [`typifier`] so calling this method is not necessary when using + /// them. + /// + /// [`typifier`]: Context::typifier + /// [`resolve_type`]: Self::resolve_type + pub(crate) fn typifier_grow(&mut self, expr: Handle<Expression>, meta: Span) -> Result<()> { + let resolve_ctx = ResolveContext::with_locals(self.module, &self.locals, &self.arguments); + + let typifier = if self.is_const { + &mut self.const_typifier + } else { + &mut self.typifier + }; + + let expressions = if self.is_const { + &self.module.const_expressions + } else { + &self.expressions + }; + + typifier + .grow(expr, expressions, &resolve_ctx) + .map_err(|error| Error { + kind: ErrorKind::SemanticError(format!("Can't resolve type: {error:?}").into()), + meta, + }) + } + + pub(crate) fn get_type(&self, expr: Handle<Expression>) -> &TypeInner { + let typifier = if self.is_const { + &self.const_typifier + } else { + &self.typifier + }; + + typifier.get(expr, &self.module.types) + } + + /// Gets the type for the result of the `expr` expression + /// + /// Automatically grows the [`typifier`] to `expr` so calling + /// [`typifier_grow`] is not necessary + /// + /// [`typifier`]: Context::typifier + /// [`typifier_grow`]: Self::typifier_grow + pub(crate) fn resolve_type( + &mut self, + expr: Handle<Expression>, + meta: Span, + ) -> Result<&TypeInner> { + self.typifier_grow(expr, meta)?; + Ok(self.get_type(expr)) + } + + /// Gets the type handle for the result of the `expr` expression + /// + /// Automatically grows the [`typifier`] to `expr` so calling + /// [`typifier_grow`] is not necessary + /// + /// # Note + /// + /// Consider using [`resolve_type`] whenever possible + /// since it doesn't require adding each type to the [`types`] arena + /// and it doesn't need to mutably borrow the [`Parser`][Self] + /// + /// [`types`]: crate::Module::types + /// [`typifier`]: Context::typifier + /// [`typifier_grow`]: Self::typifier_grow + /// [`resolve_type`]: Self::resolve_type + pub(crate) fn resolve_type_handle( + &mut self, + expr: Handle<Expression>, + meta: Span, + ) -> Result<Handle<Type>> { + self.typifier_grow(expr, meta)?; + + let typifier = if self.is_const { + &mut self.const_typifier + } else { + &mut self.typifier + }; + + Ok(typifier.register_type(expr, &mut self.module.types)) + } + + /// Invalidates the cached type resolution for `expr` forcing a recomputation + pub(crate) fn invalidate_expression( + &mut self, + expr: Handle<Expression>, + meta: Span, + ) -> Result<()> { + let resolve_ctx = ResolveContext::with_locals(self.module, &self.locals, &self.arguments); + + let typifier = if self.is_const { + &mut self.const_typifier + } else { + &mut self.typifier + }; + + typifier + .invalidate(expr, &self.expressions, &resolve_ctx) + .map_err(|error| Error { + kind: ErrorKind::SemanticError(format!("Can't resolve type: {error:?}").into()), + meta, + }) + } + + pub(crate) fn lift_up_const_expression( + &mut self, + expr: Handle<Expression>, + ) -> Result<Handle<Expression>> { + let meta = self.expressions.get_span(expr); + Ok(match self.expressions[expr] { + ref expr @ (Expression::Literal(_) + | Expression::Constant(_) + | Expression::ZeroValue(_)) => self.module.const_expressions.append(expr.clone(), meta), + Expression::Compose { ty, ref components } => { + let mut components = components.clone(); + for component in &mut components { + *component = self.lift_up_const_expression(*component)?; + } + self.module + .const_expressions + .append(Expression::Compose { ty, components }, meta) + } + Expression::Splat { size, value } => { + let value = self.lift_up_const_expression(value)?; + self.module + .const_expressions + .append(Expression::Splat { size, value }, meta) + } + _ => { + return Err(Error { + kind: ErrorKind::SemanticError("Expression is not const-expression".into()), + meta, + }) + } + }) + } +} diff --git a/third_party/rust/naga/src/front/glsl/variables.rs b/third_party/rust/naga/src/front/glsl/variables.rs new file mode 100644 index 0000000000..5af2b228f0 --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/variables.rs @@ -0,0 +1,646 @@ +use super::{ + ast::*, + context::{Context, ExprPos}, + error::{Error, ErrorKind}, + Frontend, Result, Span, +}; +use crate::{ + AddressSpace, Binding, BuiltIn, Constant, Expression, GlobalVariable, Handle, Interpolation, + LocalVariable, ResourceBinding, Scalar, ScalarKind, ShaderStage, SwizzleComponent, Type, + TypeInner, VectorSize, +}; + +pub struct VarDeclaration<'a, 'key> { + pub qualifiers: &'a mut TypeQualifiers<'key>, + pub ty: Handle<Type>, + pub name: Option<String>, + pub init: Option<Handle<Expression>>, + pub meta: Span, +} + +/// Information about a builtin used in [`add_builtin`](Frontend::add_builtin). +struct BuiltInData { + /// The type of the builtin. + inner: TypeInner, + /// The associated builtin class. + builtin: BuiltIn, + /// Whether the builtin can be written to or not. + mutable: bool, + /// The storage used for the builtin. + storage: StorageQualifier, +} + +pub enum GlobalOrConstant { + Global(Handle<GlobalVariable>), + Constant(Handle<Constant>), +} + +impl Frontend { + /// Adds a builtin and returns a variable reference to it + fn add_builtin( + &mut self, + ctx: &mut Context, + name: &str, + data: BuiltInData, + meta: Span, + ) -> Result<Option<VariableReference>> { + let ty = ctx.module.types.insert( + Type { + name: None, + inner: data.inner, + }, + meta, + ); + + let handle = ctx.module.global_variables.append( + GlobalVariable { + name: Some(name.into()), + space: AddressSpace::Private, + binding: None, + ty, + init: None, + }, + meta, + ); + + let idx = self.entry_args.len(); + self.entry_args.push(EntryArg { + name: None, + binding: Binding::BuiltIn(data.builtin), + handle, + storage: data.storage, + }); + + self.global_variables.push(( + name.into(), + GlobalLookup { + kind: GlobalLookupKind::Variable(handle), + entry_arg: Some(idx), + mutable: data.mutable, + }, + )); + + let expr = ctx.add_expression(Expression::GlobalVariable(handle), meta)?; + + let var = VariableReference { + expr, + load: true, + mutable: data.mutable, + constant: None, + entry_arg: Some(idx), + }; + + ctx.symbol_table.add_root(name.into(), var.clone()); + + Ok(Some(var)) + } + + pub(crate) fn lookup_variable( + &mut self, + ctx: &mut Context, + name: &str, + meta: Span, + ) -> Result<Option<VariableReference>> { + if let Some(var) = ctx.symbol_table.lookup(name).cloned() { + return Ok(Some(var)); + } + + let data = match name { + "gl_Position" => BuiltInData { + inner: TypeInner::Vector { + size: VectorSize::Quad, + scalar: Scalar::F32, + }, + builtin: BuiltIn::Position { invariant: false }, + mutable: true, + storage: StorageQualifier::Output, + }, + "gl_FragCoord" => BuiltInData { + inner: TypeInner::Vector { + size: VectorSize::Quad, + scalar: Scalar::F32, + }, + builtin: BuiltIn::Position { invariant: false }, + mutable: false, + storage: StorageQualifier::Input, + }, + "gl_PointCoord" => BuiltInData { + inner: TypeInner::Vector { + size: VectorSize::Bi, + scalar: Scalar::F32, + }, + builtin: BuiltIn::PointCoord, + mutable: false, + storage: StorageQualifier::Input, + }, + "gl_GlobalInvocationID" + | "gl_NumWorkGroups" + | "gl_WorkGroupSize" + | "gl_WorkGroupID" + | "gl_LocalInvocationID" => BuiltInData { + inner: TypeInner::Vector { + size: VectorSize::Tri, + scalar: Scalar::U32, + }, + builtin: match name { + "gl_GlobalInvocationID" => BuiltIn::GlobalInvocationId, + "gl_NumWorkGroups" => BuiltIn::NumWorkGroups, + "gl_WorkGroupSize" => BuiltIn::WorkGroupSize, + "gl_WorkGroupID" => BuiltIn::WorkGroupId, + "gl_LocalInvocationID" => BuiltIn::LocalInvocationId, + _ => unreachable!(), + }, + mutable: false, + storage: StorageQualifier::Input, + }, + "gl_FrontFacing" => BuiltInData { + inner: TypeInner::Scalar(Scalar::BOOL), + builtin: BuiltIn::FrontFacing, + mutable: false, + storage: StorageQualifier::Input, + }, + "gl_PointSize" | "gl_FragDepth" => BuiltInData { + inner: TypeInner::Scalar(Scalar::F32), + builtin: match name { + "gl_PointSize" => BuiltIn::PointSize, + "gl_FragDepth" => BuiltIn::FragDepth, + _ => unreachable!(), + }, + mutable: true, + storage: StorageQualifier::Output, + }, + "gl_ClipDistance" | "gl_CullDistance" => { + let base = ctx.module.types.insert( + Type { + name: None, + inner: TypeInner::Scalar(Scalar::F32), + }, + meta, + ); + + BuiltInData { + inner: TypeInner::Array { + base, + size: crate::ArraySize::Dynamic, + stride: 4, + }, + builtin: match name { + "gl_ClipDistance" => BuiltIn::ClipDistance, + "gl_CullDistance" => BuiltIn::CullDistance, + _ => unreachable!(), + }, + mutable: self.meta.stage == ShaderStage::Vertex, + storage: StorageQualifier::Output, + } + } + _ => { + let builtin = match name { + "gl_BaseVertex" => BuiltIn::BaseVertex, + "gl_BaseInstance" => BuiltIn::BaseInstance, + "gl_PrimitiveID" => BuiltIn::PrimitiveIndex, + "gl_InstanceIndex" => BuiltIn::InstanceIndex, + "gl_VertexIndex" => BuiltIn::VertexIndex, + "gl_SampleID" => BuiltIn::SampleIndex, + "gl_LocalInvocationIndex" => BuiltIn::LocalInvocationIndex, + _ => return Ok(None), + }; + + BuiltInData { + inner: TypeInner::Scalar(Scalar::U32), + builtin, + mutable: false, + storage: StorageQualifier::Input, + } + } + }; + + self.add_builtin(ctx, name, data, meta) + } + + pub(crate) fn make_variable_invariant( + &mut self, + ctx: &mut Context, + name: &str, + meta: Span, + ) -> Result<()> { + if let Some(var) = self.lookup_variable(ctx, name, meta)? { + if let Some(index) = var.entry_arg { + if let Binding::BuiltIn(BuiltIn::Position { ref mut invariant }) = + self.entry_args[index].binding + { + *invariant = true; + } + } + } + Ok(()) + } + + pub(crate) fn field_selection( + &mut self, + ctx: &mut Context, + pos: ExprPos, + expression: Handle<Expression>, + name: &str, + meta: Span, + ) -> Result<Handle<Expression>> { + let (ty, is_pointer) = match *ctx.resolve_type(expression, meta)? { + TypeInner::Pointer { base, .. } => (&ctx.module.types[base].inner, true), + ref ty => (ty, false), + }; + match *ty { + TypeInner::Struct { ref members, .. } => { + let index = members + .iter() + .position(|m| m.name == Some(name.into())) + .ok_or_else(|| Error { + kind: ErrorKind::UnknownField(name.into()), + meta, + })?; + let pointer = ctx.add_expression( + Expression::AccessIndex { + base: expression, + index: index as u32, + }, + meta, + )?; + + Ok(match pos { + ExprPos::Rhs if is_pointer => { + ctx.add_expression(Expression::Load { pointer }, meta)? + } + _ => pointer, + }) + } + // swizzles (xyzw, rgba, stpq) + TypeInner::Vector { size, .. } => { + let check_swizzle_components = |comps: &str| { + name.chars() + .map(|c| { + comps + .find(c) + .filter(|i| *i < size as usize) + .map(|i| SwizzleComponent::from_index(i as u32)) + }) + .collect::<Option<Vec<SwizzleComponent>>>() + }; + + let components = check_swizzle_components("xyzw") + .or_else(|| check_swizzle_components("rgba")) + .or_else(|| check_swizzle_components("stpq")); + + if let Some(components) = components { + if let ExprPos::Lhs = pos { + let not_unique = (1..components.len()) + .any(|i| components[i..].contains(&components[i - 1])); + if not_unique { + self.errors.push(Error { + kind: + ErrorKind::SemanticError( + format!( + "swizzle cannot have duplicate components in left-hand-side expression for \"{name:?}\"" + ) + .into(), + ), + meta , + }) + } + } + + let mut pattern = [SwizzleComponent::X; 4]; + for (pat, component) in pattern.iter_mut().zip(&components) { + *pat = *component; + } + + // flatten nested swizzles (vec.zyx.xy.x => vec.z) + let mut expression = expression; + if let Expression::Swizzle { + size: _, + vector, + pattern: ref src_pattern, + } = ctx[expression] + { + expression = vector; + for pat in &mut pattern { + *pat = src_pattern[pat.index() as usize]; + } + } + + let size = match components.len() { + // Swizzles with just one component are accesses and not swizzles + 1 => { + match pos { + // If the position is in the right hand side and the base + // vector is a pointer, load it, otherwise the swizzle would + // produce a pointer + ExprPos::Rhs if is_pointer => { + expression = ctx.add_expression( + Expression::Load { + pointer: expression, + }, + meta, + )?; + } + _ => {} + }; + return ctx.add_expression( + Expression::AccessIndex { + base: expression, + index: pattern[0].index(), + }, + meta, + ); + } + 2 => VectorSize::Bi, + 3 => VectorSize::Tri, + 4 => VectorSize::Quad, + _ => { + self.errors.push(Error { + kind: ErrorKind::SemanticError( + format!("Bad swizzle size for \"{name:?}\"").into(), + ), + meta, + }); + + VectorSize::Quad + } + }; + + if is_pointer { + // NOTE: for lhs expression, this extra load ends up as an unused expr, because the + // assignment will extract the pointer and use it directly anyway. Unfortunately we + // need it for validation to pass, as swizzles cannot operate on pointer values. + expression = ctx.add_expression( + Expression::Load { + pointer: expression, + }, + meta, + )?; + } + + Ok(ctx.add_expression( + Expression::Swizzle { + size, + vector: expression, + pattern, + }, + meta, + )?) + } else { + Err(Error { + kind: ErrorKind::SemanticError( + format!("Invalid swizzle for vector \"{name}\"").into(), + ), + meta, + }) + } + } + _ => Err(Error { + kind: ErrorKind::SemanticError( + format!("Can't lookup field on this type \"{name}\"").into(), + ), + meta, + }), + } + } + + pub(crate) fn add_global_var( + &mut self, + ctx: &mut Context, + VarDeclaration { + qualifiers, + mut ty, + name, + init, + meta, + }: VarDeclaration, + ) -> Result<GlobalOrConstant> { + let storage = qualifiers.storage.0; + let (ret, lookup) = match storage { + StorageQualifier::Input | StorageQualifier::Output => { + let input = storage == StorageQualifier::Input; + // TODO: glslang seems to use a counter for variables without + // explicit location (even if that causes collisions) + let location = qualifiers + .uint_layout_qualifier("location", &mut self.errors) + .unwrap_or(0); + let interpolation = qualifiers.interpolation.take().map(|(i, _)| i).or_else(|| { + let kind = ctx.module.types[ty].inner.scalar_kind()?; + Some(match kind { + ScalarKind::Float => Interpolation::Perspective, + _ => Interpolation::Flat, + }) + }); + let sampling = qualifiers.sampling.take().map(|(s, _)| s); + + let handle = ctx.module.global_variables.append( + GlobalVariable { + name: name.clone(), + space: AddressSpace::Private, + binding: None, + ty, + init, + }, + meta, + ); + + let idx = self.entry_args.len(); + self.entry_args.push(EntryArg { + name: name.clone(), + binding: Binding::Location { + location, + interpolation, + sampling, + second_blend_source: false, + }, + handle, + storage, + }); + + let lookup = GlobalLookup { + kind: GlobalLookupKind::Variable(handle), + entry_arg: Some(idx), + mutable: !input, + }; + + (GlobalOrConstant::Global(handle), lookup) + } + StorageQualifier::Const => { + let init = init.ok_or_else(|| Error { + kind: ErrorKind::SemanticError("const values must have an initializer".into()), + meta, + })?; + + let constant = Constant { + name: name.clone(), + r#override: crate::Override::None, + ty, + init, + }; + let handle = ctx.module.constants.fetch_or_append(constant, meta); + + let lookup = GlobalLookup { + kind: GlobalLookupKind::Constant(handle, ty), + entry_arg: None, + mutable: false, + }; + + (GlobalOrConstant::Constant(handle), lookup) + } + StorageQualifier::AddressSpace(mut space) => { + match space { + AddressSpace::Storage { ref mut access } => { + if let Some((allowed_access, _)) = qualifiers.storage_access.take() { + *access = allowed_access; + } + } + AddressSpace::Uniform => match ctx.module.types[ty].inner { + TypeInner::Image { + class, + dim, + arrayed, + } => { + if let crate::ImageClass::Storage { + mut access, + mut format, + } = class + { + if let Some((allowed_access, _)) = qualifiers.storage_access.take() + { + access = allowed_access; + } + + match qualifiers.layout_qualifiers.remove(&QualifierKey::Format) { + Some((QualifierValue::Format(f), _)) => format = f, + // TODO: glsl supports images without format qualifier + // if they are `writeonly` + None => self.errors.push(Error { + kind: ErrorKind::SemanticError( + "image types require a format layout qualifier".into(), + ), + meta, + }), + _ => unreachable!(), + } + + ty = ctx.module.types.insert( + Type { + name: None, + inner: TypeInner::Image { + dim, + arrayed, + class: crate::ImageClass::Storage { format, access }, + }, + }, + meta, + ); + } + + space = AddressSpace::Handle + } + TypeInner::Sampler { .. } => space = AddressSpace::Handle, + _ => { + if qualifiers.none_layout_qualifier("push_constant", &mut self.errors) { + space = AddressSpace::PushConstant + } + } + }, + AddressSpace::Function => space = AddressSpace::Private, + _ => {} + }; + + let binding = match space { + AddressSpace::Uniform | AddressSpace::Storage { .. } | AddressSpace::Handle => { + let binding = qualifiers.uint_layout_qualifier("binding", &mut self.errors); + if binding.is_none() { + self.errors.push(Error { + kind: ErrorKind::SemanticError( + "uniform/buffer blocks require layout(binding=X)".into(), + ), + meta, + }); + } + let set = qualifiers.uint_layout_qualifier("set", &mut self.errors); + binding.map(|binding| ResourceBinding { + group: set.unwrap_or(0), + binding, + }) + } + _ => None, + }; + + let handle = ctx.module.global_variables.append( + GlobalVariable { + name: name.clone(), + space, + binding, + ty, + init, + }, + meta, + ); + + let lookup = GlobalLookup { + kind: GlobalLookupKind::Variable(handle), + entry_arg: None, + mutable: true, + }; + + (GlobalOrConstant::Global(handle), lookup) + } + }; + + if let Some(name) = name { + ctx.add_global(&name, lookup)?; + + self.global_variables.push((name, lookup)); + } + + qualifiers.unused_errors(&mut self.errors); + + Ok(ret) + } + + pub(crate) fn add_local_var( + &mut self, + ctx: &mut Context, + decl: VarDeclaration, + ) -> Result<Handle<Expression>> { + let storage = decl.qualifiers.storage; + let mutable = match storage.0 { + StorageQualifier::AddressSpace(AddressSpace::Function) => true, + StorageQualifier::Const => false, + _ => { + self.errors.push(Error { + kind: ErrorKind::SemanticError("Locals cannot have a storage qualifier".into()), + meta: storage.1, + }); + true + } + }; + + let handle = ctx.locals.append( + LocalVariable { + name: decl.name.clone(), + ty: decl.ty, + init: decl.init, + }, + decl.meta, + ); + let expr = ctx.add_expression(Expression::LocalVariable(handle), decl.meta)?; + + if let Some(name) = decl.name { + let maybe_var = ctx.add_local_var(name.clone(), expr, mutable); + + if maybe_var.is_some() { + self.errors.push(Error { + kind: ErrorKind::VariableAlreadyDeclared(name), + meta: decl.meta, + }) + } + } + + decl.qualifiers.unused_errors(&mut self.errors); + + Ok(expr) + } +} diff --git a/third_party/rust/naga/src/front/interpolator.rs b/third_party/rust/naga/src/front/interpolator.rs new file mode 100644 index 0000000000..0196a2254d --- /dev/null +++ b/third_party/rust/naga/src/front/interpolator.rs @@ -0,0 +1,62 @@ +/*! +Interpolation defaults. +*/ + +impl crate::Binding { + /// Apply the usual default interpolation for `ty` to `binding`. + /// + /// This function is a utility front ends may use to satisfy the Naga IR's + /// requirement, meant to ensure that input languages' policies have been + /// applied appropriately, that all I/O `Binding`s from the vertex shader to the + /// fragment shader must have non-`None` `interpolation` values. + /// + /// All the shader languages Naga supports have similar rules: + /// perspective-correct, center-sampled interpolation is the default for any + /// binding that can vary, and everything else either defaults to flat, or + /// requires an explicit flat qualifier/attribute/what-have-you. + /// + /// If `binding` is not a [`Location`] binding, or if its [`interpolation`] is + /// already set, then make no changes. Otherwise, set `binding`'s interpolation + /// and sampling to reasonable defaults depending on `ty`, the type of the value + /// being interpolated: + /// + /// - If `ty` is a floating-point scalar, vector, or matrix type, then + /// default to [`Perspective`] interpolation and [`Center`] sampling. + /// + /// - If `ty` is an integral scalar or vector, then default to [`Flat`] + /// interpolation, which has no associated sampling. + /// + /// - For any other types, make no change. Such types are not permitted as + /// user-defined IO values, and will probably be flagged by the verifier + /// + /// When structs appear in input or output types, each member ought to have its + /// own [`Binding`], so structs are simply covered by the third case. + /// + /// [`Binding`]: crate::Binding + /// [`Location`]: crate::Binding::Location + /// [`interpolation`]: crate::Binding::Location::interpolation + /// [`Perspective`]: crate::Interpolation::Perspective + /// [`Flat`]: crate::Interpolation::Flat + /// [`Center`]: crate::Sampling::Center + pub fn apply_default_interpolation(&mut self, ty: &crate::TypeInner) { + if let crate::Binding::Location { + location: _, + interpolation: ref mut interpolation @ None, + ref mut sampling, + second_blend_source: _, + } = *self + { + match ty.scalar_kind() { + Some(crate::ScalarKind::Float) => { + *interpolation = Some(crate::Interpolation::Perspective); + *sampling = Some(crate::Sampling::Center); + } + Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => { + *interpolation = Some(crate::Interpolation::Flat); + *sampling = None; + } + Some(_) | None => {} + } + } + } +} diff --git a/third_party/rust/naga/src/front/mod.rs b/third_party/rust/naga/src/front/mod.rs new file mode 100644 index 0000000000..e1f99452e1 --- /dev/null +++ b/third_party/rust/naga/src/front/mod.rs @@ -0,0 +1,328 @@ +/*! +Frontend parsers that consume binary and text shaders and load them into [`Module`](super::Module)s. +*/ + +mod interpolator; +mod type_gen; + +#[cfg(feature = "glsl-in")] +pub mod glsl; +#[cfg(feature = "spv-in")] +pub mod spv; +#[cfg(feature = "wgsl-in")] +pub mod wgsl; + +use crate::{ + arena::{Arena, Handle, UniqueArena}, + proc::{ResolveContext, ResolveError, TypeResolution}, + FastHashMap, +}; +use std::ops; + +/// A table of types for an `Arena<Expression>`. +/// +/// A front end can use a `Typifier` to get types for an arena's expressions +/// while it is still contributing expressions to it. At any point, you can call +/// [`typifier.grow(expr, arena, ctx)`], where `expr` is a `Handle<Expression>` +/// referring to something in `arena`, and the `Typifier` will resolve the types +/// of all the expressions up to and including `expr`. Then you can write +/// `typifier[handle]` to get the type of any handle at or before `expr`. +/// +/// Note that `Typifier` does *not* build an `Arena<Type>` as a part of its +/// usual operation. Ideally, a module's type arena should only contain types +/// actually needed by `Handle<Type>`s elsewhere in the module — functions, +/// variables, [`Compose`] expressions, other types, and so on — so we don't +/// want every little thing that occurs as the type of some intermediate +/// expression to show up there. +/// +/// Instead, `Typifier` accumulates a [`TypeResolution`] for each expression, +/// which refers to the `Arena<Type>` in the [`ResolveContext`] passed to `grow` +/// as needed. [`TypeResolution`] is a lightweight representation for +/// intermediate types like this; see its documentation for details. +/// +/// If you do need to register a `Typifier`'s conclusion in an `Arena<Type>` +/// (say, for a [`LocalVariable`] whose type you've inferred), you can use +/// [`register_type`] to do so. +/// +/// [`typifier.grow(expr, arena)`]: Typifier::grow +/// [`register_type`]: Typifier::register_type +/// [`Compose`]: crate::Expression::Compose +/// [`LocalVariable`]: crate::LocalVariable +#[derive(Debug, Default)] +pub struct Typifier { + resolutions: Vec<TypeResolution>, +} + +impl Typifier { + pub const fn new() -> Self { + Typifier { + resolutions: Vec::new(), + } + } + + pub fn reset(&mut self) { + self.resolutions.clear() + } + + pub fn get<'a>( + &'a self, + expr_handle: Handle<crate::Expression>, + types: &'a UniqueArena<crate::Type>, + ) -> &'a crate::TypeInner { + self.resolutions[expr_handle.index()].inner_with(types) + } + + /// Add an expression's type to an `Arena<Type>`. + /// + /// Add the type of `expr_handle` to `types`, and return a `Handle<Type>` + /// referring to it. + /// + /// # Note + /// + /// If you just need a [`TypeInner`] for `expr_handle`'s type, consider + /// using `typifier[expression].inner_with(types)` instead. Calling + /// [`TypeResolution::inner_with`] often lets us avoid adding anything to + /// the arena, which can significantly reduce the number of types that end + /// up in the final module. + /// + /// [`TypeInner`]: crate::TypeInner + pub fn register_type( + &self, + expr_handle: Handle<crate::Expression>, + types: &mut UniqueArena<crate::Type>, + ) -> Handle<crate::Type> { + match self[expr_handle].clone() { + TypeResolution::Handle(handle) => handle, + TypeResolution::Value(inner) => { + types.insert(crate::Type { name: None, inner }, crate::Span::UNDEFINED) + } + } + } + + /// Grow this typifier until it contains a type for `expr_handle`. + pub fn grow( + &mut self, + expr_handle: Handle<crate::Expression>, + expressions: &Arena<crate::Expression>, + ctx: &ResolveContext, + ) -> Result<(), ResolveError> { + if self.resolutions.len() <= expr_handle.index() { + for (eh, expr) in expressions.iter().skip(self.resolutions.len()) { + //Note: the closure can't `Err` by construction + let resolution = ctx.resolve(expr, |h| Ok(&self.resolutions[h.index()]))?; + log::debug!("Resolving {:?} = {:?} : {:?}", eh, expr, resolution); + self.resolutions.push(resolution); + } + } + Ok(()) + } + + /// Recompute the type resolution for `expr_handle`. + /// + /// If the type of `expr_handle` hasn't yet been calculated, call + /// [`grow`](Self::grow) to ensure it is covered. + /// + /// In either case, when this returns, `self[expr_handle]` should be an + /// updated type resolution for `expr_handle`. + pub fn invalidate( + &mut self, + expr_handle: Handle<crate::Expression>, + expressions: &Arena<crate::Expression>, + ctx: &ResolveContext, + ) -> Result<(), ResolveError> { + if self.resolutions.len() <= expr_handle.index() { + self.grow(expr_handle, expressions, ctx) + } else { + let expr = &expressions[expr_handle]; + //Note: the closure can't `Err` by construction + let resolution = ctx.resolve(expr, |h| Ok(&self.resolutions[h.index()]))?; + self.resolutions[expr_handle.index()] = resolution; + Ok(()) + } + } +} + +impl ops::Index<Handle<crate::Expression>> for Typifier { + type Output = TypeResolution; + fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output { + &self.resolutions[handle.index()] + } +} + +/// Type representing a lexical scope, associating a name to a single variable +/// +/// The scope is generic over the variable representation and name representation +/// in order to allow larger flexibility on the frontends on how they might +/// represent them. +type Scope<Name, Var> = FastHashMap<Name, Var>; + +/// Structure responsible for managing variable lookups and keeping track of +/// lexical scopes +/// +/// The symbol table is generic over the variable representation and its name +/// to allow larger flexibility on the frontends on how they might represent them. +/// +/// ``` +/// use naga::front::SymbolTable; +/// +/// // Create a new symbol table with `u32`s representing the variable +/// let mut symbol_table: SymbolTable<&str, u32> = SymbolTable::default(); +/// +/// // Add two variables named `var1` and `var2` with 0 and 2 respectively +/// symbol_table.add("var1", 0); +/// symbol_table.add("var2", 2); +/// +/// // Check that `var1` exists and is `0` +/// assert_eq!(symbol_table.lookup("var1"), Some(&0)); +/// +/// // Push a new scope and add a variable to it named `var1` shadowing the +/// // variable of our previous scope +/// symbol_table.push_scope(); +/// symbol_table.add("var1", 1); +/// +/// // Check that `var1` now points to the new value of `1` and `var2` still +/// // exists with its value of `2` +/// assert_eq!(symbol_table.lookup("var1"), Some(&1)); +/// assert_eq!(symbol_table.lookup("var2"), Some(&2)); +/// +/// // Pop the scope +/// symbol_table.pop_scope(); +/// +/// // Check that `var1` now refers to our initial variable with value `0` +/// assert_eq!(symbol_table.lookup("var1"), Some(&0)); +/// ``` +/// +/// Scopes are ordered as a LIFO stack so a variable defined in a later scope +/// with the same name as another variable defined in a earlier scope will take +/// precedence in the lookup. Scopes can be added with [`push_scope`] and +/// removed with [`pop_scope`]. +/// +/// A root scope is added when the symbol table is created and must always be +/// present. Trying to pop it will result in a panic. +/// +/// Variables can be added with [`add`] and looked up with [`lookup`]. Adding a +/// variable will do so in the currently active scope and as mentioned +/// previously a lookup will search from the current scope to the root scope. +/// +/// [`push_scope`]: Self::push_scope +/// [`pop_scope`]: Self::push_scope +/// [`add`]: Self::add +/// [`lookup`]: Self::lookup +pub struct SymbolTable<Name, Var> { + /// Stack of lexical scopes. Not all scopes are active; see [`cursor`]. + /// + /// [`cursor`]: Self::cursor + scopes: Vec<Scope<Name, Var>>, + /// Limit of the [`scopes`] stack (exclusive). By using a separate value for + /// the stack length instead of `Vec`'s own internal length, the scopes can + /// be reused to cache memory allocations. + /// + /// [`scopes`]: Self::scopes + cursor: usize, +} + +impl<Name, Var> SymbolTable<Name, Var> { + /// Adds a new lexical scope. + /// + /// All variables declared after this point will be added to this scope + /// until another scope is pushed or [`pop_scope`] is called, causing this + /// scope to be removed along with all variables added to it. + /// + /// [`pop_scope`]: Self::pop_scope + pub fn push_scope(&mut self) { + // If the cursor is equal to the scope's stack length then we need to + // push another empty scope. Otherwise we can reuse the already existing + // scope. + if self.scopes.len() == self.cursor { + self.scopes.push(FastHashMap::default()) + } else { + self.scopes[self.cursor].clear(); + } + + self.cursor += 1; + } + + /// Removes the current lexical scope and all its variables + /// + /// # PANICS + /// - If the current lexical scope is the root scope + pub fn pop_scope(&mut self) { + // Despite the method title, the variables are only deleted when the + // scope is reused. This is because while a clear is inevitable if the + // scope needs to be reused, there are cases where the scope might be + // popped and not reused, i.e. if another scope with the same nesting + // level is never pushed again. + assert!(self.cursor != 1, "Tried to pop the root scope"); + + self.cursor -= 1; + } +} + +impl<Name, Var> SymbolTable<Name, Var> +where + Name: std::hash::Hash + Eq, +{ + /// Perform a lookup for a variable named `name`. + /// + /// As stated in the struct level documentation the lookup will proceed from + /// the current scope to the root scope, returning `Some` when a variable is + /// found or `None` if there doesn't exist a variable with `name` in any + /// scope. + pub fn lookup<Q: ?Sized>(&self, name: &Q) -> Option<&Var> + where + Name: std::borrow::Borrow<Q>, + Q: std::hash::Hash + Eq, + { + // Iterate backwards trough the scopes and try to find the variable + for scope in self.scopes[..self.cursor].iter().rev() { + if let Some(var) = scope.get(name) { + return Some(var); + } + } + + None + } + + /// Adds a new variable to the current scope. + /// + /// Returns the previous variable with the same name in this scope if it + /// exists, so that the frontend might handle it in case variable shadowing + /// is disallowed. + pub fn add(&mut self, name: Name, var: Var) -> Option<Var> { + self.scopes[self.cursor - 1].insert(name, var) + } + + /// Adds a new variable to the root scope. + /// + /// This is used in GLSL for builtins which aren't known in advance and only + /// when used for the first time, so there must be a way to add those + /// declarations to the root unconditionally from the current scope. + /// + /// Returns the previous variable with the same name in the root scope if it + /// exists, so that the frontend might handle it in case variable shadowing + /// is disallowed. + pub fn add_root(&mut self, name: Name, var: Var) -> Option<Var> { + self.scopes[0].insert(name, var) + } +} + +impl<Name, Var> Default for SymbolTable<Name, Var> { + /// Constructs a new symbol table with a root scope + fn default() -> Self { + Self { + scopes: vec![FastHashMap::default()], + cursor: 1, + } + } +} + +use std::fmt; + +impl<Name: fmt::Debug, Var: fmt::Debug> fmt::Debug for SymbolTable<Name, Var> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("SymbolTable ")?; + f.debug_list() + .entries(self.scopes[..self.cursor].iter()) + .finish() + } +} diff --git a/third_party/rust/naga/src/front/spv/convert.rs b/third_party/rust/naga/src/front/spv/convert.rs new file mode 100644 index 0000000000..f0a714fbeb --- /dev/null +++ b/third_party/rust/naga/src/front/spv/convert.rs @@ -0,0 +1,179 @@ +use super::error::Error; +use std::convert::TryInto; + +pub(super) const fn map_binary_operator(word: spirv::Op) -> Result<crate::BinaryOperator, Error> { + use crate::BinaryOperator; + use spirv::Op; + + match word { + // Arithmetic Instructions +, -, *, /, % + Op::IAdd | Op::FAdd => Ok(BinaryOperator::Add), + Op::ISub | Op::FSub => Ok(BinaryOperator::Subtract), + Op::IMul | Op::FMul => Ok(BinaryOperator::Multiply), + Op::UDiv | Op::SDiv | Op::FDiv => Ok(BinaryOperator::Divide), + Op::SRem => Ok(BinaryOperator::Modulo), + // Relational and Logical Instructions + Op::IEqual | Op::FOrdEqual | Op::FUnordEqual | Op::LogicalEqual => { + Ok(BinaryOperator::Equal) + } + Op::INotEqual | Op::FOrdNotEqual | Op::FUnordNotEqual | Op::LogicalNotEqual => { + Ok(BinaryOperator::NotEqual) + } + Op::ULessThan | Op::SLessThan | Op::FOrdLessThan | Op::FUnordLessThan => { + Ok(BinaryOperator::Less) + } + Op::ULessThanEqual + | Op::SLessThanEqual + | Op::FOrdLessThanEqual + | Op::FUnordLessThanEqual => Ok(BinaryOperator::LessEqual), + Op::UGreaterThan | Op::SGreaterThan | Op::FOrdGreaterThan | Op::FUnordGreaterThan => { + Ok(BinaryOperator::Greater) + } + Op::UGreaterThanEqual + | Op::SGreaterThanEqual + | Op::FOrdGreaterThanEqual + | Op::FUnordGreaterThanEqual => Ok(BinaryOperator::GreaterEqual), + Op::BitwiseOr => Ok(BinaryOperator::InclusiveOr), + Op::BitwiseXor => Ok(BinaryOperator::ExclusiveOr), + Op::BitwiseAnd => Ok(BinaryOperator::And), + _ => Err(Error::UnknownBinaryOperator(word)), + } +} + +pub(super) const fn map_relational_fun( + word: spirv::Op, +) -> Result<crate::RelationalFunction, Error> { + use crate::RelationalFunction as Rf; + use spirv::Op; + + match word { + Op::All => Ok(Rf::All), + Op::Any => Ok(Rf::Any), + Op::IsNan => Ok(Rf::IsNan), + Op::IsInf => Ok(Rf::IsInf), + _ => Err(Error::UnknownRelationalFunction(word)), + } +} + +pub(super) const fn map_vector_size(word: spirv::Word) -> Result<crate::VectorSize, Error> { + match word { + 2 => Ok(crate::VectorSize::Bi), + 3 => Ok(crate::VectorSize::Tri), + 4 => Ok(crate::VectorSize::Quad), + _ => Err(Error::InvalidVectorSize(word)), + } +} + +pub(super) fn map_image_dim(word: spirv::Word) -> Result<crate::ImageDimension, Error> { + use spirv::Dim as D; + match D::from_u32(word) { + Some(D::Dim1D) => Ok(crate::ImageDimension::D1), + Some(D::Dim2D) => Ok(crate::ImageDimension::D2), + Some(D::Dim3D) => Ok(crate::ImageDimension::D3), + Some(D::DimCube) => Ok(crate::ImageDimension::Cube), + _ => Err(Error::UnsupportedImageDim(word)), + } +} + +pub(super) fn map_image_format(word: spirv::Word) -> Result<crate::StorageFormat, Error> { + match spirv::ImageFormat::from_u32(word) { + Some(spirv::ImageFormat::R8) => Ok(crate::StorageFormat::R8Unorm), + Some(spirv::ImageFormat::R8Snorm) => Ok(crate::StorageFormat::R8Snorm), + Some(spirv::ImageFormat::R8ui) => Ok(crate::StorageFormat::R8Uint), + Some(spirv::ImageFormat::R8i) => Ok(crate::StorageFormat::R8Sint), + Some(spirv::ImageFormat::R16) => Ok(crate::StorageFormat::R16Unorm), + Some(spirv::ImageFormat::R16Snorm) => Ok(crate::StorageFormat::R16Snorm), + Some(spirv::ImageFormat::R16ui) => Ok(crate::StorageFormat::R16Uint), + Some(spirv::ImageFormat::R16i) => Ok(crate::StorageFormat::R16Sint), + Some(spirv::ImageFormat::R16f) => Ok(crate::StorageFormat::R16Float), + Some(spirv::ImageFormat::Rg8) => Ok(crate::StorageFormat::Rg8Unorm), + Some(spirv::ImageFormat::Rg8Snorm) => Ok(crate::StorageFormat::Rg8Snorm), + Some(spirv::ImageFormat::Rg8ui) => Ok(crate::StorageFormat::Rg8Uint), + Some(spirv::ImageFormat::Rg8i) => Ok(crate::StorageFormat::Rg8Sint), + Some(spirv::ImageFormat::R32ui) => Ok(crate::StorageFormat::R32Uint), + Some(spirv::ImageFormat::R32i) => Ok(crate::StorageFormat::R32Sint), + Some(spirv::ImageFormat::R32f) => Ok(crate::StorageFormat::R32Float), + Some(spirv::ImageFormat::Rg16) => Ok(crate::StorageFormat::Rg16Unorm), + Some(spirv::ImageFormat::Rg16Snorm) => Ok(crate::StorageFormat::Rg16Snorm), + Some(spirv::ImageFormat::Rg16ui) => Ok(crate::StorageFormat::Rg16Uint), + Some(spirv::ImageFormat::Rg16i) => Ok(crate::StorageFormat::Rg16Sint), + Some(spirv::ImageFormat::Rg16f) => Ok(crate::StorageFormat::Rg16Float), + Some(spirv::ImageFormat::Rgba8) => Ok(crate::StorageFormat::Rgba8Unorm), + Some(spirv::ImageFormat::Rgba8Snorm) => Ok(crate::StorageFormat::Rgba8Snorm), + Some(spirv::ImageFormat::Rgba8ui) => Ok(crate::StorageFormat::Rgba8Uint), + Some(spirv::ImageFormat::Rgba8i) => Ok(crate::StorageFormat::Rgba8Sint), + Some(spirv::ImageFormat::Rgb10a2ui) => Ok(crate::StorageFormat::Rgb10a2Uint), + Some(spirv::ImageFormat::Rgb10A2) => Ok(crate::StorageFormat::Rgb10a2Unorm), + Some(spirv::ImageFormat::R11fG11fB10f) => Ok(crate::StorageFormat::Rg11b10Float), + Some(spirv::ImageFormat::Rg32ui) => Ok(crate::StorageFormat::Rg32Uint), + Some(spirv::ImageFormat::Rg32i) => Ok(crate::StorageFormat::Rg32Sint), + Some(spirv::ImageFormat::Rg32f) => Ok(crate::StorageFormat::Rg32Float), + Some(spirv::ImageFormat::Rgba16) => Ok(crate::StorageFormat::Rgba16Unorm), + Some(spirv::ImageFormat::Rgba16Snorm) => Ok(crate::StorageFormat::Rgba16Snorm), + Some(spirv::ImageFormat::Rgba16ui) => Ok(crate::StorageFormat::Rgba16Uint), + Some(spirv::ImageFormat::Rgba16i) => Ok(crate::StorageFormat::Rgba16Sint), + Some(spirv::ImageFormat::Rgba16f) => Ok(crate::StorageFormat::Rgba16Float), + Some(spirv::ImageFormat::Rgba32ui) => Ok(crate::StorageFormat::Rgba32Uint), + Some(spirv::ImageFormat::Rgba32i) => Ok(crate::StorageFormat::Rgba32Sint), + Some(spirv::ImageFormat::Rgba32f) => Ok(crate::StorageFormat::Rgba32Float), + _ => Err(Error::UnsupportedImageFormat(word)), + } +} + +pub(super) fn map_width(word: spirv::Word) -> Result<crate::Bytes, Error> { + (word >> 3) // bits to bytes + .try_into() + .map_err(|_| Error::InvalidTypeWidth(word)) +} + +pub(super) fn map_builtin(word: spirv::Word, invariant: bool) -> Result<crate::BuiltIn, Error> { + use spirv::BuiltIn as Bi; + Ok(match spirv::BuiltIn::from_u32(word) { + Some(Bi::Position | Bi::FragCoord) => crate::BuiltIn::Position { invariant }, + Some(Bi::ViewIndex) => crate::BuiltIn::ViewIndex, + // vertex + Some(Bi::BaseInstance) => crate::BuiltIn::BaseInstance, + Some(Bi::BaseVertex) => crate::BuiltIn::BaseVertex, + Some(Bi::ClipDistance) => crate::BuiltIn::ClipDistance, + Some(Bi::CullDistance) => crate::BuiltIn::CullDistance, + Some(Bi::InstanceIndex) => crate::BuiltIn::InstanceIndex, + Some(Bi::PointSize) => crate::BuiltIn::PointSize, + Some(Bi::VertexIndex) => crate::BuiltIn::VertexIndex, + // fragment + Some(Bi::FragDepth) => crate::BuiltIn::FragDepth, + Some(Bi::PointCoord) => crate::BuiltIn::PointCoord, + Some(Bi::FrontFacing) => crate::BuiltIn::FrontFacing, + Some(Bi::PrimitiveId) => crate::BuiltIn::PrimitiveIndex, + Some(Bi::SampleId) => crate::BuiltIn::SampleIndex, + Some(Bi::SampleMask) => crate::BuiltIn::SampleMask, + // compute + Some(Bi::GlobalInvocationId) => crate::BuiltIn::GlobalInvocationId, + Some(Bi::LocalInvocationId) => crate::BuiltIn::LocalInvocationId, + Some(Bi::LocalInvocationIndex) => crate::BuiltIn::LocalInvocationIndex, + Some(Bi::WorkgroupId) => crate::BuiltIn::WorkGroupId, + Some(Bi::WorkgroupSize) => crate::BuiltIn::WorkGroupSize, + Some(Bi::NumWorkgroups) => crate::BuiltIn::NumWorkGroups, + _ => return Err(Error::UnsupportedBuiltIn(word)), + }) +} + +pub(super) fn map_storage_class(word: spirv::Word) -> Result<super::ExtendedClass, Error> { + use super::ExtendedClass as Ec; + use spirv::StorageClass as Sc; + Ok(match Sc::from_u32(word) { + Some(Sc::Function) => Ec::Global(crate::AddressSpace::Function), + Some(Sc::Input) => Ec::Input, + Some(Sc::Output) => Ec::Output, + Some(Sc::Private) => Ec::Global(crate::AddressSpace::Private), + Some(Sc::UniformConstant) => Ec::Global(crate::AddressSpace::Handle), + Some(Sc::StorageBuffer) => Ec::Global(crate::AddressSpace::Storage { + //Note: this is restricted by decorations later + access: crate::StorageAccess::all(), + }), + // we expect the `Storage` case to be filtered out before calling this function. + Some(Sc::Uniform) => Ec::Global(crate::AddressSpace::Uniform), + Some(Sc::Workgroup) => Ec::Global(crate::AddressSpace::WorkGroup), + Some(Sc::PushConstant) => Ec::Global(crate::AddressSpace::PushConstant), + _ => return Err(Error::UnsupportedStorageClass(word)), + }) +} diff --git a/third_party/rust/naga/src/front/spv/error.rs b/third_party/rust/naga/src/front/spv/error.rs new file mode 100644 index 0000000000..af025636c0 --- /dev/null +++ b/third_party/rust/naga/src/front/spv/error.rs @@ -0,0 +1,154 @@ +use super::ModuleState; +use crate::arena::Handle; +use codespan_reporting::diagnostic::Diagnostic; +use codespan_reporting::files::SimpleFile; +use codespan_reporting::term; +use termcolor::{NoColor, WriteColor}; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("invalid header")] + InvalidHeader, + #[error("invalid word count")] + InvalidWordCount, + #[error("unknown instruction {0}")] + UnknownInstruction(u16), + #[error("unknown capability %{0}")] + UnknownCapability(spirv::Word), + #[error("unsupported instruction {1:?} at {0:?}")] + UnsupportedInstruction(ModuleState, spirv::Op), + #[error("unsupported capability {0:?}")] + UnsupportedCapability(spirv::Capability), + #[error("unsupported extension {0}")] + UnsupportedExtension(String), + #[error("unsupported extension set {0}")] + UnsupportedExtSet(String), + #[error("unsupported extension instantiation set %{0}")] + UnsupportedExtInstSet(spirv::Word), + #[error("unsupported extension instantiation %{0}")] + UnsupportedExtInst(spirv::Word), + #[error("unsupported type {0:?}")] + UnsupportedType(Handle<crate::Type>), + #[error("unsupported execution model %{0}")] + UnsupportedExecutionModel(spirv::Word), + #[error("unsupported execution mode %{0}")] + UnsupportedExecutionMode(spirv::Word), + #[error("unsupported storage class %{0}")] + UnsupportedStorageClass(spirv::Word), + #[error("unsupported image dimension %{0}")] + UnsupportedImageDim(spirv::Word), + #[error("unsupported image format %{0}")] + UnsupportedImageFormat(spirv::Word), + #[error("unsupported builtin %{0}")] + UnsupportedBuiltIn(spirv::Word), + #[error("unsupported control flow %{0}")] + UnsupportedControlFlow(spirv::Word), + #[error("unsupported binary operator %{0}")] + UnsupportedBinaryOperator(spirv::Word), + #[error("Naga supports OpTypeRuntimeArray in the StorageBuffer storage class only")] + UnsupportedRuntimeArrayStorageClass, + #[error("unsupported matrix stride {stride} for a {columns}x{rows} matrix with scalar width={width}")] + UnsupportedMatrixStride { + stride: u32, + columns: u8, + rows: u8, + width: u8, + }, + #[error("unknown binary operator {0:?}")] + UnknownBinaryOperator(spirv::Op), + #[error("unknown relational function {0:?}")] + UnknownRelationalFunction(spirv::Op), + #[error("invalid parameter {0:?}")] + InvalidParameter(spirv::Op), + #[error("invalid operand count {1} for {0:?}")] + InvalidOperandCount(spirv::Op, u16), + #[error("invalid operand")] + InvalidOperand, + #[error("invalid id %{0}")] + InvalidId(spirv::Word), + #[error("invalid decoration %{0}")] + InvalidDecoration(spirv::Word), + #[error("invalid type width %{0}")] + InvalidTypeWidth(spirv::Word), + #[error("invalid sign %{0}")] + InvalidSign(spirv::Word), + #[error("invalid inner type %{0}")] + InvalidInnerType(spirv::Word), + #[error("invalid vector size %{0}")] + InvalidVectorSize(spirv::Word), + #[error("invalid access type %{0}")] + InvalidAccessType(spirv::Word), + #[error("invalid access {0:?}")] + InvalidAccess(crate::Expression), + #[error("invalid access index %{0}")] + InvalidAccessIndex(spirv::Word), + #[error("invalid index type %{0}")] + InvalidIndexType(spirv::Word), + #[error("invalid binding %{0}")] + InvalidBinding(spirv::Word), + #[error("invalid global var {0:?}")] + InvalidGlobalVar(crate::Expression), + #[error("invalid image/sampler expression {0:?}")] + InvalidImageExpression(crate::Expression), + #[error("invalid image base type {0:?}")] + InvalidImageBaseType(Handle<crate::Type>), + #[error("invalid image {0:?}")] + InvalidImage(Handle<crate::Type>), + #[error("invalid as type {0:?}")] + InvalidAsType(Handle<crate::Type>), + #[error("invalid vector type {0:?}")] + InvalidVectorType(Handle<crate::Type>), + #[error("inconsistent comparison sampling {0:?}")] + InconsistentComparisonSampling(Handle<crate::GlobalVariable>), + #[error("wrong function result type %{0}")] + WrongFunctionResultType(spirv::Word), + #[error("wrong function argument type %{0}")] + WrongFunctionArgumentType(spirv::Word), + #[error("missing decoration {0:?}")] + MissingDecoration(spirv::Decoration), + #[error("bad string")] + BadString, + #[error("incomplete data")] + IncompleteData, + #[error("invalid terminator")] + InvalidTerminator, + #[error("invalid edge classification")] + InvalidEdgeClassification, + #[error("cycle detected in the CFG during traversal at {0}")] + ControlFlowGraphCycle(crate::front::spv::BlockId), + #[error("recursive function call %{0}")] + FunctionCallCycle(spirv::Word), + #[error("invalid array size {0:?}")] + InvalidArraySize(Handle<crate::Constant>), + #[error("invalid barrier scope %{0}")] + InvalidBarrierScope(spirv::Word), + #[error("invalid barrier memory semantics %{0}")] + InvalidBarrierMemorySemantics(spirv::Word), + #[error( + "arrays of images / samplers are supported only through bindings for \ + now (i.e. you can't create an array of images or samplers that doesn't \ + come from a binding)" + )] + NonBindingArrayOfImageOrSamplers, +} + +impl Error { + pub fn emit_to_writer(&self, writer: &mut impl WriteColor, source: &str) { + self.emit_to_writer_with_path(writer, source, "glsl"); + } + + pub fn emit_to_writer_with_path(&self, writer: &mut impl WriteColor, source: &str, path: &str) { + let path = path.to_string(); + let files = SimpleFile::new(path, source); + let config = codespan_reporting::term::Config::default(); + let diagnostic = Diagnostic::error().with_message(format!("{self:?}")); + + term::emit(writer, &config, &files, &diagnostic).expect("cannot write error"); + } + + pub fn emit_to_string(&self, source: &str) -> String { + let mut writer = NoColor::new(Vec::new()); + self.emit_to_writer(&mut writer, source); + String::from_utf8(writer.into_inner()).unwrap() + } +} diff --git a/third_party/rust/naga/src/front/spv/function.rs b/third_party/rust/naga/src/front/spv/function.rs new file mode 100644 index 0000000000..198d9c52dd --- /dev/null +++ b/third_party/rust/naga/src/front/spv/function.rs @@ -0,0 +1,674 @@ +use crate::{ + arena::{Arena, Handle}, + front::spv::{BlockContext, BodyIndex}, +}; + +use super::{Error, Instruction, LookupExpression, LookupHelper as _}; +use crate::proc::Emitter; + +pub type BlockId = u32; + +#[derive(Copy, Clone, Debug)] +pub struct MergeInstruction { + pub merge_block_id: BlockId, + pub continue_block_id: Option<BlockId>, +} + +impl<I: Iterator<Item = u32>> super::Frontend<I> { + // Registers a function call. It will generate a dummy handle to call, which + // gets resolved after all the functions are processed. + pub(super) fn add_call( + &mut self, + from: spirv::Word, + to: spirv::Word, + ) -> Handle<crate::Function> { + let dummy_handle = self + .dummy_functions + .append(crate::Function::default(), Default::default()); + self.deferred_function_calls.push(to); + self.function_call_graph.add_edge(from, to, ()); + dummy_handle + } + + pub(super) fn parse_function(&mut self, module: &mut crate::Module) -> Result<(), Error> { + let start = self.data_offset; + self.lookup_expression.clear(); + self.lookup_load_override.clear(); + self.lookup_sampled_image.clear(); + + let result_type_id = self.next()?; + let fun_id = self.next()?; + let _fun_control = self.next()?; + let fun_type_id = self.next()?; + + let mut fun = { + let ft = self.lookup_function_type.lookup(fun_type_id)?; + if ft.return_type_id != result_type_id { + return Err(Error::WrongFunctionResultType(result_type_id)); + } + crate::Function { + name: self.future_decor.remove(&fun_id).and_then(|dec| dec.name), + arguments: Vec::with_capacity(ft.parameter_type_ids.len()), + result: if self.lookup_void_type == Some(result_type_id) { + None + } else { + let lookup_result_ty = self.lookup_type.lookup(result_type_id)?; + Some(crate::FunctionResult { + ty: lookup_result_ty.handle, + binding: None, + }) + }, + local_variables: Arena::new(), + expressions: self + .make_expression_storage(&module.global_variables, &module.constants), + named_expressions: crate::NamedExpressions::default(), + body: crate::Block::new(), + } + }; + + // read parameters + for i in 0..fun.arguments.capacity() { + let start = self.data_offset; + match self.next_inst()? { + Instruction { + op: spirv::Op::FunctionParameter, + wc: 3, + } => { + let type_id = self.next()?; + let id = self.next()?; + let handle = fun.expressions.append( + crate::Expression::FunctionArgument(i as u32), + self.span_from(start), + ); + self.lookup_expression.insert( + id, + LookupExpression { + handle, + type_id, + // Setting this to an invalid id will cause get_expr_handle + // to default to the main body making sure no load/stores + // are added. + block_id: 0, + }, + ); + //Note: we redo the lookup in order to work around `self` borrowing + + if type_id + != self + .lookup_function_type + .lookup(fun_type_id)? + .parameter_type_ids[i] + { + return Err(Error::WrongFunctionArgumentType(type_id)); + } + let ty = self.lookup_type.lookup(type_id)?.handle; + let decor = self.future_decor.remove(&id).unwrap_or_default(); + fun.arguments.push(crate::FunctionArgument { + name: decor.name, + ty, + binding: None, + }); + } + Instruction { op, .. } => return Err(Error::InvalidParameter(op)), + } + } + + // Read body + self.function_call_graph.add_node(fun_id); + let mut parameters_sampling = + vec![super::image::SamplingFlags::empty(); fun.arguments.len()]; + + let mut block_ctx = BlockContext { + phis: Default::default(), + blocks: Default::default(), + body_for_label: Default::default(), + mergers: Default::default(), + bodies: Default::default(), + function_id: fun_id, + expressions: &mut fun.expressions, + local_arena: &mut fun.local_variables, + const_arena: &mut module.constants, + const_expressions: &mut module.const_expressions, + type_arena: &module.types, + global_arena: &module.global_variables, + arguments: &fun.arguments, + parameter_sampling: &mut parameters_sampling, + }; + // Insert the main body whose parent is also himself + block_ctx.bodies.push(super::Body::with_parent(0)); + + // Scan the blocks and add them as nodes + loop { + let fun_inst = self.next_inst()?; + log::debug!("{:?}", fun_inst.op); + match fun_inst.op { + spirv::Op::Line => { + fun_inst.expect(4)?; + let _file_id = self.next()?; + let _row_id = self.next()?; + let _col_id = self.next()?; + } + spirv::Op::Label => { + // Read the label ID + fun_inst.expect(2)?; + let block_id = self.next()?; + + self.next_block(block_id, &mut block_ctx)?; + } + spirv::Op::FunctionEnd => { + fun_inst.expect(1)?; + break; + } + _ => { + return Err(Error::UnsupportedInstruction(self.state, fun_inst.op)); + } + } + } + + if let Some(ref prefix) = self.options.block_ctx_dump_prefix { + let dump_suffix = match self.lookup_entry_point.get(&fun_id) { + Some(ep) => format!("block_ctx.{:?}-{}.txt", ep.stage, ep.name), + None => format!("block_ctx.Fun-{}.txt", module.functions.len()), + }; + let dest = prefix.join(dump_suffix); + let dump = format!("{block_ctx:#?}"); + if let Err(e) = std::fs::write(&dest, dump) { + log::error!("Unable to dump the block context into {:?}: {}", dest, e); + } + } + + // Emit `Store` statements to properly initialize all the local variables we + // created for `phi` expressions. + // + // Note that get_expr_handle also contributes slightly odd entries to this table, + // to get the spill. + for phi in block_ctx.phis.iter() { + // Get a pointer to the local variable for the phi's value. + let phi_pointer = block_ctx.expressions.append( + crate::Expression::LocalVariable(phi.local), + crate::Span::default(), + ); + + // At the end of each of `phi`'s predecessor blocks, store the corresponding + // source value in the phi's local variable. + for &(source, predecessor) in phi.expressions.iter() { + let source_lexp = &self.lookup_expression[&source]; + let predecessor_body_idx = block_ctx.body_for_label[&predecessor]; + // If the expression is a global/argument it will have a 0 block + // id so we must use a default value instead of panicking + let source_body_idx = block_ctx + .body_for_label + .get(&source_lexp.block_id) + .copied() + .unwrap_or(0); + + // If the Naga `Expression` generated for `source` is in scope, then we + // can simply store that in the phi's local variable. + // + // Otherwise, spill the source value to a local variable in the block that + // defines it. (We know this store dominates the predecessor; otherwise, + // the phi wouldn't have been able to refer to that source expression in + // the first place.) Then, the predecessor block can count on finding the + // source's value in that local variable. + let value = if super::is_parent(predecessor_body_idx, source_body_idx, &block_ctx) { + source_lexp.handle + } else { + // The source SPIR-V expression is not defined in the phi's + // predecessor block, nor is it a globally available expression. So it + // must be defined off in some other block that merely dominates the + // predecessor. This means that the corresponding Naga `Expression` + // may not be in scope in the predecessor block. + // + // In the block that defines `source`, spill it to a fresh local + // variable, to ensure we can still use it at the end of the + // predecessor. + let ty = self.lookup_type[&source_lexp.type_id].handle; + let local = block_ctx.local_arena.append( + crate::LocalVariable { + name: None, + ty, + init: None, + }, + crate::Span::default(), + ); + + let pointer = block_ctx.expressions.append( + crate::Expression::LocalVariable(local), + crate::Span::default(), + ); + + // Get the spilled value of the source expression. + let start = block_ctx.expressions.len(); + let expr = block_ctx + .expressions + .append(crate::Expression::Load { pointer }, crate::Span::default()); + let range = block_ctx.expressions.range_from(start); + + block_ctx + .blocks + .get_mut(&predecessor) + .unwrap() + .push(crate::Statement::Emit(range), crate::Span::default()); + + // At the end of the block that defines it, spill the source + // expression's value. + block_ctx + .blocks + .get_mut(&source_lexp.block_id) + .unwrap() + .push( + crate::Statement::Store { + pointer, + value: source_lexp.handle, + }, + crate::Span::default(), + ); + + expr + }; + + // At the end of the phi predecessor block, store the source + // value in the phi's value. + block_ctx.blocks.get_mut(&predecessor).unwrap().push( + crate::Statement::Store { + pointer: phi_pointer, + value, + }, + crate::Span::default(), + ) + } + } + + fun.body = block_ctx.lower(); + + // done + let fun_handle = module.functions.append(fun, self.span_from_with_op(start)); + self.lookup_function.insert( + fun_id, + super::LookupFunction { + handle: fun_handle, + parameters_sampling, + }, + ); + + if let Some(ep) = self.lookup_entry_point.remove(&fun_id) { + // create a wrapping function + let mut function = crate::Function { + name: Some(format!("{}_wrap", ep.name)), + arguments: Vec::new(), + result: None, + local_variables: Arena::new(), + expressions: Arena::new(), + named_expressions: crate::NamedExpressions::default(), + body: crate::Block::new(), + }; + + // 1. copy the inputs from arguments to privates + for &v_id in ep.variable_ids.iter() { + let lvar = self.lookup_variable.lookup(v_id)?; + if let super::Variable::Input(ref arg) = lvar.inner { + let span = module.global_variables.get_span(lvar.handle); + let arg_expr = function.expressions.append( + crate::Expression::FunctionArgument(function.arguments.len() as u32), + span, + ); + let load_expr = if arg.ty == module.global_variables[lvar.handle].ty { + arg_expr + } else { + // The only case where the type is different is if we need to treat + // unsigned integer as signed. + let mut emitter = Emitter::default(); + emitter.start(&function.expressions); + let handle = function.expressions.append( + crate::Expression::As { + expr: arg_expr, + kind: crate::ScalarKind::Sint, + convert: Some(4), + }, + span, + ); + function.body.extend(emitter.finish(&function.expressions)); + handle + }; + function.body.push( + crate::Statement::Store { + pointer: function + .expressions + .append(crate::Expression::GlobalVariable(lvar.handle), span), + value: load_expr, + }, + span, + ); + + let mut arg = arg.clone(); + if ep.stage == crate::ShaderStage::Fragment { + if let Some(ref mut binding) = arg.binding { + binding.apply_default_interpolation(&module.types[arg.ty].inner); + } + } + function.arguments.push(arg); + } + } + // 2. call the wrapped function + let fake_id = !(module.entry_points.len() as u32); // doesn't matter, as long as it's not a collision + let dummy_handle = self.add_call(fake_id, fun_id); + function.body.push( + crate::Statement::Call { + function: dummy_handle, + arguments: Vec::new(), + result: None, + }, + crate::Span::default(), + ); + + // 3. copy the outputs from privates to the result + let mut members = Vec::new(); + let mut components = Vec::new(); + for &v_id in ep.variable_ids.iter() { + let lvar = self.lookup_variable.lookup(v_id)?; + if let super::Variable::Output(ref result) = lvar.inner { + let span = module.global_variables.get_span(lvar.handle); + let expr_handle = function + .expressions + .append(crate::Expression::GlobalVariable(lvar.handle), span); + + // Cull problematic builtins of gl_PerVertex. + // See the docs for `Frontend::gl_per_vertex_builtin_access`. + { + let ty = &module.types[result.ty]; + match ty.inner { + crate::TypeInner::Struct { + members: ref original_members, + span, + } if ty.name.as_deref() == Some("gl_PerVertex") => { + let mut new_members = original_members.clone(); + for member in &mut new_members { + if let Some(crate::Binding::BuiltIn(built_in)) = member.binding + { + if !self.gl_per_vertex_builtin_access.contains(&built_in) { + member.binding = None + } + } + } + if &new_members != original_members { + module.types.replace( + result.ty, + crate::Type { + name: ty.name.clone(), + inner: crate::TypeInner::Struct { + members: new_members, + span, + }, + }, + ); + } + } + _ => {} + } + } + + match module.types[result.ty].inner { + crate::TypeInner::Struct { + members: ref sub_members, + .. + } => { + for (index, sm) in sub_members.iter().enumerate() { + if sm.binding.is_none() { + continue; + } + let mut sm = sm.clone(); + + if let Some(ref mut binding) = sm.binding { + if ep.stage == crate::ShaderStage::Vertex { + binding.apply_default_interpolation( + &module.types[sm.ty].inner, + ); + } + } + + members.push(sm); + + components.push(function.expressions.append( + crate::Expression::AccessIndex { + base: expr_handle, + index: index as u32, + }, + span, + )); + } + } + ref inner => { + let mut binding = result.binding.clone(); + if let Some(ref mut binding) = binding { + if ep.stage == crate::ShaderStage::Vertex { + binding.apply_default_interpolation(inner); + } + } + + members.push(crate::StructMember { + name: None, + ty: result.ty, + binding, + offset: 0, + }); + // populate just the globals first, then do `Load` in a + // separate step, so that we can get a range. + components.push(expr_handle); + } + } + } + } + + for (member_index, member) in members.iter().enumerate() { + match member.binding { + Some(crate::Binding::BuiltIn(crate::BuiltIn::Position { .. })) + if self.options.adjust_coordinate_space => + { + let mut emitter = Emitter::default(); + emitter.start(&function.expressions); + let global_expr = components[member_index]; + let span = function.expressions.get_span(global_expr); + let access_expr = function.expressions.append( + crate::Expression::AccessIndex { + base: global_expr, + index: 1, + }, + span, + ); + let load_expr = function.expressions.append( + crate::Expression::Load { + pointer: access_expr, + }, + span, + ); + let neg_expr = function.expressions.append( + crate::Expression::Unary { + op: crate::UnaryOperator::Negate, + expr: load_expr, + }, + span, + ); + function.body.extend(emitter.finish(&function.expressions)); + function.body.push( + crate::Statement::Store { + pointer: access_expr, + value: neg_expr, + }, + span, + ); + } + _ => {} + } + } + + let mut emitter = Emitter::default(); + emitter.start(&function.expressions); + for component in components.iter_mut() { + let load_expr = crate::Expression::Load { + pointer: *component, + }; + let span = function.expressions.get_span(*component); + *component = function.expressions.append(load_expr, span); + } + + match members[..] { + [] => {} + [ref member] => { + function.body.extend(emitter.finish(&function.expressions)); + let span = function.expressions.get_span(components[0]); + function.body.push( + crate::Statement::Return { + value: components.first().cloned(), + }, + span, + ); + function.result = Some(crate::FunctionResult { + ty: member.ty, + binding: member.binding.clone(), + }); + } + _ => { + let span = crate::Span::total_span( + components.iter().map(|h| function.expressions.get_span(*h)), + ); + let ty = module.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Struct { + members, + span: 0xFFFF, // shouldn't matter + }, + }, + span, + ); + let result_expr = function + .expressions + .append(crate::Expression::Compose { ty, components }, span); + function.body.extend(emitter.finish(&function.expressions)); + function.body.push( + crate::Statement::Return { + value: Some(result_expr), + }, + span, + ); + function.result = Some(crate::FunctionResult { ty, binding: None }); + } + } + + module.entry_points.push(crate::EntryPoint { + name: ep.name, + stage: ep.stage, + early_depth_test: ep.early_depth_test, + workgroup_size: ep.workgroup_size, + function, + }); + } + + Ok(()) + } +} + +impl<'function> BlockContext<'function> { + pub(super) fn gctx(&self) -> crate::proc::GlobalCtx { + crate::proc::GlobalCtx { + types: self.type_arena, + constants: self.const_arena, + const_expressions: self.const_expressions, + } + } + + /// Consumes the `BlockContext` producing a Ir [`Block`](crate::Block) + fn lower(mut self) -> crate::Block { + fn lower_impl( + blocks: &mut crate::FastHashMap<spirv::Word, crate::Block>, + bodies: &[super::Body], + body_idx: BodyIndex, + ) -> crate::Block { + let mut block = crate::Block::new(); + + for item in bodies[body_idx].data.iter() { + match *item { + super::BodyFragment::BlockId(id) => block.append(blocks.get_mut(&id).unwrap()), + super::BodyFragment::If { + condition, + accept, + reject, + } => { + let accept = lower_impl(blocks, bodies, accept); + let reject = lower_impl(blocks, bodies, reject); + + block.push( + crate::Statement::If { + condition, + accept, + reject, + }, + crate::Span::default(), + ) + } + super::BodyFragment::Loop { + body, + continuing, + break_if, + } => { + let body = lower_impl(blocks, bodies, body); + let continuing = lower_impl(blocks, bodies, continuing); + + block.push( + crate::Statement::Loop { + body, + continuing, + break_if, + }, + crate::Span::default(), + ) + } + super::BodyFragment::Switch { + selector, + ref cases, + default, + } => { + let mut ir_cases: Vec<_> = cases + .iter() + .map(|&(value, body_idx)| { + let body = lower_impl(blocks, bodies, body_idx); + + // Handle simple cases that would make a fallthrough statement unreachable code + let fall_through = body.last().map_or(true, |s| !s.is_terminator()); + + crate::SwitchCase { + value: crate::SwitchValue::I32(value), + body, + fall_through, + } + }) + .collect(); + ir_cases.push(crate::SwitchCase { + value: crate::SwitchValue::Default, + body: lower_impl(blocks, bodies, default), + fall_through: false, + }); + + block.push( + crate::Statement::Switch { + selector, + cases: ir_cases, + }, + crate::Span::default(), + ) + } + super::BodyFragment::Break => { + block.push(crate::Statement::Break, crate::Span::default()) + } + super::BodyFragment::Continue => { + block.push(crate::Statement::Continue, crate::Span::default()) + } + } + } + + block + } + + lower_impl(&mut self.blocks, &self.bodies, 0) + } +} diff --git a/third_party/rust/naga/src/front/spv/image.rs b/third_party/rust/naga/src/front/spv/image.rs new file mode 100644 index 0000000000..0f25dd626b --- /dev/null +++ b/third_party/rust/naga/src/front/spv/image.rs @@ -0,0 +1,767 @@ +use crate::{ + arena::{Handle, UniqueArena}, + Scalar, +}; + +use super::{Error, LookupExpression, LookupHelper as _}; + +#[derive(Clone, Debug)] +pub(super) struct LookupSampledImage { + image: Handle<crate::Expression>, + sampler: Handle<crate::Expression>, +} + +bitflags::bitflags! { + /// Flags describing sampling method. + #[derive(Clone, Copy, Debug, Eq, PartialEq)] + pub struct SamplingFlags: u32 { + /// Regular sampling. + const REGULAR = 0x1; + /// Comparison sampling. + const COMPARISON = 0x2; + } +} + +impl<'function> super::BlockContext<'function> { + fn get_image_expr_ty( + &self, + handle: Handle<crate::Expression>, + ) -> Result<Handle<crate::Type>, Error> { + match self.expressions[handle] { + crate::Expression::GlobalVariable(handle) => Ok(self.global_arena[handle].ty), + crate::Expression::FunctionArgument(i) => Ok(self.arguments[i as usize].ty), + ref other => Err(Error::InvalidImageExpression(other.clone())), + } + } +} + +/// Options of a sampling operation. +#[derive(Debug)] +pub struct SamplingOptions { + /// Projection sampling: the division by W is expected to happen + /// in the texture unit. + pub project: bool, + /// Depth comparison sampling with a reference value. + pub compare: bool, +} + +enum ExtraCoordinate { + ArrayLayer, + Projection, + Garbage, +} + +/// Return the texture coordinates separated from the array layer, +/// and/or divided by the projection term. +/// +/// The Proj sampling ops expect an extra coordinate for the W. +/// The arrayed (can't be Proj!) images expect an extra coordinate for the layer. +fn extract_image_coordinates( + image_dim: crate::ImageDimension, + extra_coordinate: ExtraCoordinate, + base: Handle<crate::Expression>, + coordinate_ty: Handle<crate::Type>, + ctx: &mut super::BlockContext, +) -> (Handle<crate::Expression>, Option<Handle<crate::Expression>>) { + let (given_size, kind) = match ctx.type_arena[coordinate_ty].inner { + crate::TypeInner::Scalar(Scalar { kind, .. }) => (None, kind), + crate::TypeInner::Vector { + size, + scalar: Scalar { kind, .. }, + } => (Some(size), kind), + ref other => unreachable!("Unexpected texture coordinate {:?}", other), + }; + + let required_size = image_dim.required_coordinate_size(); + let required_ty = required_size.map(|size| { + ctx.type_arena + .get(&crate::Type { + name: None, + inner: crate::TypeInner::Vector { + size, + scalar: Scalar { kind, width: 4 }, + }, + }) + .expect("Required coordinate type should have been set up by `parse_type_image`!") + }); + let extra_expr = crate::Expression::AccessIndex { + base, + index: required_size.map_or(1, |size| size as u32), + }; + + let base_span = ctx.expressions.get_span(base); + + match extra_coordinate { + ExtraCoordinate::ArrayLayer => { + let extracted = match required_size { + None => ctx + .expressions + .append(crate::Expression::AccessIndex { base, index: 0 }, base_span), + Some(size) => { + let mut components = Vec::with_capacity(size as usize); + for index in 0..size as u32 { + let comp = ctx + .expressions + .append(crate::Expression::AccessIndex { base, index }, base_span); + components.push(comp); + } + ctx.expressions.append( + crate::Expression::Compose { + ty: required_ty.unwrap(), + components, + }, + base_span, + ) + } + }; + let array_index_f32 = ctx.expressions.append(extra_expr, base_span); + let array_index = ctx.expressions.append( + crate::Expression::As { + kind: crate::ScalarKind::Sint, + expr: array_index_f32, + convert: Some(4), + }, + base_span, + ); + (extracted, Some(array_index)) + } + ExtraCoordinate::Projection => { + let projection = ctx.expressions.append(extra_expr, base_span); + let divided = match required_size { + None => { + let temp = ctx + .expressions + .append(crate::Expression::AccessIndex { base, index: 0 }, base_span); + ctx.expressions.append( + crate::Expression::Binary { + op: crate::BinaryOperator::Divide, + left: temp, + right: projection, + }, + base_span, + ) + } + Some(size) => { + let mut components = Vec::with_capacity(size as usize); + for index in 0..size as u32 { + let temp = ctx + .expressions + .append(crate::Expression::AccessIndex { base, index }, base_span); + let comp = ctx.expressions.append( + crate::Expression::Binary { + op: crate::BinaryOperator::Divide, + left: temp, + right: projection, + }, + base_span, + ); + components.push(comp); + } + ctx.expressions.append( + crate::Expression::Compose { + ty: required_ty.unwrap(), + components, + }, + base_span, + ) + } + }; + (divided, None) + } + ExtraCoordinate::Garbage if given_size == required_size => (base, None), + ExtraCoordinate::Garbage => { + use crate::SwizzleComponent as Sc; + let cut_expr = match required_size { + None => crate::Expression::AccessIndex { base, index: 0 }, + Some(size) => crate::Expression::Swizzle { + size, + vector: base, + pattern: [Sc::X, Sc::Y, Sc::Z, Sc::W], + }, + }; + (ctx.expressions.append(cut_expr, base_span), None) + } + } +} + +pub(super) fn patch_comparison_type( + flags: SamplingFlags, + var: &mut crate::GlobalVariable, + arena: &mut UniqueArena<crate::Type>, +) -> bool { + if !flags.contains(SamplingFlags::COMPARISON) { + return true; + } + if flags == SamplingFlags::all() { + return false; + } + + log::debug!("Flipping comparison for {:?}", var); + let original_ty = &arena[var.ty]; + let original_ty_span = arena.get_span(var.ty); + let ty_inner = match original_ty.inner { + crate::TypeInner::Image { + class: crate::ImageClass::Sampled { multi, .. }, + dim, + arrayed, + } => crate::TypeInner::Image { + class: crate::ImageClass::Depth { multi }, + dim, + arrayed, + }, + crate::TypeInner::Sampler { .. } => crate::TypeInner::Sampler { comparison: true }, + ref other => unreachable!("Unexpected type for comparison mutation: {:?}", other), + }; + + let name = original_ty.name.clone(); + var.ty = arena.insert( + crate::Type { + name, + inner: ty_inner, + }, + original_ty_span, + ); + true +} + +impl<I: Iterator<Item = u32>> super::Frontend<I> { + pub(super) fn parse_image_couple(&mut self) -> Result<(), Error> { + let _result_type_id = self.next()?; + let result_id = self.next()?; + let image_id = self.next()?; + let sampler_id = self.next()?; + let image_lexp = self.lookup_expression.lookup(image_id)?; + let sampler_lexp = self.lookup_expression.lookup(sampler_id)?; + self.lookup_sampled_image.insert( + result_id, + LookupSampledImage { + image: image_lexp.handle, + sampler: sampler_lexp.handle, + }, + ); + Ok(()) + } + + pub(super) fn parse_image_uncouple(&mut self, block_id: spirv::Word) -> Result<(), Error> { + let result_type_id = self.next()?; + let result_id = self.next()?; + let sampled_image_id = self.next()?; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: self.lookup_sampled_image.lookup(sampled_image_id)?.image, + type_id: result_type_id, + block_id, + }, + ); + Ok(()) + } + + pub(super) fn parse_image_write( + &mut self, + words_left: u16, + ctx: &mut super::BlockContext, + emitter: &mut crate::proc::Emitter, + block: &mut crate::Block, + body_idx: usize, + ) -> Result<crate::Statement, Error> { + let image_id = self.next()?; + let coordinate_id = self.next()?; + let value_id = self.next()?; + + let image_ops = if words_left != 0 { self.next()? } else { 0 }; + + if image_ops != 0 { + let other = spirv::ImageOperands::from_bits_truncate(image_ops); + log::warn!("Unknown image write ops {:?}", other); + for _ in 1..words_left { + self.next()?; + } + } + + let image_lexp = self.lookup_expression.lookup(image_id)?; + let image_ty = ctx.get_image_expr_ty(image_lexp.handle)?; + + let coord_lexp = self.lookup_expression.lookup(coordinate_id)?; + let coord_handle = + self.get_expr_handle(coordinate_id, coord_lexp, ctx, emitter, block, body_idx); + let coord_type_handle = self.lookup_type.lookup(coord_lexp.type_id)?.handle; + let (coordinate, array_index) = match ctx.type_arena[image_ty].inner { + crate::TypeInner::Image { + dim, + arrayed, + class: _, + } => extract_image_coordinates( + dim, + if arrayed { + ExtraCoordinate::ArrayLayer + } else { + ExtraCoordinate::Garbage + }, + coord_handle, + coord_type_handle, + ctx, + ), + _ => return Err(Error::InvalidImage(image_ty)), + }; + + let value_lexp = self.lookup_expression.lookup(value_id)?; + let value = self.get_expr_handle(value_id, value_lexp, ctx, emitter, block, body_idx); + + Ok(crate::Statement::ImageStore { + image: image_lexp.handle, + coordinate, + array_index, + value, + }) + } + + pub(super) fn parse_image_load( + &mut self, + mut words_left: u16, + ctx: &mut super::BlockContext, + emitter: &mut crate::proc::Emitter, + block: &mut crate::Block, + block_id: spirv::Word, + body_idx: usize, + ) -> Result<(), Error> { + let start = self.data_offset; + let result_type_id = self.next()?; + let result_id = self.next()?; + let image_id = self.next()?; + let coordinate_id = self.next()?; + + let mut image_ops = if words_left != 0 { + words_left -= 1; + self.next()? + } else { + 0 + }; + + let mut sample = None; + let mut level = None; + while image_ops != 0 { + let bit = 1 << image_ops.trailing_zeros(); + match spirv::ImageOperands::from_bits_truncate(bit) { + spirv::ImageOperands::LOD => { + let lod_expr = self.next()?; + let lod_lexp = self.lookup_expression.lookup(lod_expr)?; + let lod_handle = + self.get_expr_handle(lod_expr, lod_lexp, ctx, emitter, block, body_idx); + level = Some(lod_handle); + words_left -= 1; + } + spirv::ImageOperands::SAMPLE => { + let sample_expr = self.next()?; + let sample_handle = self.lookup_expression.lookup(sample_expr)?.handle; + sample = Some(sample_handle); + words_left -= 1; + } + other => { + log::warn!("Unknown image load op {:?}", other); + for _ in 0..words_left { + self.next()?; + } + break; + } + } + image_ops ^= bit; + } + + // No need to call get_expr_handle here since only globals/arguments are + // allowed as images and they are always in the root scope + let image_lexp = self.lookup_expression.lookup(image_id)?; + let image_ty = ctx.get_image_expr_ty(image_lexp.handle)?; + + let coord_lexp = self.lookup_expression.lookup(coordinate_id)?; + let coord_handle = + self.get_expr_handle(coordinate_id, coord_lexp, ctx, emitter, block, body_idx); + let coord_type_handle = self.lookup_type.lookup(coord_lexp.type_id)?.handle; + let (coordinate, array_index) = match ctx.type_arena[image_ty].inner { + crate::TypeInner::Image { + dim, + arrayed, + class: _, + } => extract_image_coordinates( + dim, + if arrayed { + ExtraCoordinate::ArrayLayer + } else { + ExtraCoordinate::Garbage + }, + coord_handle, + coord_type_handle, + ctx, + ), + _ => return Err(Error::InvalidImage(image_ty)), + }; + + let expr = crate::Expression::ImageLoad { + image: image_lexp.handle, + coordinate, + array_index, + sample, + level, + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, self.span_from_with_op(start)), + type_id: result_type_id, + block_id, + }, + ); + Ok(()) + } + + #[allow(clippy::too_many_arguments)] + pub(super) fn parse_image_sample( + &mut self, + mut words_left: u16, + options: SamplingOptions, + ctx: &mut super::BlockContext, + emitter: &mut crate::proc::Emitter, + block: &mut crate::Block, + block_id: spirv::Word, + body_idx: usize, + ) -> Result<(), Error> { + let start = self.data_offset; + let result_type_id = self.next()?; + let result_id = self.next()?; + let sampled_image_id = self.next()?; + let coordinate_id = self.next()?; + let dref_id = if options.compare { + Some(self.next()?) + } else { + None + }; + + let mut image_ops = if words_left != 0 { + words_left -= 1; + self.next()? + } else { + 0 + }; + + let mut level = crate::SampleLevel::Auto; + let mut offset = None; + while image_ops != 0 { + let bit = 1 << image_ops.trailing_zeros(); + match spirv::ImageOperands::from_bits_truncate(bit) { + spirv::ImageOperands::BIAS => { + let bias_expr = self.next()?; + let bias_lexp = self.lookup_expression.lookup(bias_expr)?; + let bias_handle = + self.get_expr_handle(bias_expr, bias_lexp, ctx, emitter, block, body_idx); + level = crate::SampleLevel::Bias(bias_handle); + words_left -= 1; + } + spirv::ImageOperands::LOD => { + let lod_expr = self.next()?; + let lod_lexp = self.lookup_expression.lookup(lod_expr)?; + let lod_handle = + self.get_expr_handle(lod_expr, lod_lexp, ctx, emitter, block, body_idx); + level = if options.compare { + log::debug!("Assuming {:?} is zero", lod_handle); + crate::SampleLevel::Zero + } else { + crate::SampleLevel::Exact(lod_handle) + }; + words_left -= 1; + } + spirv::ImageOperands::GRAD => { + let grad_x_expr = self.next()?; + let grad_x_lexp = self.lookup_expression.lookup(grad_x_expr)?; + let grad_x_handle = self.get_expr_handle( + grad_x_expr, + grad_x_lexp, + ctx, + emitter, + block, + body_idx, + ); + let grad_y_expr = self.next()?; + let grad_y_lexp = self.lookup_expression.lookup(grad_y_expr)?; + let grad_y_handle = self.get_expr_handle( + grad_y_expr, + grad_y_lexp, + ctx, + emitter, + block, + body_idx, + ); + level = if options.compare { + log::debug!( + "Assuming gradients {:?} and {:?} are not greater than 1", + grad_x_handle, + grad_y_handle + ); + crate::SampleLevel::Zero + } else { + crate::SampleLevel::Gradient { + x: grad_x_handle, + y: grad_y_handle, + } + }; + words_left -= 2; + } + spirv::ImageOperands::CONST_OFFSET => { + let offset_constant = self.next()?; + let offset_handle = self.lookup_constant.lookup(offset_constant)?.handle; + let offset_handle = ctx.const_expressions.append( + crate::Expression::Constant(offset_handle), + Default::default(), + ); + offset = Some(offset_handle); + words_left -= 1; + } + other => { + log::warn!("Unknown image sample operand {:?}", other); + for _ in 0..words_left { + self.next()?; + } + break; + } + } + image_ops ^= bit; + } + + let si_lexp = self.lookup_sampled_image.lookup(sampled_image_id)?; + let coord_lexp = self.lookup_expression.lookup(coordinate_id)?; + let coord_handle = + self.get_expr_handle(coordinate_id, coord_lexp, ctx, emitter, block, body_idx); + let coord_type_handle = self.lookup_type.lookup(coord_lexp.type_id)?.handle; + + let sampling_bit = if options.compare { + SamplingFlags::COMPARISON + } else { + SamplingFlags::REGULAR + }; + + let image_ty = match ctx.expressions[si_lexp.image] { + crate::Expression::GlobalVariable(handle) => { + if let Some(flags) = self.handle_sampling.get_mut(&handle) { + *flags |= sampling_bit; + } + + ctx.global_arena[handle].ty + } + + crate::Expression::FunctionArgument(i) => { + ctx.parameter_sampling[i as usize] |= sampling_bit; + ctx.arguments[i as usize].ty + } + + crate::Expression::Access { base, .. } => match ctx.expressions[base] { + crate::Expression::GlobalVariable(handle) => { + if let Some(flags) = self.handle_sampling.get_mut(&handle) { + *flags |= sampling_bit; + } + + match ctx.type_arena[ctx.global_arena[handle].ty].inner { + crate::TypeInner::BindingArray { base, .. } => base, + _ => return Err(Error::InvalidGlobalVar(ctx.expressions[base].clone())), + } + } + + ref other => return Err(Error::InvalidGlobalVar(other.clone())), + }, + + ref other => return Err(Error::InvalidGlobalVar(other.clone())), + }; + + match ctx.expressions[si_lexp.sampler] { + crate::Expression::GlobalVariable(handle) => { + *self.handle_sampling.get_mut(&handle).unwrap() |= sampling_bit; + } + + crate::Expression::FunctionArgument(i) => { + ctx.parameter_sampling[i as usize] |= sampling_bit; + } + + crate::Expression::Access { base, .. } => match ctx.expressions[base] { + crate::Expression::GlobalVariable(handle) => { + *self.handle_sampling.get_mut(&handle).unwrap() |= sampling_bit; + } + + ref other => return Err(Error::InvalidGlobalVar(other.clone())), + }, + + ref other => return Err(Error::InvalidGlobalVar(other.clone())), + } + + let ((coordinate, array_index), depth_ref) = match ctx.type_arena[image_ty].inner { + crate::TypeInner::Image { + dim, + arrayed, + class: _, + } => ( + extract_image_coordinates( + dim, + if options.project { + ExtraCoordinate::Projection + } else if arrayed { + ExtraCoordinate::ArrayLayer + } else { + ExtraCoordinate::Garbage + }, + coord_handle, + coord_type_handle, + ctx, + ), + { + match dref_id { + Some(id) => { + let expr_lexp = self.lookup_expression.lookup(id)?; + let mut expr = + self.get_expr_handle(id, expr_lexp, ctx, emitter, block, body_idx); + + if options.project { + let required_size = dim.required_coordinate_size(); + let right = ctx.expressions.append( + crate::Expression::AccessIndex { + base: coord_handle, + index: required_size.map_or(1, |size| size as u32), + }, + crate::Span::default(), + ); + expr = ctx.expressions.append( + crate::Expression::Binary { + op: crate::BinaryOperator::Divide, + left: expr, + right, + }, + crate::Span::default(), + ) + }; + Some(expr) + } + None => None, + } + }, + ), + _ => return Err(Error::InvalidImage(image_ty)), + }; + + let expr = crate::Expression::ImageSample { + image: si_lexp.image, + sampler: si_lexp.sampler, + gather: None, //TODO + coordinate, + array_index, + offset, + level, + depth_ref, + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, self.span_from_with_op(start)), + type_id: result_type_id, + block_id, + }, + ); + Ok(()) + } + + pub(super) fn parse_image_query_size( + &mut self, + at_level: bool, + ctx: &mut super::BlockContext, + emitter: &mut crate::proc::Emitter, + block: &mut crate::Block, + block_id: spirv::Word, + body_idx: usize, + ) -> Result<(), Error> { + let start = self.data_offset; + let result_type_id = self.next()?; + let result_id = self.next()?; + let image_id = self.next()?; + let level = if at_level { + let level_id = self.next()?; + let level_lexp = self.lookup_expression.lookup(level_id)?; + Some(self.get_expr_handle(level_id, level_lexp, ctx, emitter, block, body_idx)) + } else { + None + }; + + // No need to call get_expr_handle here since only globals/arguments are + // allowed as images and they are always in the root scope + //TODO: handle arrays and cubes + let image_lexp = self.lookup_expression.lookup(image_id)?; + + let expr = crate::Expression::ImageQuery { + image: image_lexp.handle, + query: crate::ImageQuery::Size { level }, + }; + + let result_type_handle = self.lookup_type.lookup(result_type_id)?.handle; + let maybe_scalar_kind = ctx.type_arena[result_type_handle].inner.scalar_kind(); + + let expr = if maybe_scalar_kind == Some(crate::ScalarKind::Sint) { + crate::Expression::As { + expr: ctx.expressions.append(expr, self.span_from_with_op(start)), + kind: crate::ScalarKind::Sint, + convert: Some(4), + } + } else { + expr + }; + + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, self.span_from_with_op(start)), + type_id: result_type_id, + block_id, + }, + ); + + Ok(()) + } + + pub(super) fn parse_image_query_other( + &mut self, + query: crate::ImageQuery, + ctx: &mut super::BlockContext, + block_id: spirv::Word, + ) -> Result<(), Error> { + let start = self.data_offset; + let result_type_id = self.next()?; + let result_id = self.next()?; + let image_id = self.next()?; + + // No need to call get_expr_handle here since only globals/arguments are + // allowed as images and they are always in the root scope + let image_lexp = self.lookup_expression.lookup(image_id)?.clone(); + + let expr = crate::Expression::ImageQuery { + image: image_lexp.handle, + query, + }; + + let result_type_handle = self.lookup_type.lookup(result_type_id)?.handle; + let maybe_scalar_kind = ctx.type_arena[result_type_handle].inner.scalar_kind(); + + let expr = if maybe_scalar_kind == Some(crate::ScalarKind::Sint) { + crate::Expression::As { + expr: ctx.expressions.append(expr, self.span_from_with_op(start)), + kind: crate::ScalarKind::Sint, + convert: Some(4), + } + } else { + expr + }; + + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, self.span_from_with_op(start)), + type_id: result_type_id, + block_id, + }, + ); + + Ok(()) + } +} diff --git a/third_party/rust/naga/src/front/spv/mod.rs b/third_party/rust/naga/src/front/spv/mod.rs new file mode 100644 index 0000000000..8b1c854358 --- /dev/null +++ b/third_party/rust/naga/src/front/spv/mod.rs @@ -0,0 +1,5356 @@ +/*! +Frontend for [SPIR-V][spv] (Standard Portable Intermediate Representation). + +## ID lookups + +Our IR links to everything with `Handle`, while SPIR-V uses IDs. +In order to keep track of the associations, the parser has many lookup tables. +There map `spv::Word` into a specific IR handle, plus potentially a bit of +extra info, such as the related SPIR-V type ID. +TODO: would be nice to find ways that avoid looking up as much + +## Inputs/Outputs + +We create a private variable for each input/output. The relevant inputs are +populated at the start of an entry point. The outputs are saved at the end. + +The function associated with an entry point is wrapped in another function, +such that we can handle any `Return` statements without problems. + +## Row-major matrices + +We don't handle them natively, since the IR only expects column majority. +Instead, we detect when such matrix is accessed in the `OpAccessChain`, +and we generate a parallel expression that loads the value, but transposed. +This value then gets used instead of `OpLoad` result later on. + +[spv]: https://www.khronos.org/registry/SPIR-V/ +*/ + +mod convert; +mod error; +mod function; +mod image; +mod null; + +use convert::*; +pub use error::Error; +use function::*; + +use crate::{ + arena::{Arena, Handle, UniqueArena}, + proc::{Alignment, Layouter}, + FastHashMap, FastHashSet, FastIndexMap, +}; + +use petgraph::graphmap::GraphMap; +use std::{convert::TryInto, mem, num::NonZeroU32, path::PathBuf}; + +pub const SUPPORTED_CAPABILITIES: &[spirv::Capability] = &[ + spirv::Capability::Shader, + spirv::Capability::VulkanMemoryModel, + spirv::Capability::ClipDistance, + spirv::Capability::CullDistance, + spirv::Capability::SampleRateShading, + spirv::Capability::DerivativeControl, + spirv::Capability::Matrix, + spirv::Capability::ImageQuery, + spirv::Capability::Sampled1D, + spirv::Capability::Image1D, + spirv::Capability::SampledCubeArray, + spirv::Capability::ImageCubeArray, + spirv::Capability::StorageImageExtendedFormats, + spirv::Capability::Int8, + spirv::Capability::Int16, + spirv::Capability::Int64, + spirv::Capability::Float16, + spirv::Capability::Float64, + spirv::Capability::Geometry, + spirv::Capability::MultiView, + // tricky ones + spirv::Capability::UniformBufferArrayDynamicIndexing, + spirv::Capability::StorageBufferArrayDynamicIndexing, +]; +pub const SUPPORTED_EXTENSIONS: &[&str] = &[ + "SPV_KHR_storage_buffer_storage_class", + "SPV_KHR_vulkan_memory_model", + "SPV_KHR_multiview", +]; +pub const SUPPORTED_EXT_SETS: &[&str] = &["GLSL.std.450"]; + +#[derive(Copy, Clone)] +pub struct Instruction { + op: spirv::Op, + wc: u16, +} + +impl Instruction { + const fn expect(self, count: u16) -> Result<(), Error> { + if self.wc == count { + Ok(()) + } else { + Err(Error::InvalidOperandCount(self.op, self.wc)) + } + } + + fn expect_at_least(self, count: u16) -> Result<u16, Error> { + self.wc + .checked_sub(count) + .ok_or(Error::InvalidOperandCount(self.op, self.wc)) + } +} + +impl crate::TypeInner { + fn can_comparison_sample(&self, module: &crate::Module) -> bool { + match *self { + crate::TypeInner::Image { + class: + crate::ImageClass::Sampled { + kind: crate::ScalarKind::Float, + multi: false, + }, + .. + } => true, + crate::TypeInner::Sampler { .. } => true, + crate::TypeInner::BindingArray { base, .. } => { + module.types[base].inner.can_comparison_sample(module) + } + _ => false, + } + } +} + +#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)] +pub enum ModuleState { + Empty, + Capability, + Extension, + ExtInstImport, + MemoryModel, + EntryPoint, + ExecutionMode, + Source, + Name, + ModuleProcessed, + Annotation, + Type, + Function, +} + +trait LookupHelper { + type Target; + fn lookup(&self, key: spirv::Word) -> Result<&Self::Target, Error>; +} + +impl<T> LookupHelper for FastHashMap<spirv::Word, T> { + type Target = T; + fn lookup(&self, key: spirv::Word) -> Result<&T, Error> { + self.get(&key).ok_or(Error::InvalidId(key)) + } +} + +impl crate::ImageDimension { + const fn required_coordinate_size(&self) -> Option<crate::VectorSize> { + match *self { + crate::ImageDimension::D1 => None, + crate::ImageDimension::D2 => Some(crate::VectorSize::Bi), + crate::ImageDimension::D3 => Some(crate::VectorSize::Tri), + crate::ImageDimension::Cube => Some(crate::VectorSize::Tri), + } + } +} + +type MemberIndex = u32; + +bitflags::bitflags! { + #[derive(Clone, Copy, Debug, Default)] + struct DecorationFlags: u32 { + const NON_READABLE = 0x1; + const NON_WRITABLE = 0x2; + } +} + +impl DecorationFlags { + fn to_storage_access(self) -> crate::StorageAccess { + let mut access = crate::StorageAccess::all(); + if self.contains(DecorationFlags::NON_READABLE) { + access &= !crate::StorageAccess::LOAD; + } + if self.contains(DecorationFlags::NON_WRITABLE) { + access &= !crate::StorageAccess::STORE; + } + access + } +} + +#[derive(Debug, PartialEq)] +enum Majority { + Column, + Row, +} + +#[derive(Debug, Default)] +struct Decoration { + name: Option<String>, + built_in: Option<spirv::Word>, + location: Option<spirv::Word>, + desc_set: Option<spirv::Word>, + desc_index: Option<spirv::Word>, + specialization: Option<spirv::Word>, + storage_buffer: bool, + offset: Option<spirv::Word>, + array_stride: Option<NonZeroU32>, + matrix_stride: Option<NonZeroU32>, + matrix_major: Option<Majority>, + invariant: bool, + interpolation: Option<crate::Interpolation>, + sampling: Option<crate::Sampling>, + flags: DecorationFlags, +} + +impl Decoration { + fn debug_name(&self) -> &str { + match self.name { + Some(ref name) => name.as_str(), + None => "?", + } + } + + fn specialization(&self) -> crate::Override { + self.specialization + .map_or(crate::Override::None, crate::Override::ByNameOrId) + } + + const fn resource_binding(&self) -> Option<crate::ResourceBinding> { + match *self { + Decoration { + desc_set: Some(group), + desc_index: Some(binding), + .. + } => Some(crate::ResourceBinding { group, binding }), + _ => None, + } + } + + fn io_binding(&self) -> Result<crate::Binding, Error> { + match *self { + Decoration { + built_in: Some(built_in), + location: None, + invariant, + .. + } => Ok(crate::Binding::BuiltIn(map_builtin(built_in, invariant)?)), + Decoration { + built_in: None, + location: Some(location), + interpolation, + sampling, + .. + } => Ok(crate::Binding::Location { + location, + interpolation, + sampling, + second_blend_source: false, + }), + _ => Err(Error::MissingDecoration(spirv::Decoration::Location)), + } + } +} + +#[derive(Debug)] +struct LookupFunctionType { + parameter_type_ids: Vec<spirv::Word>, + return_type_id: spirv::Word, +} + +struct LookupFunction { + handle: Handle<crate::Function>, + parameters_sampling: Vec<image::SamplingFlags>, +} + +#[derive(Debug)] +struct EntryPoint { + stage: crate::ShaderStage, + name: String, + early_depth_test: Option<crate::EarlyDepthTest>, + workgroup_size: [u32; 3], + variable_ids: Vec<spirv::Word>, +} + +#[derive(Clone, Debug)] +struct LookupType { + handle: Handle<crate::Type>, + base_id: Option<spirv::Word>, +} + +#[derive(Debug)] +struct LookupConstant { + handle: Handle<crate::Constant>, + type_id: spirv::Word, +} + +#[derive(Debug)] +enum Variable { + Global, + Input(crate::FunctionArgument), + Output(crate::FunctionResult), +} + +#[derive(Debug)] +struct LookupVariable { + inner: Variable, + handle: Handle<crate::GlobalVariable>, + type_id: spirv::Word, +} + +/// Information about SPIR-V result ids, stored in `Parser::lookup_expression`. +#[derive(Clone, Debug)] +struct LookupExpression { + /// The `Expression` constructed for this result. + /// + /// Note that, while a SPIR-V result id can be used in any block dominated + /// by its definition, a Naga `Expression` is only in scope for the rest of + /// its subtree. `Parser::get_expr_handle` takes care of spilling the result + /// to a `LocalVariable` which can then be used anywhere. + handle: Handle<crate::Expression>, + + /// The SPIR-V type of this result. + type_id: spirv::Word, + + /// The label id of the block that defines this expression. + /// + /// This is zero for globals, constants, and function parameters, since they + /// originate outside any function's block. + block_id: spirv::Word, +} + +#[derive(Debug)] +struct LookupMember { + type_id: spirv::Word, + // This is true for either matrices, or arrays of matrices (yikes). + row_major: bool, +} + +#[derive(Clone, Debug)] +enum LookupLoadOverride { + /// For arrays of matrices, we track them but not loading yet. + Pending, + /// For matrices, vectors, and scalars, we pre-load the data. + Loaded(Handle<crate::Expression>), +} + +#[derive(PartialEq)] +enum ExtendedClass { + Global(crate::AddressSpace), + Input, + Output, +} + +#[derive(Clone, Debug)] +pub struct Options { + /// The IR coordinate space matches all the APIs except SPIR-V, + /// so by default we flip the Y coordinate of the `BuiltIn::Position`. + /// This flag can be used to avoid this. + pub adjust_coordinate_space: bool, + /// Only allow shaders with the known set of capabilities. + pub strict_capabilities: bool, + pub block_ctx_dump_prefix: Option<PathBuf>, +} + +impl Default for Options { + fn default() -> Self { + Options { + adjust_coordinate_space: true, + strict_capabilities: false, + block_ctx_dump_prefix: None, + } + } +} + +/// An index into the `BlockContext::bodies` table. +type BodyIndex = usize; + +/// An intermediate representation of a Naga [`Statement`]. +/// +/// `Body` and `BodyFragment` values form a tree: the `BodyIndex` fields of the +/// variants are indices of the child `Body` values in [`BlockContext::bodies`]. +/// The `lower` function assembles the final `Statement` tree from this `Body` +/// tree. See [`BlockContext`] for details. +/// +/// [`Statement`]: crate::Statement +#[derive(Debug)] +enum BodyFragment { + BlockId(spirv::Word), + If { + condition: Handle<crate::Expression>, + accept: BodyIndex, + reject: BodyIndex, + }, + Loop { + /// The body of the loop. Its [`Body::parent`] is the block containing + /// this `Loop` fragment. + body: BodyIndex, + + /// The loop's continuing block. This is a grandchild: its + /// [`Body::parent`] is the loop body block, whose index is above. + continuing: BodyIndex, + + /// If the SPIR-V loop's back-edge branch is conditional, this is the + /// expression that must be `false` for the back-edge to be taken, with + /// `true` being for the "loop merge" (which breaks out of the loop). + break_if: Option<Handle<crate::Expression>>, + }, + Switch { + selector: Handle<crate::Expression>, + cases: Vec<(i32, BodyIndex)>, + default: BodyIndex, + }, + Break, + Continue, +} + +/// An intermediate representation of a Naga [`Block`]. +/// +/// This will be assembled into a `Block` once we've added spills for phi nodes +/// and out-of-scope expressions. See [`BlockContext`] for details. +/// +/// [`Block`]: crate::Block +#[derive(Debug)] +struct Body { + /// The index of the direct parent of this body + parent: usize, + data: Vec<BodyFragment>, +} + +impl Body { + /// Creates a new empty `Body` with the specified `parent` + pub const fn with_parent(parent: usize) -> Self { + Body { + parent, + data: Vec::new(), + } + } +} + +#[derive(Debug)] +struct PhiExpression { + /// The local variable used for the phi node + local: Handle<crate::LocalVariable>, + /// List of (expression, block) + expressions: Vec<(spirv::Word, spirv::Word)>, +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +enum MergeBlockInformation { + LoopMerge, + LoopContinue, + SelectionMerge, + SwitchMerge, +} + +/// Fragments of Naga IR, to be assembled into `Statements` once data flow is +/// resolved. +/// +/// We can't build a Naga `Statement` tree directly from SPIR-V blocks for three +/// main reasons: +/// +/// - We parse a function's SPIR-V blocks in the order they appear in the file. +/// Within a function, SPIR-V requires that a block must precede any blocks it +/// structurally dominates, but doesn't say much else about the order in which +/// they must appear. So while we know we'll see control flow header blocks +/// before their child constructs and merge blocks, those children and the +/// merge blocks may appear in any order - perhaps even intermingled with +/// children of other constructs. +/// +/// - A SPIR-V expression can be used in any SPIR-V block dominated by its +/// definition, whereas Naga expressions are scoped to the rest of their +/// subtree. This means that discovering an expression use later in the +/// function retroactively requires us to have spilled that expression into a +/// local variable back before we left its scope. +/// +/// - We translate SPIR-V OpPhi expressions as Naga local variables in which we +/// store the appropriate value before jumping to the OpPhi's block. +/// +/// All these cases require us to go back and amend previously generated Naga IR +/// based on things we discover later. But modifying old blocks in arbitrary +/// spots in a `Statement` tree is awkward. +/// +/// Instead, as we iterate through the function's body, we accumulate +/// control-flow-free fragments of Naga IR in the [`blocks`] table, while +/// building a skeleton of the Naga `Statement` tree in [`bodies`]. We note any +/// spills and temporaries we must introduce in [`phis`]. +/// +/// Finally, once we've processed the entire function, we add temporaries and +/// spills to the fragmentary `Blocks` as directed by `phis`, and assemble them +/// into the final Naga `Statement` tree as directed by `bodies`. +/// +/// [`blocks`]: BlockContext::blocks +/// [`bodies`]: BlockContext::bodies +/// [`phis`]: BlockContext::phis +/// [`lower`]: function::lower +#[derive(Debug)] +struct BlockContext<'function> { + /// Phi nodes encountered when parsing the function, used to generate spills + /// to local variables. + phis: Vec<PhiExpression>, + + /// Fragments of control-flow-free Naga IR. + /// + /// These will be stitched together into a proper [`Statement`] tree according + /// to `bodies`, once parsing is complete. + /// + /// [`Statement`]: crate::Statement + blocks: FastHashMap<spirv::Word, crate::Block>, + + /// Map from each SPIR-V block's label id to the index of the [`Body`] in + /// [`bodies`] the block should append its contents to. + /// + /// Since each statement in a Naga [`Block`] dominates the next, we are sure + /// to encounter their SPIR-V blocks in order. Thus, by having this table + /// map a SPIR-V structured control flow construct's merge block to the same + /// body index as its header block, when we encounter the merge block, we + /// will simply pick up building the [`Body`] where the header left off. + /// + /// A function's first block is special: it is the only block we encounter + /// without having seen its label mentioned in advance. (It's simply the + /// first `OpLabel` after the `OpFunction`.) We thus assume that any block + /// missing an entry here must be the first block, which always has body + /// index zero. + /// + /// [`bodies`]: BlockContext::bodies + /// [`Block`]: crate::Block + body_for_label: FastHashMap<spirv::Word, BodyIndex>, + + /// SPIR-V metadata about merge/continue blocks. + mergers: FastHashMap<spirv::Word, MergeBlockInformation>, + + /// A table of `Body` values, each representing a block in the final IR. + /// + /// The first element is always the function's top-level block. + bodies: Vec<Body>, + + /// Id of the function currently being processed + function_id: spirv::Word, + /// Expression arena of the function currently being processed + expressions: &'function mut Arena<crate::Expression>, + /// Local variables arena of the function currently being processed + local_arena: &'function mut Arena<crate::LocalVariable>, + /// Constants arena of the module being processed + const_arena: &'function mut Arena<crate::Constant>, + const_expressions: &'function mut Arena<crate::Expression>, + /// Type arena of the module being processed + type_arena: &'function UniqueArena<crate::Type>, + /// Global arena of the module being processed + global_arena: &'function Arena<crate::GlobalVariable>, + /// Arguments of the function currently being processed + arguments: &'function [crate::FunctionArgument], + /// Metadata about the usage of function parameters as sampling objects + parameter_sampling: &'function mut [image::SamplingFlags], +} + +enum SignAnchor { + Result, + Operand, +} + +pub struct Frontend<I> { + data: I, + data_offset: usize, + state: ModuleState, + layouter: Layouter, + temp_bytes: Vec<u8>, + ext_glsl_id: Option<spirv::Word>, + future_decor: FastHashMap<spirv::Word, Decoration>, + future_member_decor: FastHashMap<(spirv::Word, MemberIndex), Decoration>, + lookup_member: FastHashMap<(Handle<crate::Type>, MemberIndex), LookupMember>, + handle_sampling: FastHashMap<Handle<crate::GlobalVariable>, image::SamplingFlags>, + lookup_type: FastHashMap<spirv::Word, LookupType>, + lookup_void_type: Option<spirv::Word>, + lookup_storage_buffer_types: FastHashMap<Handle<crate::Type>, crate::StorageAccess>, + // Lookup for samplers and sampled images, storing flags on how they are used. + lookup_constant: FastHashMap<spirv::Word, LookupConstant>, + lookup_variable: FastHashMap<spirv::Word, LookupVariable>, + lookup_expression: FastHashMap<spirv::Word, LookupExpression>, + // Load overrides are used to work around row-major matrices + lookup_load_override: FastHashMap<spirv::Word, LookupLoadOverride>, + lookup_sampled_image: FastHashMap<spirv::Word, image::LookupSampledImage>, + lookup_function_type: FastHashMap<spirv::Word, LookupFunctionType>, + lookup_function: FastHashMap<spirv::Word, LookupFunction>, + lookup_entry_point: FastHashMap<spirv::Word, EntryPoint>, + //Note: each `OpFunctionCall` gets a single entry here, indexed by the + // dummy `Handle<crate::Function>` of the call site. + deferred_function_calls: Vec<spirv::Word>, + dummy_functions: Arena<crate::Function>, + // Graph of all function calls through the module. + // It's used to sort the functions (as nodes) topologically, + // so that in the IR any called function is already known. + function_call_graph: GraphMap<spirv::Word, (), petgraph::Directed>, + options: Options, + + /// Maps for a switch from a case target to the respective body and associated literals that + /// use that target block id. + /// + /// Used to preserve allocations between instruction parsing. + switch_cases: FastIndexMap<spirv::Word, (BodyIndex, Vec<i32>)>, + + /// Tracks access to gl_PerVertex's builtins, it is used to cull unused builtins since initializing those can + /// affect performance and the mere presence of some of these builtins might cause backends to error since they + /// might be unsupported. + /// + /// The problematic builtins are: PointSize, ClipDistance and CullDistance. + /// + /// glslang declares those by default even though they are never written to + /// (see <https://github.com/KhronosGroup/glslang/issues/1868>) + gl_per_vertex_builtin_access: FastHashSet<crate::BuiltIn>, +} + +impl<I: Iterator<Item = u32>> Frontend<I> { + pub fn new(data: I, options: &Options) -> Self { + Frontend { + data, + data_offset: 0, + state: ModuleState::Empty, + layouter: Layouter::default(), + temp_bytes: Vec::new(), + ext_glsl_id: None, + future_decor: FastHashMap::default(), + future_member_decor: FastHashMap::default(), + handle_sampling: FastHashMap::default(), + lookup_member: FastHashMap::default(), + lookup_type: FastHashMap::default(), + lookup_void_type: None, + lookup_storage_buffer_types: FastHashMap::default(), + lookup_constant: FastHashMap::default(), + lookup_variable: FastHashMap::default(), + lookup_expression: FastHashMap::default(), + lookup_load_override: FastHashMap::default(), + lookup_sampled_image: FastHashMap::default(), + lookup_function_type: FastHashMap::default(), + lookup_function: FastHashMap::default(), + lookup_entry_point: FastHashMap::default(), + deferred_function_calls: Vec::default(), + dummy_functions: Arena::new(), + function_call_graph: GraphMap::new(), + options: options.clone(), + switch_cases: FastIndexMap::default(), + gl_per_vertex_builtin_access: FastHashSet::default(), + } + } + + fn span_from(&self, from: usize) -> crate::Span { + crate::Span::from(from..self.data_offset) + } + + fn span_from_with_op(&self, from: usize) -> crate::Span { + crate::Span::from((from - 4)..self.data_offset) + } + + fn next(&mut self) -> Result<u32, Error> { + if let Some(res) = self.data.next() { + self.data_offset += 4; + Ok(res) + } else { + Err(Error::IncompleteData) + } + } + + fn next_inst(&mut self) -> Result<Instruction, Error> { + let word = self.next()?; + let (wc, opcode) = ((word >> 16) as u16, (word & 0xffff) as u16); + if wc == 0 { + return Err(Error::InvalidWordCount); + } + let op = spirv::Op::from_u32(opcode as u32).ok_or(Error::UnknownInstruction(opcode))?; + + Ok(Instruction { op, wc }) + } + + fn next_string(&mut self, mut count: u16) -> Result<(String, u16), Error> { + self.temp_bytes.clear(); + loop { + if count == 0 { + return Err(Error::BadString); + } + count -= 1; + let chars = self.next()?.to_le_bytes(); + let pos = chars.iter().position(|&c| c == 0).unwrap_or(4); + self.temp_bytes.extend_from_slice(&chars[..pos]); + if pos < 4 { + break; + } + } + std::str::from_utf8(&self.temp_bytes) + .map(|s| (s.to_owned(), count)) + .map_err(|_| Error::BadString) + } + + fn next_decoration( + &mut self, + inst: Instruction, + base_words: u16, + dec: &mut Decoration, + ) -> Result<(), Error> { + let raw = self.next()?; + let dec_typed = spirv::Decoration::from_u32(raw).ok_or(Error::InvalidDecoration(raw))?; + log::trace!("\t\t{}: {:?}", dec.debug_name(), dec_typed); + match dec_typed { + spirv::Decoration::BuiltIn => { + inst.expect(base_words + 2)?; + dec.built_in = Some(self.next()?); + } + spirv::Decoration::Location => { + inst.expect(base_words + 2)?; + dec.location = Some(self.next()?); + } + spirv::Decoration::DescriptorSet => { + inst.expect(base_words + 2)?; + dec.desc_set = Some(self.next()?); + } + spirv::Decoration::Binding => { + inst.expect(base_words + 2)?; + dec.desc_index = Some(self.next()?); + } + spirv::Decoration::BufferBlock => { + dec.storage_buffer = true; + } + spirv::Decoration::Offset => { + inst.expect(base_words + 2)?; + dec.offset = Some(self.next()?); + } + spirv::Decoration::ArrayStride => { + inst.expect(base_words + 2)?; + dec.array_stride = NonZeroU32::new(self.next()?); + } + spirv::Decoration::MatrixStride => { + inst.expect(base_words + 2)?; + dec.matrix_stride = NonZeroU32::new(self.next()?); + } + spirv::Decoration::Invariant => { + dec.invariant = true; + } + spirv::Decoration::NoPerspective => { + dec.interpolation = Some(crate::Interpolation::Linear); + } + spirv::Decoration::Flat => { + dec.interpolation = Some(crate::Interpolation::Flat); + } + spirv::Decoration::Centroid => { + dec.sampling = Some(crate::Sampling::Centroid); + } + spirv::Decoration::Sample => { + dec.sampling = Some(crate::Sampling::Sample); + } + spirv::Decoration::NonReadable => { + dec.flags |= DecorationFlags::NON_READABLE; + } + spirv::Decoration::NonWritable => { + dec.flags |= DecorationFlags::NON_WRITABLE; + } + spirv::Decoration::ColMajor => { + dec.matrix_major = Some(Majority::Column); + } + spirv::Decoration::RowMajor => { + dec.matrix_major = Some(Majority::Row); + } + spirv::Decoration::SpecId => { + dec.specialization = Some(self.next()?); + } + other => { + log::warn!("Unknown decoration {:?}", other); + for _ in base_words + 1..inst.wc { + let _var = self.next()?; + } + } + } + Ok(()) + } + + /// Return the Naga `Expression` for a given SPIR-V result `id`. + /// + /// `lookup` must be the `LookupExpression` for `id`. + /// + /// SPIR-V result ids can be used by any block dominated by the id's + /// definition, but Naga `Expressions` are only in scope for the remainder + /// of their `Statement` subtree. This means that the `Expression` generated + /// for `id` may no longer be in scope. In such cases, this function takes + /// care of spilling the value of `id` to a `LocalVariable` which can then + /// be used anywhere. The SPIR-V domination rule ensures that the + /// `LocalVariable` has been initialized before it is used. + /// + /// The `body_idx` argument should be the index of the `Body` that hopes to + /// use `id`'s `Expression`. + fn get_expr_handle( + &self, + id: spirv::Word, + lookup: &LookupExpression, + ctx: &mut BlockContext, + emitter: &mut crate::proc::Emitter, + block: &mut crate::Block, + body_idx: BodyIndex, + ) -> Handle<crate::Expression> { + // What `Body` was `id` defined in? + let expr_body_idx = ctx + .body_for_label + .get(&lookup.block_id) + .copied() + .unwrap_or(0); + + // Don't need to do a load/store if the expression is in the main body + // or if the expression is in the same body as where the query was + // requested. The body_idx might actually not be the final one if a loop + // or conditional occurs but in those cases we know that the new body + // will be a subscope of the body that was passed so we can still reuse + // the handle and not issue a load/store. + if is_parent(body_idx, expr_body_idx, ctx) { + lookup.handle + } else { + // Add a temporary variable of the same type which will be used to + // store the original expression and used in the current block + let ty = self.lookup_type[&lookup.type_id].handle; + let local = ctx.local_arena.append( + crate::LocalVariable { + name: None, + ty, + init: None, + }, + crate::Span::default(), + ); + + block.extend(emitter.finish(ctx.expressions)); + let pointer = ctx.expressions.append( + crate::Expression::LocalVariable(local), + crate::Span::default(), + ); + emitter.start(ctx.expressions); + let expr = ctx + .expressions + .append(crate::Expression::Load { pointer }, crate::Span::default()); + + // Add a slightly odd entry to the phi table, so that while `id`'s + // `Expression` is still in scope, the usual phi processing will + // spill its value to `local`, where we can find it later. + // + // This pretends that the block in which `id` is defined is the + // predecessor of some other block with a phi in it that cites id as + // one of its sources, and uses `local` as its variable. There is no + // such phi, but nobody needs to know that. + ctx.phis.push(PhiExpression { + local, + expressions: vec![(id, lookup.block_id)], + }); + + expr + } + } + + fn parse_expr_unary_op( + &mut self, + ctx: &mut BlockContext, + emitter: &mut crate::proc::Emitter, + block: &mut crate::Block, + block_id: spirv::Word, + body_idx: usize, + op: crate::UnaryOperator, + ) -> Result<(), Error> { + let start = self.data_offset; + let result_type_id = self.next()?; + let result_id = self.next()?; + let p_id = self.next()?; + + let p_lexp = self.lookup_expression.lookup(p_id)?; + let handle = self.get_expr_handle(p_id, p_lexp, ctx, emitter, block, body_idx); + + let expr = crate::Expression::Unary { op, expr: handle }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, self.span_from_with_op(start)), + type_id: result_type_id, + block_id, + }, + ); + Ok(()) + } + + fn parse_expr_binary_op( + &mut self, + ctx: &mut BlockContext, + emitter: &mut crate::proc::Emitter, + block: &mut crate::Block, + block_id: spirv::Word, + body_idx: usize, + op: crate::BinaryOperator, + ) -> Result<(), Error> { + let start = self.data_offset; + let result_type_id = self.next()?; + let result_id = self.next()?; + let p1_id = self.next()?; + let p2_id = self.next()?; + + let p1_lexp = self.lookup_expression.lookup(p1_id)?; + let left = self.get_expr_handle(p1_id, p1_lexp, ctx, emitter, block, body_idx); + let p2_lexp = self.lookup_expression.lookup(p2_id)?; + let right = self.get_expr_handle(p2_id, p2_lexp, ctx, emitter, block, body_idx); + + let expr = crate::Expression::Binary { op, left, right }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, self.span_from_with_op(start)), + type_id: result_type_id, + block_id, + }, + ); + Ok(()) + } + + /// A more complicated version of the unary op, + /// where we force the operand to have the same type as the result. + fn parse_expr_unary_op_sign_adjusted( + &mut self, + ctx: &mut BlockContext, + emitter: &mut crate::proc::Emitter, + block: &mut crate::Block, + block_id: spirv::Word, + body_idx: usize, + op: crate::UnaryOperator, + ) -> Result<(), Error> { + let start = self.data_offset; + let result_type_id = self.next()?; + let result_id = self.next()?; + let p1_id = self.next()?; + let span = self.span_from_with_op(start); + + let p1_lexp = self.lookup_expression.lookup(p1_id)?; + let left = self.get_expr_handle(p1_id, p1_lexp, ctx, emitter, block, body_idx); + + let result_lookup_ty = self.lookup_type.lookup(result_type_id)?; + let kind = ctx.type_arena[result_lookup_ty.handle] + .inner + .scalar_kind() + .unwrap(); + + let expr = crate::Expression::Unary { + op, + expr: if p1_lexp.type_id == result_type_id { + left + } else { + ctx.expressions.append( + crate::Expression::As { + expr: left, + kind, + convert: None, + }, + span, + ) + }, + }; + + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, span), + type_id: result_type_id, + block_id, + }, + ); + Ok(()) + } + + /// A more complicated version of the binary op, + /// where we force the operand to have the same type as the result. + /// This is mostly needed for "i++" and "i--" coming from GLSL. + #[allow(clippy::too_many_arguments)] + fn parse_expr_binary_op_sign_adjusted( + &mut self, + ctx: &mut BlockContext, + emitter: &mut crate::proc::Emitter, + block: &mut crate::Block, + block_id: spirv::Word, + body_idx: usize, + op: crate::BinaryOperator, + // For arithmetic operations, we need the sign of operands to match the result. + // For boolean operations, however, the operands need to match the signs, but + // result is always different - a boolean. + anchor: SignAnchor, + ) -> Result<(), Error> { + let start = self.data_offset; + let result_type_id = self.next()?; + let result_id = self.next()?; + let p1_id = self.next()?; + let p2_id = self.next()?; + let span = self.span_from_with_op(start); + + let p1_lexp = self.lookup_expression.lookup(p1_id)?; + let left = self.get_expr_handle(p1_id, p1_lexp, ctx, emitter, block, body_idx); + let p2_lexp = self.lookup_expression.lookup(p2_id)?; + let right = self.get_expr_handle(p2_id, p2_lexp, ctx, emitter, block, body_idx); + + let expected_type_id = match anchor { + SignAnchor::Result => result_type_id, + SignAnchor::Operand => p1_lexp.type_id, + }; + let expected_lookup_ty = self.lookup_type.lookup(expected_type_id)?; + let kind = ctx.type_arena[expected_lookup_ty.handle] + .inner + .scalar_kind() + .unwrap(); + + let expr = crate::Expression::Binary { + op, + left: if p1_lexp.type_id == expected_type_id { + left + } else { + ctx.expressions.append( + crate::Expression::As { + expr: left, + kind, + convert: None, + }, + span, + ) + }, + right: if p2_lexp.type_id == expected_type_id { + right + } else { + ctx.expressions.append( + crate::Expression::As { + expr: right, + kind, + convert: None, + }, + span, + ) + }, + }; + + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, span), + type_id: result_type_id, + block_id, + }, + ); + Ok(()) + } + + /// A version of the binary op where one or both of the arguments might need to be casted to a + /// specific integer kind (unsigned or signed), used for operations like OpINotEqual or + /// OpUGreaterThan. + #[allow(clippy::too_many_arguments)] + fn parse_expr_int_comparison( + &mut self, + ctx: &mut BlockContext, + emitter: &mut crate::proc::Emitter, + block: &mut crate::Block, + block_id: spirv::Word, + body_idx: usize, + op: crate::BinaryOperator, + kind: crate::ScalarKind, + ) -> Result<(), Error> { + let start = self.data_offset; + let result_type_id = self.next()?; + let result_id = self.next()?; + let p1_id = self.next()?; + let p2_id = self.next()?; + let span = self.span_from_with_op(start); + + let p1_lexp = self.lookup_expression.lookup(p1_id)?; + let left = self.get_expr_handle(p1_id, p1_lexp, ctx, emitter, block, body_idx); + let p1_lookup_ty = self.lookup_type.lookup(p1_lexp.type_id)?; + let p1_kind = ctx.type_arena[p1_lookup_ty.handle] + .inner + .scalar_kind() + .unwrap(); + let p2_lexp = self.lookup_expression.lookup(p2_id)?; + let right = self.get_expr_handle(p2_id, p2_lexp, ctx, emitter, block, body_idx); + let p2_lookup_ty = self.lookup_type.lookup(p2_lexp.type_id)?; + let p2_kind = ctx.type_arena[p2_lookup_ty.handle] + .inner + .scalar_kind() + .unwrap(); + + let expr = crate::Expression::Binary { + op, + left: if p1_kind == kind { + left + } else { + ctx.expressions.append( + crate::Expression::As { + expr: left, + kind, + convert: None, + }, + span, + ) + }, + right: if p2_kind == kind { + right + } else { + ctx.expressions.append( + crate::Expression::As { + expr: right, + kind, + convert: None, + }, + span, + ) + }, + }; + + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, span), + type_id: result_type_id, + block_id, + }, + ); + Ok(()) + } + + fn parse_expr_shift_op( + &mut self, + ctx: &mut BlockContext, + emitter: &mut crate::proc::Emitter, + block: &mut crate::Block, + block_id: spirv::Word, + body_idx: usize, + op: crate::BinaryOperator, + ) -> Result<(), Error> { + let start = self.data_offset; + let result_type_id = self.next()?; + let result_id = self.next()?; + let p1_id = self.next()?; + let p2_id = self.next()?; + + let span = self.span_from_with_op(start); + + let p1_lexp = self.lookup_expression.lookup(p1_id)?; + let left = self.get_expr_handle(p1_id, p1_lexp, ctx, emitter, block, body_idx); + let p2_lexp = self.lookup_expression.lookup(p2_id)?; + let p2_handle = self.get_expr_handle(p2_id, p2_lexp, ctx, emitter, block, body_idx); + // convert the shift to Uint + let right = ctx.expressions.append( + crate::Expression::As { + expr: p2_handle, + kind: crate::ScalarKind::Uint, + convert: None, + }, + span, + ); + + let expr = crate::Expression::Binary { op, left, right }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, span), + type_id: result_type_id, + block_id, + }, + ); + Ok(()) + } + + fn parse_expr_derivative( + &mut self, + ctx: &mut BlockContext, + emitter: &mut crate::proc::Emitter, + block: &mut crate::Block, + block_id: spirv::Word, + body_idx: usize, + (axis, ctrl): (crate::DerivativeAxis, crate::DerivativeControl), + ) -> Result<(), Error> { + let start = self.data_offset; + let result_type_id = self.next()?; + let result_id = self.next()?; + let arg_id = self.next()?; + + let arg_lexp = self.lookup_expression.lookup(arg_id)?; + let arg_handle = self.get_expr_handle(arg_id, arg_lexp, ctx, emitter, block, body_idx); + + let expr = crate::Expression::Derivative { + axis, + ctrl, + expr: arg_handle, + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, self.span_from_with_op(start)), + type_id: result_type_id, + block_id, + }, + ); + Ok(()) + } + + #[allow(clippy::too_many_arguments)] + fn insert_composite( + &self, + root_expr: Handle<crate::Expression>, + root_type_id: spirv::Word, + object_expr: Handle<crate::Expression>, + selections: &[spirv::Word], + type_arena: &UniqueArena<crate::Type>, + expressions: &mut Arena<crate::Expression>, + span: crate::Span, + ) -> Result<Handle<crate::Expression>, Error> { + let selection = match selections.first() { + Some(&index) => index, + None => return Ok(object_expr), + }; + let root_span = expressions.get_span(root_expr); + let root_lookup = self.lookup_type.lookup(root_type_id)?; + + let (count, child_type_id) = match type_arena[root_lookup.handle].inner { + crate::TypeInner::Struct { ref members, .. } => { + let child_member = self + .lookup_member + .get(&(root_lookup.handle, selection)) + .ok_or(Error::InvalidAccessType(root_type_id))?; + (members.len(), child_member.type_id) + } + crate::TypeInner::Array { size, .. } => { + let size = match size { + crate::ArraySize::Constant(size) => size.get(), + // A runtime sized array is not a composite type + crate::ArraySize::Dynamic => { + return Err(Error::InvalidAccessType(root_type_id)) + } + }; + + let child_type_id = root_lookup + .base_id + .ok_or(Error::InvalidAccessType(root_type_id))?; + + (size as usize, child_type_id) + } + crate::TypeInner::Vector { size, .. } + | crate::TypeInner::Matrix { columns: size, .. } => { + let child_type_id = root_lookup + .base_id + .ok_or(Error::InvalidAccessType(root_type_id))?; + (size as usize, child_type_id) + } + _ => return Err(Error::InvalidAccessType(root_type_id)), + }; + + let mut components = Vec::with_capacity(count); + for index in 0..count as u32 { + let expr = expressions.append( + crate::Expression::AccessIndex { + base: root_expr, + index, + }, + if index == selection { span } else { root_span }, + ); + components.push(expr); + } + components[selection as usize] = self.insert_composite( + components[selection as usize], + child_type_id, + object_expr, + &selections[1..], + type_arena, + expressions, + span, + )?; + + Ok(expressions.append( + crate::Expression::Compose { + ty: root_lookup.handle, + components, + }, + span, + )) + } + + /// Add the next SPIR-V block's contents to `block_ctx`. + /// + /// Except for the function's entry block, `block_id` should be the label of + /// a block we've seen mentioned before, with an entry in + /// `block_ctx.body_for_label` to tell us which `Body` it contributes to. + fn next_block(&mut self, block_id: spirv::Word, ctx: &mut BlockContext) -> Result<(), Error> { + // Extend `body` with the correct form for a branch to `target`. + fn merger(body: &mut Body, target: &MergeBlockInformation) { + body.data.push(match *target { + MergeBlockInformation::LoopContinue => BodyFragment::Continue, + MergeBlockInformation::LoopMerge | MergeBlockInformation::SwitchMerge => { + BodyFragment::Break + } + + // Finishing a selection merge means just falling off the end of + // the `accept` or `reject` block of the `If` statement. + MergeBlockInformation::SelectionMerge => return, + }) + } + + let mut emitter = crate::proc::Emitter::default(); + emitter.start(ctx.expressions); + + // Find the `Body` to which this block contributes. + // + // If this is some SPIR-V structured control flow construct's merge + // block, then `body_idx` will refer to the same `Body` as the header, + // so that we simply pick up accumulating the `Body` where the header + // left off. Each of the statements in a block dominates the next, so + // we're sure to encounter their SPIR-V blocks in order, ensuring that + // the `Body` will be assembled in the proper order. + // + // Note that, unlike every other kind of SPIR-V block, we don't know the + // function's first block's label in advance. Thus, we assume that if + // this block has no entry in `ctx.body_for_label`, it must be the + // function's first block. This always has body index zero. + let mut body_idx = *ctx.body_for_label.entry(block_id).or_default(); + + // The Naga IR block this call builds. This will end up as + // `ctx.blocks[&block_id]`, and `ctx.bodies[body_idx]` will refer to it + // via a `BodyFragment::BlockId`. + let mut block = crate::Block::new(); + + // Stores the merge block as defined by a `OpSelectionMerge` otherwise is `None` + // + // This is used in `OpSwitch` to promote the `MergeBlockInformation` from + // `SelectionMerge` to `SwitchMerge` to allow `Break`s this isn't desirable for + // `LoopMerge`s because otherwise `Continue`s wouldn't be allowed + let mut selection_merge_block = None; + + macro_rules! get_expr_handle { + ($id:expr, $lexp:expr) => { + self.get_expr_handle($id, $lexp, ctx, &mut emitter, &mut block, body_idx) + }; + } + macro_rules! parse_expr_op { + ($op:expr, BINARY) => { + self.parse_expr_binary_op(ctx, &mut emitter, &mut block, block_id, body_idx, $op) + }; + + ($op:expr, SHIFT) => { + self.parse_expr_shift_op(ctx, &mut emitter, &mut block, block_id, body_idx, $op) + }; + ($op:expr, UNARY) => { + self.parse_expr_unary_op(ctx, &mut emitter, &mut block, block_id, body_idx, $op) + }; + ($axis:expr, $ctrl:expr, DERIVATIVE) => { + self.parse_expr_derivative( + ctx, + &mut emitter, + &mut block, + block_id, + body_idx, + ($axis, $ctrl), + ) + }; + } + + let terminator = loop { + use spirv::Op; + let start = self.data_offset; + let inst = self.next_inst()?; + let span = crate::Span::from(start..(start + 4 * (inst.wc as usize))); + log::debug!("\t\t{:?} [{}]", inst.op, inst.wc); + + match inst.op { + Op::Line => { + inst.expect(4)?; + let _file_id = self.next()?; + let _row_id = self.next()?; + let _col_id = self.next()?; + } + Op::NoLine => inst.expect(1)?, + Op::Undef => { + inst.expect(3)?; + let type_id = self.next()?; + let id = self.next()?; + let type_lookup = self.lookup_type.lookup(type_id)?; + let ty = type_lookup.handle; + + self.lookup_expression.insert( + id, + LookupExpression { + handle: ctx + .expressions + .append(crate::Expression::ZeroValue(ty), span), + type_id, + block_id, + }, + ); + } + Op::Variable => { + inst.expect_at_least(4)?; + block.extend(emitter.finish(ctx.expressions)); + + let result_type_id = self.next()?; + let result_id = self.next()?; + let _storage_class = self.next()?; + let init = if inst.wc > 4 { + inst.expect(5)?; + let init_id = self.next()?; + let lconst = self.lookup_constant.lookup(init_id)?; + Some( + ctx.expressions + .append(crate::Expression::Constant(lconst.handle), span), + ) + } else { + None + }; + + let name = self + .future_decor + .remove(&result_id) + .and_then(|decor| decor.name); + if let Some(ref name) = name { + log::debug!("\t\t\tid={} name={}", result_id, name); + } + let lookup_ty = self.lookup_type.lookup(result_type_id)?; + let var_handle = ctx.local_arena.append( + crate::LocalVariable { + name, + ty: match ctx.type_arena[lookup_ty.handle].inner { + crate::TypeInner::Pointer { base, .. } => base, + _ => lookup_ty.handle, + }, + init, + }, + span, + ); + + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx + .expressions + .append(crate::Expression::LocalVariable(var_handle), span), + type_id: result_type_id, + block_id, + }, + ); + emitter.start(ctx.expressions); + } + Op::Phi => { + inst.expect_at_least(3)?; + block.extend(emitter.finish(ctx.expressions)); + + let result_type_id = self.next()?; + let result_id = self.next()?; + + let name = format!("phi_{result_id}"); + let local = ctx.local_arena.append( + crate::LocalVariable { + name: Some(name), + ty: self.lookup_type.lookup(result_type_id)?.handle, + init: None, + }, + self.span_from(start), + ); + let pointer = ctx + .expressions + .append(crate::Expression::LocalVariable(local), span); + + let in_count = (inst.wc - 3) / 2; + let mut phi = PhiExpression { + local, + expressions: Vec::with_capacity(in_count as usize), + }; + for _ in 0..in_count { + let expr = self.next()?; + let block = self.next()?; + phi.expressions.push((expr, block)); + } + + ctx.phis.push(phi); + emitter.start(ctx.expressions); + + // Associate the lookup with an actual value, which is emitted + // into the current block. + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx + .expressions + .append(crate::Expression::Load { pointer }, span), + type_id: result_type_id, + block_id, + }, + ); + } + Op::AccessChain | Op::InBoundsAccessChain => { + struct AccessExpression { + base_handle: Handle<crate::Expression>, + type_id: spirv::Word, + load_override: Option<LookupLoadOverride>, + } + + inst.expect_at_least(4)?; + + let result_type_id = self.next()?; + let result_id = self.next()?; + let base_id = self.next()?; + log::trace!("\t\t\tlooking up expr {:?}", base_id); + + let mut acex = { + let lexp = self.lookup_expression.lookup(base_id)?; + let lty = self.lookup_type.lookup(lexp.type_id)?; + + // HACK `OpAccessChain` and `OpInBoundsAccessChain` + // require for the result type to be a pointer, but if + // we're given a pointer to an image / sampler, it will + // be *already* dereferenced, since we do that early + // during `parse_type_pointer()`. + // + // This can happen only through `BindingArray`, since + // that's the only case where one can obtain a pointer + // to an image / sampler, and so let's match on that: + let dereference = match ctx.type_arena[lty.handle].inner { + crate::TypeInner::BindingArray { .. } => false, + _ => true, + }; + + let type_id = if dereference { + lty.base_id.ok_or(Error::InvalidAccessType(lexp.type_id))? + } else { + lexp.type_id + }; + + AccessExpression { + base_handle: get_expr_handle!(base_id, lexp), + type_id, + load_override: self.lookup_load_override.get(&base_id).cloned(), + } + }; + + for _ in 4..inst.wc { + let access_id = self.next()?; + log::trace!("\t\t\tlooking up index expr {:?}", access_id); + let index_expr = self.lookup_expression.lookup(access_id)?.clone(); + let index_expr_handle = get_expr_handle!(access_id, &index_expr); + let index_expr_data = &ctx.expressions[index_expr.handle]; + let index_maybe = match *index_expr_data { + crate::Expression::Constant(const_handle) => Some( + ctx.gctx() + .eval_expr_to_u32(ctx.const_arena[const_handle].init) + .map_err(|_| { + Error::InvalidAccess(crate::Expression::Constant( + const_handle, + )) + })?, + ), + _ => None, + }; + + log::trace!("\t\t\tlooking up type {:?}", acex.type_id); + let type_lookup = self.lookup_type.lookup(acex.type_id)?; + let ty = &ctx.type_arena[type_lookup.handle]; + acex = match ty.inner { + // can only index a struct with a constant + crate::TypeInner::Struct { ref members, .. } => { + let index = index_maybe + .ok_or_else(|| Error::InvalidAccess(index_expr_data.clone()))?; + + let lookup_member = self + .lookup_member + .get(&(type_lookup.handle, index)) + .ok_or(Error::InvalidAccessType(acex.type_id))?; + let base_handle = ctx.expressions.append( + crate::Expression::AccessIndex { + base: acex.base_handle, + index, + }, + span, + ); + + if ty.name.as_deref() == Some("gl_PerVertex") { + if let Some(crate::Binding::BuiltIn(built_in)) = + members[index as usize].binding + { + self.gl_per_vertex_builtin_access.insert(built_in); + } + } + + AccessExpression { + base_handle, + type_id: lookup_member.type_id, + load_override: if lookup_member.row_major { + debug_assert!(acex.load_override.is_none()); + let sub_type_lookup = + self.lookup_type.lookup(lookup_member.type_id)?; + Some(match ctx.type_arena[sub_type_lookup.handle].inner { + // load it transposed, to match column major expectations + crate::TypeInner::Matrix { .. } => { + let loaded = ctx.expressions.append( + crate::Expression::Load { + pointer: base_handle, + }, + span, + ); + let transposed = ctx.expressions.append( + crate::Expression::Math { + fun: crate::MathFunction::Transpose, + arg: loaded, + arg1: None, + arg2: None, + arg3: None, + }, + span, + ); + LookupLoadOverride::Loaded(transposed) + } + _ => LookupLoadOverride::Pending, + }) + } else { + None + }, + } + } + crate::TypeInner::Matrix { .. } => { + let load_override = match acex.load_override { + // We are indexing inside a row-major matrix + Some(LookupLoadOverride::Loaded(load_expr)) => { + let index = index_maybe.ok_or_else(|| { + Error::InvalidAccess(index_expr_data.clone()) + })?; + let sub_handle = ctx.expressions.append( + crate::Expression::AccessIndex { + base: load_expr, + index, + }, + span, + ); + Some(LookupLoadOverride::Loaded(sub_handle)) + } + _ => None, + }; + let sub_expr = match index_maybe { + Some(index) => crate::Expression::AccessIndex { + base: acex.base_handle, + index, + }, + None => crate::Expression::Access { + base: acex.base_handle, + index: index_expr_handle, + }, + }; + AccessExpression { + base_handle: ctx.expressions.append(sub_expr, span), + type_id: type_lookup + .base_id + .ok_or(Error::InvalidAccessType(acex.type_id))?, + load_override, + } + } + // This must be a vector or an array. + _ => { + let base_handle = ctx.expressions.append( + crate::Expression::Access { + base: acex.base_handle, + index: index_expr_handle, + }, + span, + ); + let load_override = match acex.load_override { + // If there is a load override in place, then we always end up + // with a side-loaded value here. + Some(lookup_load_override) => { + let sub_expr = match lookup_load_override { + // We must be indexing into the array of row-major matrices. + // Let's load the result of indexing and transpose it. + LookupLoadOverride::Pending => { + let loaded = ctx.expressions.append( + crate::Expression::Load { + pointer: base_handle, + }, + span, + ); + ctx.expressions.append( + crate::Expression::Math { + fun: crate::MathFunction::Transpose, + arg: loaded, + arg1: None, + arg2: None, + arg3: None, + }, + span, + ) + } + // We are indexing inside a row-major matrix. + LookupLoadOverride::Loaded(load_expr) => { + ctx.expressions.append( + crate::Expression::Access { + base: load_expr, + index: index_expr_handle, + }, + span, + ) + } + }; + Some(LookupLoadOverride::Loaded(sub_expr)) + } + None => None, + }; + AccessExpression { + base_handle, + type_id: type_lookup + .base_id + .ok_or(Error::InvalidAccessType(acex.type_id))?, + load_override, + } + } + }; + } + + if let Some(load_expr) = acex.load_override { + self.lookup_load_override.insert(result_id, load_expr); + } + let lookup_expression = LookupExpression { + handle: acex.base_handle, + type_id: result_type_id, + block_id, + }; + self.lookup_expression.insert(result_id, lookup_expression); + } + Op::VectorExtractDynamic => { + inst.expect(5)?; + + let result_type_id = self.next()?; + let id = self.next()?; + let composite_id = self.next()?; + let index_id = self.next()?; + + let root_lexp = self.lookup_expression.lookup(composite_id)?; + let root_handle = get_expr_handle!(composite_id, root_lexp); + let root_type_lookup = self.lookup_type.lookup(root_lexp.type_id)?; + let index_lexp = self.lookup_expression.lookup(index_id)?; + let index_handle = get_expr_handle!(index_id, index_lexp); + let index_type = self.lookup_type.lookup(index_lexp.type_id)?.handle; + + let num_components = match ctx.type_arena[root_type_lookup.handle].inner { + crate::TypeInner::Vector { size, .. } => size as u32, + _ => return Err(Error::InvalidVectorType(root_type_lookup.handle)), + }; + + let mut make_index = |ctx: &mut BlockContext, index: u32| { + make_index_literal( + ctx, + index, + &mut block, + &mut emitter, + index_type, + index_lexp.type_id, + span, + ) + }; + + let index_expr = make_index(ctx, 0)?; + let mut handle = ctx.expressions.append( + crate::Expression::Access { + base: root_handle, + index: index_expr, + }, + span, + ); + for index in 1..num_components { + let index_expr = make_index(ctx, index)?; + let access_expr = ctx.expressions.append( + crate::Expression::Access { + base: root_handle, + index: index_expr, + }, + span, + ); + let cond = ctx.expressions.append( + crate::Expression::Binary { + op: crate::BinaryOperator::Equal, + left: index_expr, + right: index_handle, + }, + span, + ); + handle = ctx.expressions.append( + crate::Expression::Select { + condition: cond, + accept: access_expr, + reject: handle, + }, + span, + ); + } + + self.lookup_expression.insert( + id, + LookupExpression { + handle, + type_id: result_type_id, + block_id, + }, + ); + } + Op::VectorInsertDynamic => { + inst.expect(6)?; + + let result_type_id = self.next()?; + let id = self.next()?; + let composite_id = self.next()?; + let object_id = self.next()?; + let index_id = self.next()?; + + let object_lexp = self.lookup_expression.lookup(object_id)?; + let object_handle = get_expr_handle!(object_id, object_lexp); + let root_lexp = self.lookup_expression.lookup(composite_id)?; + let root_handle = get_expr_handle!(composite_id, root_lexp); + let root_type_lookup = self.lookup_type.lookup(root_lexp.type_id)?; + let index_lexp = self.lookup_expression.lookup(index_id)?; + let index_handle = get_expr_handle!(index_id, index_lexp); + let index_type = self.lookup_type.lookup(index_lexp.type_id)?.handle; + + let num_components = match ctx.type_arena[root_type_lookup.handle].inner { + crate::TypeInner::Vector { size, .. } => size as u32, + _ => return Err(Error::InvalidVectorType(root_type_lookup.handle)), + }; + + let mut components = Vec::with_capacity(num_components as usize); + for index in 0..num_components { + let index_expr = make_index_literal( + ctx, + index, + &mut block, + &mut emitter, + index_type, + index_lexp.type_id, + span, + )?; + let access_expr = ctx.expressions.append( + crate::Expression::Access { + base: root_handle, + index: index_expr, + }, + span, + ); + let cond = ctx.expressions.append( + crate::Expression::Binary { + op: crate::BinaryOperator::Equal, + left: index_expr, + right: index_handle, + }, + span, + ); + let handle = ctx.expressions.append( + crate::Expression::Select { + condition: cond, + accept: object_handle, + reject: access_expr, + }, + span, + ); + components.push(handle); + } + let handle = ctx.expressions.append( + crate::Expression::Compose { + ty: root_type_lookup.handle, + components, + }, + span, + ); + + self.lookup_expression.insert( + id, + LookupExpression { + handle, + type_id: result_type_id, + block_id, + }, + ); + } + Op::CompositeExtract => { + inst.expect_at_least(4)?; + + let result_type_id = self.next()?; + let result_id = self.next()?; + let base_id = self.next()?; + log::trace!("\t\t\tlooking up expr {:?}", base_id); + let mut lexp = self.lookup_expression.lookup(base_id)?.clone(); + lexp.handle = get_expr_handle!(base_id, &lexp); + for _ in 4..inst.wc { + let index = self.next()?; + log::trace!("\t\t\tlooking up type {:?}", lexp.type_id); + let type_lookup = self.lookup_type.lookup(lexp.type_id)?; + let type_id = match ctx.type_arena[type_lookup.handle].inner { + crate::TypeInner::Struct { .. } => { + self.lookup_member + .get(&(type_lookup.handle, index)) + .ok_or(Error::InvalidAccessType(lexp.type_id))? + .type_id + } + crate::TypeInner::Array { .. } + | crate::TypeInner::Vector { .. } + | crate::TypeInner::Matrix { .. } => type_lookup + .base_id + .ok_or(Error::InvalidAccessType(lexp.type_id))?, + ref other => { + log::warn!("composite type {:?}", other); + return Err(Error::UnsupportedType(type_lookup.handle)); + } + }; + lexp = LookupExpression { + handle: ctx.expressions.append( + crate::Expression::AccessIndex { + base: lexp.handle, + index, + }, + span, + ), + type_id, + block_id, + }; + } + + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: lexp.handle, + type_id: result_type_id, + block_id, + }, + ); + } + Op::CompositeInsert => { + inst.expect_at_least(5)?; + + let result_type_id = self.next()?; + let id = self.next()?; + let object_id = self.next()?; + let composite_id = self.next()?; + let mut selections = Vec::with_capacity(inst.wc as usize - 5); + for _ in 5..inst.wc { + selections.push(self.next()?); + } + + let object_lexp = self.lookup_expression.lookup(object_id)?.clone(); + let object_handle = get_expr_handle!(object_id, &object_lexp); + let root_lexp = self.lookup_expression.lookup(composite_id)?.clone(); + let root_handle = get_expr_handle!(composite_id, &root_lexp); + let handle = self.insert_composite( + root_handle, + result_type_id, + object_handle, + &selections, + ctx.type_arena, + ctx.expressions, + span, + )?; + + self.lookup_expression.insert( + id, + LookupExpression { + handle, + type_id: result_type_id, + block_id, + }, + ); + } + Op::CompositeConstruct => { + inst.expect_at_least(3)?; + + let result_type_id = self.next()?; + let id = self.next()?; + let mut components = Vec::with_capacity(inst.wc as usize - 2); + for _ in 3..inst.wc { + let comp_id = self.next()?; + log::trace!("\t\t\tlooking up expr {:?}", comp_id); + let lexp = self.lookup_expression.lookup(comp_id)?; + let handle = get_expr_handle!(comp_id, lexp); + components.push(handle); + } + let ty = self.lookup_type.lookup(result_type_id)?.handle; + let first = components[0]; + let expr = match ctx.type_arena[ty].inner { + // this is an optimization to detect the splat + crate::TypeInner::Vector { size, .. } + if components.len() == size as usize + && components[1..].iter().all(|&c| c == first) => + { + crate::Expression::Splat { size, value: first } + } + _ => crate::Expression::Compose { ty, components }, + }; + self.lookup_expression.insert( + id, + LookupExpression { + handle: ctx.expressions.append(expr, span), + type_id: result_type_id, + block_id, + }, + ); + } + Op::Load => { + inst.expect_at_least(4)?; + + let result_type_id = self.next()?; + let result_id = self.next()?; + let pointer_id = self.next()?; + if inst.wc != 4 { + inst.expect(5)?; + let _memory_access = self.next()?; + } + + let base_lexp = self.lookup_expression.lookup(pointer_id)?; + let base_handle = get_expr_handle!(pointer_id, base_lexp); + let type_lookup = self.lookup_type.lookup(base_lexp.type_id)?; + let handle = match ctx.type_arena[type_lookup.handle].inner { + crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } => { + base_handle + } + _ => match self.lookup_load_override.get(&pointer_id) { + Some(&LookupLoadOverride::Loaded(handle)) => handle, + //Note: we aren't handling `LookupLoadOverride::Pending` properly here + _ => ctx.expressions.append( + crate::Expression::Load { + pointer: base_handle, + }, + span, + ), + }, + }; + + self.lookup_expression.insert( + result_id, + LookupExpression { + handle, + type_id: result_type_id, + block_id, + }, + ); + } + Op::Store => { + inst.expect_at_least(3)?; + + let pointer_id = self.next()?; + let value_id = self.next()?; + if inst.wc != 3 { + inst.expect(4)?; + let _memory_access = self.next()?; + } + let base_expr = self.lookup_expression.lookup(pointer_id)?; + let base_handle = get_expr_handle!(pointer_id, base_expr); + let value_expr = self.lookup_expression.lookup(value_id)?; + let value_handle = get_expr_handle!(value_id, value_expr); + + block.extend(emitter.finish(ctx.expressions)); + block.push( + crate::Statement::Store { + pointer: base_handle, + value: value_handle, + }, + span, + ); + emitter.start(ctx.expressions); + } + // Arithmetic Instructions +, -, *, /, % + Op::SNegate | Op::FNegate => { + inst.expect(4)?; + self.parse_expr_unary_op_sign_adjusted( + ctx, + &mut emitter, + &mut block, + block_id, + body_idx, + crate::UnaryOperator::Negate, + )?; + } + Op::IAdd + | Op::ISub + | Op::IMul + | Op::BitwiseOr + | Op::BitwiseXor + | Op::BitwiseAnd + | Op::SDiv + | Op::SRem => { + inst.expect(5)?; + let operator = map_binary_operator(inst.op)?; + self.parse_expr_binary_op_sign_adjusted( + ctx, + &mut emitter, + &mut block, + block_id, + body_idx, + operator, + SignAnchor::Result, + )?; + } + Op::IEqual | Op::INotEqual => { + inst.expect(5)?; + let operator = map_binary_operator(inst.op)?; + self.parse_expr_binary_op_sign_adjusted( + ctx, + &mut emitter, + &mut block, + block_id, + body_idx, + operator, + SignAnchor::Operand, + )?; + } + Op::FAdd => { + inst.expect(5)?; + parse_expr_op!(crate::BinaryOperator::Add, BINARY)?; + } + Op::FSub => { + inst.expect(5)?; + parse_expr_op!(crate::BinaryOperator::Subtract, BINARY)?; + } + Op::FMul => { + inst.expect(5)?; + parse_expr_op!(crate::BinaryOperator::Multiply, BINARY)?; + } + Op::UDiv | Op::FDiv => { + inst.expect(5)?; + parse_expr_op!(crate::BinaryOperator::Divide, BINARY)?; + } + Op::UMod | Op::FRem => { + inst.expect(5)?; + parse_expr_op!(crate::BinaryOperator::Modulo, BINARY)?; + } + Op::SMod => { + inst.expect(5)?; + + // x - y * int(floor(float(x) / float(y))) + + let start = self.data_offset; + let result_type_id = self.next()?; + let result_id = self.next()?; + let p1_id = self.next()?; + let p2_id = self.next()?; + let span = self.span_from_with_op(start); + + let p1_lexp = self.lookup_expression.lookup(p1_id)?; + let left = self.get_expr_handle( + p1_id, + p1_lexp, + ctx, + &mut emitter, + &mut block, + body_idx, + ); + let p2_lexp = self.lookup_expression.lookup(p2_id)?; + let right = self.get_expr_handle( + p2_id, + p2_lexp, + ctx, + &mut emitter, + &mut block, + body_idx, + ); + + let result_ty = self.lookup_type.lookup(result_type_id)?; + let inner = &ctx.type_arena[result_ty.handle].inner; + let kind = inner.scalar_kind().unwrap(); + let size = inner.size(ctx.gctx()) as u8; + + let left_cast = ctx.expressions.append( + crate::Expression::As { + expr: left, + kind: crate::ScalarKind::Float, + convert: Some(size), + }, + span, + ); + let right_cast = ctx.expressions.append( + crate::Expression::As { + expr: right, + kind: crate::ScalarKind::Float, + convert: Some(size), + }, + span, + ); + let div = ctx.expressions.append( + crate::Expression::Binary { + op: crate::BinaryOperator::Divide, + left: left_cast, + right: right_cast, + }, + span, + ); + let floor = ctx.expressions.append( + crate::Expression::Math { + fun: crate::MathFunction::Floor, + arg: div, + arg1: None, + arg2: None, + arg3: None, + }, + span, + ); + let cast = ctx.expressions.append( + crate::Expression::As { + expr: floor, + kind, + convert: Some(size), + }, + span, + ); + let mult = ctx.expressions.append( + crate::Expression::Binary { + op: crate::BinaryOperator::Multiply, + left: cast, + right, + }, + span, + ); + let sub = ctx.expressions.append( + crate::Expression::Binary { + op: crate::BinaryOperator::Subtract, + left, + right: mult, + }, + span, + ); + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: sub, + type_id: result_type_id, + block_id, + }, + ); + } + Op::FMod => { + inst.expect(5)?; + + // x - y * floor(x / y) + + let start = self.data_offset; + let span = self.span_from_with_op(start); + + let result_type_id = self.next()?; + let result_id = self.next()?; + let p1_id = self.next()?; + let p2_id = self.next()?; + + let p1_lexp = self.lookup_expression.lookup(p1_id)?; + let left = self.get_expr_handle( + p1_id, + p1_lexp, + ctx, + &mut emitter, + &mut block, + body_idx, + ); + let p2_lexp = self.lookup_expression.lookup(p2_id)?; + let right = self.get_expr_handle( + p2_id, + p2_lexp, + ctx, + &mut emitter, + &mut block, + body_idx, + ); + + let div = ctx.expressions.append( + crate::Expression::Binary { + op: crate::BinaryOperator::Divide, + left, + right, + }, + span, + ); + let floor = ctx.expressions.append( + crate::Expression::Math { + fun: crate::MathFunction::Floor, + arg: div, + arg1: None, + arg2: None, + arg3: None, + }, + span, + ); + let mult = ctx.expressions.append( + crate::Expression::Binary { + op: crate::BinaryOperator::Multiply, + left: floor, + right, + }, + span, + ); + let sub = ctx.expressions.append( + crate::Expression::Binary { + op: crate::BinaryOperator::Subtract, + left, + right: mult, + }, + span, + ); + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: sub, + type_id: result_type_id, + block_id, + }, + ); + } + Op::VectorTimesScalar + | Op::VectorTimesMatrix + | Op::MatrixTimesScalar + | Op::MatrixTimesVector + | Op::MatrixTimesMatrix => { + inst.expect(5)?; + parse_expr_op!(crate::BinaryOperator::Multiply, BINARY)?; + } + Op::Transpose => { + inst.expect(4)?; + + let result_type_id = self.next()?; + let result_id = self.next()?; + let matrix_id = self.next()?; + let matrix_lexp = self.lookup_expression.lookup(matrix_id)?; + let matrix_handle = get_expr_handle!(matrix_id, matrix_lexp); + let expr = crate::Expression::Math { + fun: crate::MathFunction::Transpose, + arg: matrix_handle, + arg1: None, + arg2: None, + arg3: None, + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, span), + type_id: result_type_id, + block_id, + }, + ); + } + Op::Dot => { + inst.expect(5)?; + + let result_type_id = self.next()?; + let result_id = self.next()?; + let left_id = self.next()?; + let right_id = self.next()?; + let left_lexp = self.lookup_expression.lookup(left_id)?; + let left_handle = get_expr_handle!(left_id, left_lexp); + let right_lexp = self.lookup_expression.lookup(right_id)?; + let right_handle = get_expr_handle!(right_id, right_lexp); + let expr = crate::Expression::Math { + fun: crate::MathFunction::Dot, + arg: left_handle, + arg1: Some(right_handle), + arg2: None, + arg3: None, + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, span), + type_id: result_type_id, + block_id, + }, + ); + } + Op::BitFieldInsert => { + inst.expect(7)?; + + let start = self.data_offset; + let span = self.span_from_with_op(start); + + let result_type_id = self.next()?; + let result_id = self.next()?; + let base_id = self.next()?; + let insert_id = self.next()?; + let offset_id = self.next()?; + let count_id = self.next()?; + let base_lexp = self.lookup_expression.lookup(base_id)?; + let base_handle = get_expr_handle!(base_id, base_lexp); + let insert_lexp = self.lookup_expression.lookup(insert_id)?; + let insert_handle = get_expr_handle!(insert_id, insert_lexp); + let offset_lexp = self.lookup_expression.lookup(offset_id)?; + let offset_handle = get_expr_handle!(offset_id, offset_lexp); + let offset_lookup_ty = self.lookup_type.lookup(offset_lexp.type_id)?; + let count_lexp = self.lookup_expression.lookup(count_id)?; + let count_handle = get_expr_handle!(count_id, count_lexp); + let count_lookup_ty = self.lookup_type.lookup(count_lexp.type_id)?; + + let offset_kind = ctx.type_arena[offset_lookup_ty.handle] + .inner + .scalar_kind() + .unwrap(); + let count_kind = ctx.type_arena[count_lookup_ty.handle] + .inner + .scalar_kind() + .unwrap(); + + let offset_cast_handle = if offset_kind != crate::ScalarKind::Uint { + ctx.expressions.append( + crate::Expression::As { + expr: offset_handle, + kind: crate::ScalarKind::Uint, + convert: None, + }, + span, + ) + } else { + offset_handle + }; + + let count_cast_handle = if count_kind != crate::ScalarKind::Uint { + ctx.expressions.append( + crate::Expression::As { + expr: count_handle, + kind: crate::ScalarKind::Uint, + convert: None, + }, + span, + ) + } else { + count_handle + }; + + let expr = crate::Expression::Math { + fun: crate::MathFunction::InsertBits, + arg: base_handle, + arg1: Some(insert_handle), + arg2: Some(offset_cast_handle), + arg3: Some(count_cast_handle), + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, span), + type_id: result_type_id, + block_id, + }, + ); + } + Op::BitFieldSExtract | Op::BitFieldUExtract => { + inst.expect(6)?; + + let result_type_id = self.next()?; + let result_id = self.next()?; + let base_id = self.next()?; + let offset_id = self.next()?; + let count_id = self.next()?; + let base_lexp = self.lookup_expression.lookup(base_id)?; + let base_handle = get_expr_handle!(base_id, base_lexp); + let offset_lexp = self.lookup_expression.lookup(offset_id)?; + let offset_handle = get_expr_handle!(offset_id, offset_lexp); + let offset_lookup_ty = self.lookup_type.lookup(offset_lexp.type_id)?; + let count_lexp = self.lookup_expression.lookup(count_id)?; + let count_handle = get_expr_handle!(count_id, count_lexp); + let count_lookup_ty = self.lookup_type.lookup(count_lexp.type_id)?; + + let offset_kind = ctx.type_arena[offset_lookup_ty.handle] + .inner + .scalar_kind() + .unwrap(); + let count_kind = ctx.type_arena[count_lookup_ty.handle] + .inner + .scalar_kind() + .unwrap(); + + let offset_cast_handle = if offset_kind != crate::ScalarKind::Uint { + ctx.expressions.append( + crate::Expression::As { + expr: offset_handle, + kind: crate::ScalarKind::Uint, + convert: None, + }, + span, + ) + } else { + offset_handle + }; + + let count_cast_handle = if count_kind != crate::ScalarKind::Uint { + ctx.expressions.append( + crate::Expression::As { + expr: count_handle, + kind: crate::ScalarKind::Uint, + convert: None, + }, + span, + ) + } else { + count_handle + }; + + let expr = crate::Expression::Math { + fun: crate::MathFunction::ExtractBits, + arg: base_handle, + arg1: Some(offset_cast_handle), + arg2: Some(count_cast_handle), + arg3: None, + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, span), + type_id: result_type_id, + block_id, + }, + ); + } + Op::BitReverse | Op::BitCount => { + inst.expect(4)?; + + let result_type_id = self.next()?; + let result_id = self.next()?; + let base_id = self.next()?; + let base_lexp = self.lookup_expression.lookup(base_id)?; + let base_handle = get_expr_handle!(base_id, base_lexp); + let expr = crate::Expression::Math { + fun: match inst.op { + Op::BitReverse => crate::MathFunction::ReverseBits, + Op::BitCount => crate::MathFunction::CountOneBits, + _ => unreachable!(), + }, + arg: base_handle, + arg1: None, + arg2: None, + arg3: None, + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, span), + type_id: result_type_id, + block_id, + }, + ); + } + Op::OuterProduct => { + inst.expect(5)?; + + let result_type_id = self.next()?; + let result_id = self.next()?; + let left_id = self.next()?; + let right_id = self.next()?; + let left_lexp = self.lookup_expression.lookup(left_id)?; + let left_handle = get_expr_handle!(left_id, left_lexp); + let right_lexp = self.lookup_expression.lookup(right_id)?; + let right_handle = get_expr_handle!(right_id, right_lexp); + let expr = crate::Expression::Math { + fun: crate::MathFunction::Outer, + arg: left_handle, + arg1: Some(right_handle), + arg2: None, + arg3: None, + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, span), + type_id: result_type_id, + block_id, + }, + ); + } + // Bitwise instructions + Op::Not => { + inst.expect(4)?; + self.parse_expr_unary_op_sign_adjusted( + ctx, + &mut emitter, + &mut block, + block_id, + body_idx, + crate::UnaryOperator::BitwiseNot, + )?; + } + Op::ShiftRightLogical => { + inst.expect(5)?; + //TODO: convert input and result to unsigned + parse_expr_op!(crate::BinaryOperator::ShiftRight, SHIFT)?; + } + Op::ShiftRightArithmetic => { + inst.expect(5)?; + //TODO: convert input and result to signed + parse_expr_op!(crate::BinaryOperator::ShiftRight, SHIFT)?; + } + Op::ShiftLeftLogical => { + inst.expect(5)?; + parse_expr_op!(crate::BinaryOperator::ShiftLeft, SHIFT)?; + } + // Sampling + Op::Image => { + inst.expect(4)?; + self.parse_image_uncouple(block_id)?; + } + Op::SampledImage => { + inst.expect(5)?; + self.parse_image_couple()?; + } + Op::ImageWrite => { + let extra = inst.expect_at_least(4)?; + let stmt = + self.parse_image_write(extra, ctx, &mut emitter, &mut block, body_idx)?; + block.extend(emitter.finish(ctx.expressions)); + block.push(stmt, span); + emitter.start(ctx.expressions); + } + Op::ImageFetch | Op::ImageRead => { + let extra = inst.expect_at_least(5)?; + self.parse_image_load( + extra, + ctx, + &mut emitter, + &mut block, + block_id, + body_idx, + )?; + } + Op::ImageSampleImplicitLod | Op::ImageSampleExplicitLod => { + let extra = inst.expect_at_least(5)?; + let options = image::SamplingOptions { + compare: false, + project: false, + }; + self.parse_image_sample( + extra, + options, + ctx, + &mut emitter, + &mut block, + block_id, + body_idx, + )?; + } + Op::ImageSampleProjImplicitLod | Op::ImageSampleProjExplicitLod => { + let extra = inst.expect_at_least(5)?; + let options = image::SamplingOptions { + compare: false, + project: true, + }; + self.parse_image_sample( + extra, + options, + ctx, + &mut emitter, + &mut block, + block_id, + body_idx, + )?; + } + Op::ImageSampleDrefImplicitLod | Op::ImageSampleDrefExplicitLod => { + let extra = inst.expect_at_least(6)?; + let options = image::SamplingOptions { + compare: true, + project: false, + }; + self.parse_image_sample( + extra, + options, + ctx, + &mut emitter, + &mut block, + block_id, + body_idx, + )?; + } + Op::ImageSampleProjDrefImplicitLod | Op::ImageSampleProjDrefExplicitLod => { + let extra = inst.expect_at_least(6)?; + let options = image::SamplingOptions { + compare: true, + project: true, + }; + self.parse_image_sample( + extra, + options, + ctx, + &mut emitter, + &mut block, + block_id, + body_idx, + )?; + } + Op::ImageQuerySize => { + inst.expect(4)?; + self.parse_image_query_size( + false, + ctx, + &mut emitter, + &mut block, + block_id, + body_idx, + )?; + } + Op::ImageQuerySizeLod => { + inst.expect(5)?; + self.parse_image_query_size( + true, + ctx, + &mut emitter, + &mut block, + block_id, + body_idx, + )?; + } + Op::ImageQueryLevels => { + inst.expect(4)?; + self.parse_image_query_other(crate::ImageQuery::NumLevels, ctx, block_id)?; + } + Op::ImageQuerySamples => { + inst.expect(4)?; + self.parse_image_query_other(crate::ImageQuery::NumSamples, ctx, block_id)?; + } + // other ops + Op::Select => { + inst.expect(6)?; + let result_type_id = self.next()?; + let result_id = self.next()?; + let condition = self.next()?; + let o1_id = self.next()?; + let o2_id = self.next()?; + + let cond_lexp = self.lookup_expression.lookup(condition)?; + let cond_handle = get_expr_handle!(condition, cond_lexp); + let o1_lexp = self.lookup_expression.lookup(o1_id)?; + let o1_handle = get_expr_handle!(o1_id, o1_lexp); + let o2_lexp = self.lookup_expression.lookup(o2_id)?; + let o2_handle = get_expr_handle!(o2_id, o2_lexp); + + let expr = crate::Expression::Select { + condition: cond_handle, + accept: o1_handle, + reject: o2_handle, + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, span), + type_id: result_type_id, + block_id, + }, + ); + } + Op::VectorShuffle => { + inst.expect_at_least(5)?; + let result_type_id = self.next()?; + let result_id = self.next()?; + let v1_id = self.next()?; + let v2_id = self.next()?; + + let v1_lexp = self.lookup_expression.lookup(v1_id)?; + let v1_lty = self.lookup_type.lookup(v1_lexp.type_id)?; + let v1_handle = get_expr_handle!(v1_id, v1_lexp); + let n1 = match ctx.type_arena[v1_lty.handle].inner { + crate::TypeInner::Vector { size, .. } => size as u32, + _ => return Err(Error::InvalidInnerType(v1_lexp.type_id)), + }; + let v2_lexp = self.lookup_expression.lookup(v2_id)?; + let v2_lty = self.lookup_type.lookup(v2_lexp.type_id)?; + let v2_handle = get_expr_handle!(v2_id, v2_lexp); + let n2 = match ctx.type_arena[v2_lty.handle].inner { + crate::TypeInner::Vector { size, .. } => size as u32, + _ => return Err(Error::InvalidInnerType(v2_lexp.type_id)), + }; + + self.temp_bytes.clear(); + let mut max_component = 0; + for _ in 5..inst.wc as usize { + let mut index = self.next()?; + if index == u32::MAX { + // treat Undefined as X + index = 0; + } + max_component = max_component.max(index); + self.temp_bytes.push(index as u8); + } + + // Check for swizzle first. + let expr = if max_component < n1 { + use crate::SwizzleComponent as Sc; + let size = match self.temp_bytes.len() { + 2 => crate::VectorSize::Bi, + 3 => crate::VectorSize::Tri, + _ => crate::VectorSize::Quad, + }; + let mut pattern = [Sc::X; 4]; + for (pat, index) in pattern.iter_mut().zip(self.temp_bytes.drain(..)) { + *pat = match index { + 0 => Sc::X, + 1 => Sc::Y, + 2 => Sc::Z, + _ => Sc::W, + }; + } + crate::Expression::Swizzle { + size, + vector: v1_handle, + pattern, + } + } else { + // Fall back to access + compose + let mut components = Vec::with_capacity(self.temp_bytes.len()); + for index in self.temp_bytes.drain(..).map(|i| i as u32) { + let expr = if index < n1 { + crate::Expression::AccessIndex { + base: v1_handle, + index, + } + } else if index < n1 + n2 { + crate::Expression::AccessIndex { + base: v2_handle, + index: index - n1, + } + } else { + return Err(Error::InvalidAccessIndex(index)); + }; + components.push(ctx.expressions.append(expr, span)); + } + crate::Expression::Compose { + ty: self.lookup_type.lookup(result_type_id)?.handle, + components, + } + }; + + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, span), + type_id: result_type_id, + block_id, + }, + ); + } + Op::Bitcast + | Op::ConvertSToF + | Op::ConvertUToF + | Op::ConvertFToU + | Op::ConvertFToS + | Op::FConvert + | Op::UConvert + | Op::SConvert => { + inst.expect(4)?; + let result_type_id = self.next()?; + let result_id = self.next()?; + let value_id = self.next()?; + + let value_lexp = self.lookup_expression.lookup(value_id)?; + let ty_lookup = self.lookup_type.lookup(result_type_id)?; + let scalar = match ctx.type_arena[ty_lookup.handle].inner { + crate::TypeInner::Scalar(scalar) + | crate::TypeInner::Vector { scalar, .. } + | crate::TypeInner::Matrix { scalar, .. } => scalar, + _ => return Err(Error::InvalidAsType(ty_lookup.handle)), + }; + + let expr = crate::Expression::As { + expr: get_expr_handle!(value_id, value_lexp), + kind: scalar.kind, + convert: if scalar.kind == crate::ScalarKind::Bool { + Some(crate::BOOL_WIDTH) + } else if inst.op == Op::Bitcast { + None + } else { + Some(scalar.width) + }, + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, span), + type_id: result_type_id, + block_id, + }, + ); + } + Op::FunctionCall => { + inst.expect_at_least(4)?; + block.extend(emitter.finish(ctx.expressions)); + + let result_type_id = self.next()?; + let result_id = self.next()?; + let func_id = self.next()?; + + let mut arguments = Vec::with_capacity(inst.wc as usize - 4); + for _ in 0..arguments.capacity() { + let arg_id = self.next()?; + let lexp = self.lookup_expression.lookup(arg_id)?; + arguments.push(get_expr_handle!(arg_id, lexp)); + } + + // We just need an unique handle here, nothing more. + let function = self.add_call(ctx.function_id, func_id); + + let result = if self.lookup_void_type == Some(result_type_id) { + None + } else { + let expr_handle = ctx + .expressions + .append(crate::Expression::CallResult(function), span); + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: expr_handle, + type_id: result_type_id, + block_id, + }, + ); + Some(expr_handle) + }; + block.push( + crate::Statement::Call { + function, + arguments, + result, + }, + span, + ); + emitter.start(ctx.expressions); + } + Op::ExtInst => { + use crate::MathFunction as Mf; + use spirv::GLOp as Glo; + + let base_wc = 5; + inst.expect_at_least(base_wc)?; + + let result_type_id = self.next()?; + let result_id = self.next()?; + let set_id = self.next()?; + if Some(set_id) != self.ext_glsl_id { + return Err(Error::UnsupportedExtInstSet(set_id)); + } + let inst_id = self.next()?; + let gl_op = Glo::from_u32(inst_id).ok_or(Error::UnsupportedExtInst(inst_id))?; + + let fun = match gl_op { + Glo::Round => Mf::Round, + Glo::RoundEven => Mf::Round, + Glo::Trunc => Mf::Trunc, + Glo::FAbs | Glo::SAbs => Mf::Abs, + Glo::FSign | Glo::SSign => Mf::Sign, + Glo::Floor => Mf::Floor, + Glo::Ceil => Mf::Ceil, + Glo::Fract => Mf::Fract, + Glo::Sin => Mf::Sin, + Glo::Cos => Mf::Cos, + Glo::Tan => Mf::Tan, + Glo::Asin => Mf::Asin, + Glo::Acos => Mf::Acos, + Glo::Atan => Mf::Atan, + Glo::Sinh => Mf::Sinh, + Glo::Cosh => Mf::Cosh, + Glo::Tanh => Mf::Tanh, + Glo::Atan2 => Mf::Atan2, + Glo::Asinh => Mf::Asinh, + Glo::Acosh => Mf::Acosh, + Glo::Atanh => Mf::Atanh, + Glo::Radians => Mf::Radians, + Glo::Degrees => Mf::Degrees, + Glo::Pow => Mf::Pow, + Glo::Exp => Mf::Exp, + Glo::Log => Mf::Log, + Glo::Exp2 => Mf::Exp2, + Glo::Log2 => Mf::Log2, + Glo::Sqrt => Mf::Sqrt, + Glo::InverseSqrt => Mf::InverseSqrt, + Glo::MatrixInverse => Mf::Inverse, + Glo::Determinant => Mf::Determinant, + Glo::ModfStruct => Mf::Modf, + Glo::FMin | Glo::UMin | Glo::SMin | Glo::NMin => Mf::Min, + Glo::FMax | Glo::UMax | Glo::SMax | Glo::NMax => Mf::Max, + Glo::FClamp | Glo::UClamp | Glo::SClamp | Glo::NClamp => Mf::Clamp, + Glo::FMix => Mf::Mix, + Glo::Step => Mf::Step, + Glo::SmoothStep => Mf::SmoothStep, + Glo::Fma => Mf::Fma, + Glo::FrexpStruct => Mf::Frexp, + Glo::Ldexp => Mf::Ldexp, + Glo::Length => Mf::Length, + Glo::Distance => Mf::Distance, + Glo::Cross => Mf::Cross, + Glo::Normalize => Mf::Normalize, + Glo::FaceForward => Mf::FaceForward, + Glo::Reflect => Mf::Reflect, + Glo::Refract => Mf::Refract, + Glo::PackUnorm4x8 => Mf::Pack4x8unorm, + Glo::PackSnorm4x8 => Mf::Pack4x8snorm, + Glo::PackHalf2x16 => Mf::Pack2x16float, + Glo::PackUnorm2x16 => Mf::Pack2x16unorm, + Glo::PackSnorm2x16 => Mf::Pack2x16snorm, + Glo::UnpackUnorm4x8 => Mf::Unpack4x8unorm, + Glo::UnpackSnorm4x8 => Mf::Unpack4x8snorm, + Glo::UnpackHalf2x16 => Mf::Unpack2x16float, + Glo::UnpackUnorm2x16 => Mf::Unpack2x16unorm, + Glo::UnpackSnorm2x16 => Mf::Unpack2x16snorm, + Glo::FindILsb => Mf::FindLsb, + Glo::FindUMsb | Glo::FindSMsb => Mf::FindMsb, + // TODO: https://github.com/gfx-rs/naga/issues/2526 + Glo::Modf | Glo::Frexp => return Err(Error::UnsupportedExtInst(inst_id)), + Glo::IMix + | Glo::PackDouble2x32 + | Glo::UnpackDouble2x32 + | Glo::InterpolateAtCentroid + | Glo::InterpolateAtSample + | Glo::InterpolateAtOffset => { + return Err(Error::UnsupportedExtInst(inst_id)) + } + }; + + let arg_count = fun.argument_count(); + inst.expect(base_wc + arg_count as u16)?; + let arg = { + let arg_id = self.next()?; + let lexp = self.lookup_expression.lookup(arg_id)?; + get_expr_handle!(arg_id, lexp) + }; + let arg1 = if arg_count > 1 { + let arg_id = self.next()?; + let lexp = self.lookup_expression.lookup(arg_id)?; + Some(get_expr_handle!(arg_id, lexp)) + } else { + None + }; + let arg2 = if arg_count > 2 { + let arg_id = self.next()?; + let lexp = self.lookup_expression.lookup(arg_id)?; + Some(get_expr_handle!(arg_id, lexp)) + } else { + None + }; + let arg3 = if arg_count > 3 { + let arg_id = self.next()?; + let lexp = self.lookup_expression.lookup(arg_id)?; + Some(get_expr_handle!(arg_id, lexp)) + } else { + None + }; + + let expr = crate::Expression::Math { + fun, + arg, + arg1, + arg2, + arg3, + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, span), + type_id: result_type_id, + block_id, + }, + ); + } + // Relational and Logical Instructions + Op::LogicalNot => { + inst.expect(4)?; + parse_expr_op!(crate::UnaryOperator::LogicalNot, UNARY)?; + } + Op::LogicalOr => { + inst.expect(5)?; + parse_expr_op!(crate::BinaryOperator::LogicalOr, BINARY)?; + } + Op::LogicalAnd => { + inst.expect(5)?; + parse_expr_op!(crate::BinaryOperator::LogicalAnd, BINARY)?; + } + Op::SGreaterThan | Op::SGreaterThanEqual | Op::SLessThan | Op::SLessThanEqual => { + inst.expect(5)?; + self.parse_expr_int_comparison( + ctx, + &mut emitter, + &mut block, + block_id, + body_idx, + map_binary_operator(inst.op)?, + crate::ScalarKind::Sint, + )?; + } + Op::UGreaterThan | Op::UGreaterThanEqual | Op::ULessThan | Op::ULessThanEqual => { + inst.expect(5)?; + self.parse_expr_int_comparison( + ctx, + &mut emitter, + &mut block, + block_id, + body_idx, + map_binary_operator(inst.op)?, + crate::ScalarKind::Uint, + )?; + } + Op::FOrdEqual + | Op::FUnordEqual + | Op::FOrdNotEqual + | Op::FUnordNotEqual + | Op::FOrdLessThan + | Op::FUnordLessThan + | Op::FOrdGreaterThan + | Op::FUnordGreaterThan + | Op::FOrdLessThanEqual + | Op::FUnordLessThanEqual + | Op::FOrdGreaterThanEqual + | Op::FUnordGreaterThanEqual + | Op::LogicalEqual + | Op::LogicalNotEqual => { + inst.expect(5)?; + let operator = map_binary_operator(inst.op)?; + parse_expr_op!(operator, BINARY)?; + } + Op::Any | Op::All | Op::IsNan | Op::IsInf | Op::IsFinite | Op::IsNormal => { + inst.expect(4)?; + let result_type_id = self.next()?; + let result_id = self.next()?; + let arg_id = self.next()?; + + let arg_lexp = self.lookup_expression.lookup(arg_id)?; + let arg_handle = get_expr_handle!(arg_id, arg_lexp); + + let expr = crate::Expression::Relational { + fun: map_relational_fun(inst.op)?, + argument: arg_handle, + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, span), + type_id: result_type_id, + block_id, + }, + ); + } + Op::Kill => { + inst.expect(1)?; + break Some(crate::Statement::Kill); + } + Op::Unreachable => { + inst.expect(1)?; + break None; + } + Op::Return => { + inst.expect(1)?; + break Some(crate::Statement::Return { value: None }); + } + Op::ReturnValue => { + inst.expect(2)?; + let value_id = self.next()?; + let value_lexp = self.lookup_expression.lookup(value_id)?; + let value_handle = get_expr_handle!(value_id, value_lexp); + break Some(crate::Statement::Return { + value: Some(value_handle), + }); + } + Op::Branch => { + inst.expect(2)?; + let target_id = self.next()?; + + // If this is a branch to a merge or continue block, then + // that ends the current body. + // + // Why can we count on finding an entry here when it's + // needed? SPIR-V requires dominators to appear before + // blocks they dominate, so we will have visited a + // structured control construct's header block before + // anything that could exit it. + if let Some(info) = ctx.mergers.get(&target_id) { + block.extend(emitter.finish(ctx.expressions)); + ctx.blocks.insert(block_id, block); + let body = &mut ctx.bodies[body_idx]; + body.data.push(BodyFragment::BlockId(block_id)); + + merger(body, info); + + return Ok(()); + } + + // If `target_id` has no entry in `ctx.body_for_label`, then + // this must be the only branch to it: + // + // - We've already established that it's not anybody's merge + // block. + // + // - It can't be a switch case. Only switch header blocks + // and other switch cases can branch to a switch case. + // Switch header blocks must dominate all their cases, so + // they must appear in the file before them, and when we + // see `Op::Switch` we populate `ctx.body_for_label` for + // every switch case. + // + // Thus, `target_id` must be a simple extension of the + // current block, which we dominate, so we know we'll + // encounter it later in the file. + ctx.body_for_label.entry(target_id).or_insert(body_idx); + + break None; + } + Op::BranchConditional => { + inst.expect_at_least(4)?; + + let condition = { + let condition_id = self.next()?; + let lexp = self.lookup_expression.lookup(condition_id)?; + get_expr_handle!(condition_id, lexp) + }; + + // HACK(eddyb) Naga doesn't seem to have this helper, + // so it's declared on the fly here for convenience. + #[derive(Copy, Clone)] + struct BranchTarget { + label_id: spirv::Word, + merge_info: Option<MergeBlockInformation>, + } + let branch_target = |label_id| BranchTarget { + label_id, + merge_info: ctx.mergers.get(&label_id).copied(), + }; + + let true_target = branch_target(self.next()?); + let false_target = branch_target(self.next()?); + + // Consume branch weights + for _ in 4..inst.wc { + let _ = self.next()?; + } + + // Handle `OpBranchConditional`s used at the end of a loop + // body's "continuing" section as a "conditional backedge", + // i.e. a `do`-`while` condition, or `break if` in WGSL. + + // HACK(eddyb) this has to go to the parent *twice*, because + // `OpLoopMerge` left the "continuing" section nested in the + // loop body in terms of `parent`, but not `BodyFragment`. + let parent_body_idx = ctx.bodies[body_idx].parent; + let parent_parent_body_idx = ctx.bodies[parent_body_idx].parent; + match ctx.bodies[parent_parent_body_idx].data[..] { + // The `OpLoopMerge`'s `continuing` block and the loop's + // backedge block may not be the same, but they'll both + // belong to the same body. + [.., BodyFragment::Loop { + body: loop_body_idx, + continuing: loop_continuing_idx, + break_if: ref mut break_if_slot @ None, + }] if body_idx == loop_continuing_idx => { + // Try both orderings of break-vs-backedge, because + // SPIR-V is symmetrical here, unlike WGSL `break if`. + let break_if_cond = [true, false].into_iter().find_map(|true_breaks| { + let (break_candidate, backedge_candidate) = if true_breaks { + (true_target, false_target) + } else { + (false_target, true_target) + }; + + if break_candidate.merge_info + != Some(MergeBlockInformation::LoopMerge) + { + return None; + } + + // HACK(eddyb) since Naga doesn't explicitly track + // backedges, this is checking for the outcome of + // `OpLoopMerge` below (even if it looks weird). + let backedge_candidate_is_backedge = + backedge_candidate.merge_info.is_none() + && ctx.body_for_label.get(&backedge_candidate.label_id) + == Some(&loop_body_idx); + if !backedge_candidate_is_backedge { + return None; + } + + Some(if true_breaks { + condition + } else { + ctx.expressions.append( + crate::Expression::Unary { + op: crate::UnaryOperator::LogicalNot, + expr: condition, + }, + span, + ) + }) + }); + + if let Some(break_if_cond) = break_if_cond { + *break_if_slot = Some(break_if_cond); + + // This `OpBranchConditional` ends the "continuing" + // section of the loop body as normal, with the + // `break if` condition having been stashed above. + break None; + } + } + _ => {} + } + + block.extend(emitter.finish(ctx.expressions)); + ctx.blocks.insert(block_id, block); + let body = &mut ctx.bodies[body_idx]; + body.data.push(BodyFragment::BlockId(block_id)); + + let same_target = true_target.label_id == false_target.label_id; + + // Start a body block for the `accept` branch. + let accept = ctx.bodies.len(); + let mut accept_block = Body::with_parent(body_idx); + + // If the `OpBranchConditional` target is somebody else's + // merge or continue block, then put a `Break` or `Continue` + // statement in this new body block. + if let Some(info) = true_target.merge_info { + merger( + match same_target { + true => &mut ctx.bodies[body_idx], + false => &mut accept_block, + }, + &info, + ) + } else { + // Note the body index for the block we're branching to. + let prev = ctx.body_for_label.insert( + true_target.label_id, + match same_target { + true => body_idx, + false => accept, + }, + ); + debug_assert!(prev.is_none()); + } + + if same_target { + return Ok(()); + } + + ctx.bodies.push(accept_block); + + // Handle the `reject` branch just like the `accept` block. + let reject = ctx.bodies.len(); + let mut reject_block = Body::with_parent(body_idx); + + if let Some(info) = false_target.merge_info { + merger(&mut reject_block, &info) + } else { + let prev = ctx.body_for_label.insert(false_target.label_id, reject); + debug_assert!(prev.is_none()); + } + + ctx.bodies.push(reject_block); + + let body = &mut ctx.bodies[body_idx]; + body.data.push(BodyFragment::If { + condition, + accept, + reject, + }); + + return Ok(()); + } + Op::Switch => { + inst.expect_at_least(3)?; + let selector = self.next()?; + let default_id = self.next()?; + + // If the previous instruction was a `OpSelectionMerge` then we must + // promote the `MergeBlockInformation` to a `SwitchMerge` + if let Some(merge) = selection_merge_block { + ctx.mergers + .insert(merge, MergeBlockInformation::SwitchMerge); + } + + let default = ctx.bodies.len(); + ctx.bodies.push(Body::with_parent(body_idx)); + ctx.body_for_label.entry(default_id).or_insert(default); + + let selector_lexp = &self.lookup_expression[&selector]; + let selector_lty = self.lookup_type.lookup(selector_lexp.type_id)?; + let selector_handle = get_expr_handle!(selector, selector_lexp); + let selector = match ctx.type_arena[selector_lty.handle].inner { + crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Uint, + width: _, + }) => { + // IR expects a signed integer, so do a bitcast + ctx.expressions.append( + crate::Expression::As { + kind: crate::ScalarKind::Sint, + expr: selector_handle, + convert: None, + }, + span, + ) + } + crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Sint, + width: _, + }) => selector_handle, + ref other => unimplemented!("Unexpected selector {:?}", other), + }; + + // Clear past switch cases to prevent them from entering this one + self.switch_cases.clear(); + + for _ in 0..(inst.wc - 3) / 2 { + let literal = self.next()?; + let target = self.next()?; + + let case_body_idx = ctx.bodies.len(); + + // Check if any previous case already used this target block id, if so + // group them together to reorder them later so that no weird + // fallthrough cases happen. + if let Some(&mut (_, ref mut literals)) = self.switch_cases.get_mut(&target) + { + literals.push(literal as i32); + continue; + } + + let mut body = Body::with_parent(body_idx); + + if let Some(info) = ctx.mergers.get(&target) { + merger(&mut body, info); + } + + ctx.bodies.push(body); + ctx.body_for_label.entry(target).or_insert(case_body_idx); + + // Register this target block id as already having been processed and + // the respective body index assigned and the first case value + self.switch_cases + .insert(target, (case_body_idx, vec![literal as i32])); + } + + // Loop trough the collected target blocks creating a new case for each + // literal pointing to it, only one case will have the true body and all the + // others will be empty fallthrough so that they all execute the same body + // without duplicating code. + // + // Since `switch_cases` is an indexmap the order of insertion is preserved + // this is needed because spir-v defines fallthrough order in the switch + // instruction. + let mut cases = Vec::with_capacity((inst.wc as usize - 3) / 2); + for &(case_body_idx, ref literals) in self.switch_cases.values() { + let value = literals[0]; + + for &literal in literals.iter().skip(1) { + let empty_body_idx = ctx.bodies.len(); + let body = Body::with_parent(body_idx); + + ctx.bodies.push(body); + + cases.push((literal, empty_body_idx)); + } + + cases.push((value, case_body_idx)); + } + + block.extend(emitter.finish(ctx.expressions)); + + let body = &mut ctx.bodies[body_idx]; + ctx.blocks.insert(block_id, block); + // Make sure the vector has space for at least two more allocations + body.data.reserve(2); + body.data.push(BodyFragment::BlockId(block_id)); + body.data.push(BodyFragment::Switch { + selector, + cases, + default, + }); + + return Ok(()); + } + Op::SelectionMerge => { + inst.expect(3)?; + let merge_block_id = self.next()?; + // TODO: Selection Control Mask + let _selection_control = self.next()?; + + // Indicate that the merge block is a continuation of the + // current `Body`. + ctx.body_for_label.entry(merge_block_id).or_insert(body_idx); + + // Let subsequent branches to the merge block know that + // they've reached the end of the selection construct. + ctx.mergers + .insert(merge_block_id, MergeBlockInformation::SelectionMerge); + + selection_merge_block = Some(merge_block_id); + } + Op::LoopMerge => { + inst.expect_at_least(4)?; + let merge_block_id = self.next()?; + let continuing = self.next()?; + + // TODO: Loop Control Parameters + for _ in 0..inst.wc - 3 { + self.next()?; + } + + // Indicate that the merge block is a continuation of the + // current `Body`. + ctx.body_for_label.entry(merge_block_id).or_insert(body_idx); + // Let subsequent branches to the merge block know that + // they're `Break` statements. + ctx.mergers + .insert(merge_block_id, MergeBlockInformation::LoopMerge); + + let loop_body_idx = ctx.bodies.len(); + ctx.bodies.push(Body::with_parent(body_idx)); + + let continue_idx = ctx.bodies.len(); + // The continue block inherits the scope of the loop body + ctx.bodies.push(Body::with_parent(loop_body_idx)); + ctx.body_for_label.entry(continuing).or_insert(continue_idx); + // Let subsequent branches to the continue block know that + // they're `Continue` statements. + ctx.mergers + .insert(continuing, MergeBlockInformation::LoopContinue); + + // The loop header always belongs to the loop body + ctx.body_for_label.insert(block_id, loop_body_idx); + + let parent_body = &mut ctx.bodies[body_idx]; + parent_body.data.push(BodyFragment::Loop { + body: loop_body_idx, + continuing: continue_idx, + break_if: None, + }); + body_idx = loop_body_idx; + } + Op::DPdxCoarse => { + parse_expr_op!( + crate::DerivativeAxis::X, + crate::DerivativeControl::Coarse, + DERIVATIVE + )?; + } + Op::DPdyCoarse => { + parse_expr_op!( + crate::DerivativeAxis::Y, + crate::DerivativeControl::Coarse, + DERIVATIVE + )?; + } + Op::FwidthCoarse => { + parse_expr_op!( + crate::DerivativeAxis::Width, + crate::DerivativeControl::Coarse, + DERIVATIVE + )?; + } + Op::DPdxFine => { + parse_expr_op!( + crate::DerivativeAxis::X, + crate::DerivativeControl::Fine, + DERIVATIVE + )?; + } + Op::DPdyFine => { + parse_expr_op!( + crate::DerivativeAxis::Y, + crate::DerivativeControl::Fine, + DERIVATIVE + )?; + } + Op::FwidthFine => { + parse_expr_op!( + crate::DerivativeAxis::Width, + crate::DerivativeControl::Fine, + DERIVATIVE + )?; + } + Op::DPdx => { + parse_expr_op!( + crate::DerivativeAxis::X, + crate::DerivativeControl::None, + DERIVATIVE + )?; + } + Op::DPdy => { + parse_expr_op!( + crate::DerivativeAxis::Y, + crate::DerivativeControl::None, + DERIVATIVE + )?; + } + Op::Fwidth => { + parse_expr_op!( + crate::DerivativeAxis::Width, + crate::DerivativeControl::None, + DERIVATIVE + )?; + } + Op::ArrayLength => { + inst.expect(5)?; + let result_type_id = self.next()?; + let result_id = self.next()?; + let structure_id = self.next()?; + let member_index = self.next()?; + + // We're assuming that the validation pass, if it's run, will catch if the + // wrong types or parameters are supplied here. + + let structure_ptr = self.lookup_expression.lookup(structure_id)?; + let structure_handle = get_expr_handle!(structure_id, structure_ptr); + + let member_ptr = ctx.expressions.append( + crate::Expression::AccessIndex { + base: structure_handle, + index: member_index, + }, + span, + ); + + let length = ctx + .expressions + .append(crate::Expression::ArrayLength(member_ptr), span); + + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: length, + type_id: result_type_id, + block_id, + }, + ); + } + Op::CopyMemory => { + inst.expect_at_least(3)?; + let target_id = self.next()?; + let source_id = self.next()?; + let _memory_access = if inst.wc != 3 { + inst.expect(4)?; + spirv::MemoryAccess::from_bits(self.next()?) + .ok_or(Error::InvalidParameter(Op::CopyMemory))? + } else { + spirv::MemoryAccess::NONE + }; + + // TODO: check if the source and target types are the same? + let target = self.lookup_expression.lookup(target_id)?; + let target_handle = get_expr_handle!(target_id, target); + let source = self.lookup_expression.lookup(source_id)?; + let source_handle = get_expr_handle!(source_id, source); + + // This operation is practically the same as loading and then storing, I think. + let value_expr = ctx.expressions.append( + crate::Expression::Load { + pointer: source_handle, + }, + span, + ); + + block.extend(emitter.finish(ctx.expressions)); + block.push( + crate::Statement::Store { + pointer: target_handle, + value: value_expr, + }, + span, + ); + + emitter.start(ctx.expressions); + } + Op::ControlBarrier => { + inst.expect(4)?; + let exec_scope_id = self.next()?; + let _mem_scope_raw = self.next()?; + let semantics_id = self.next()?; + let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; + let semantics_const = self.lookup_constant.lookup(semantics_id)?; + + let exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle) + .ok_or(Error::InvalidBarrierScope(exec_scope_id))?; + let semantics = resolve_constant(ctx.gctx(), semantics_const.handle) + .ok_or(Error::InvalidBarrierMemorySemantics(semantics_id))?; + + if exec_scope == spirv::Scope::Workgroup as u32 { + let mut flags = crate::Barrier::empty(); + flags.set( + crate::Barrier::STORAGE, + semantics & spirv::MemorySemantics::UNIFORM_MEMORY.bits() != 0, + ); + flags.set( + crate::Barrier::WORK_GROUP, + semantics + & (spirv::MemorySemantics::SUBGROUP_MEMORY + | spirv::MemorySemantics::WORKGROUP_MEMORY) + .bits() + != 0, + ); + block.push(crate::Statement::Barrier(flags), span); + } else { + log::warn!("Unsupported barrier execution scope: {}", exec_scope); + } + } + Op::CopyObject => { + inst.expect(4)?; + let result_type_id = self.next()?; + let result_id = self.next()?; + let operand_id = self.next()?; + + let lookup = self.lookup_expression.lookup(operand_id)?; + let handle = get_expr_handle!(operand_id, lookup); + + self.lookup_expression.insert( + result_id, + LookupExpression { + handle, + type_id: result_type_id, + block_id, + }, + ); + } + _ => return Err(Error::UnsupportedInstruction(self.state, inst.op)), + } + }; + + block.extend(emitter.finish(ctx.expressions)); + if let Some(stmt) = terminator { + block.push(stmt, crate::Span::default()); + } + + // Save this block fragment in `block_ctx.blocks`, and mark it to be + // incorporated into the current body at `Statement` assembly time. + ctx.blocks.insert(block_id, block); + let body = &mut ctx.bodies[body_idx]; + body.data.push(BodyFragment::BlockId(block_id)); + Ok(()) + } + + fn make_expression_storage( + &mut self, + globals: &Arena<crate::GlobalVariable>, + constants: &Arena<crate::Constant>, + ) -> Arena<crate::Expression> { + let mut expressions = Arena::new(); + #[allow(clippy::panic)] + { + assert!(self.lookup_expression.is_empty()); + } + // register global variables + for (&id, var) in self.lookup_variable.iter() { + let span = globals.get_span(var.handle); + let handle = expressions.append(crate::Expression::GlobalVariable(var.handle), span); + self.lookup_expression.insert( + id, + LookupExpression { + type_id: var.type_id, + handle, + // Setting this to an invalid id will cause get_expr_handle + // to default to the main body making sure no load/stores + // are added. + block_id: 0, + }, + ); + } + // register constants + for (&id, con) in self.lookup_constant.iter() { + let span = constants.get_span(con.handle); + let handle = expressions.append(crate::Expression::Constant(con.handle), span); + self.lookup_expression.insert( + id, + LookupExpression { + type_id: con.type_id, + handle, + // Setting this to an invalid id will cause get_expr_handle + // to default to the main body making sure no load/stores + // are added. + block_id: 0, + }, + ); + } + // done + expressions + } + + fn switch(&mut self, state: ModuleState, op: spirv::Op) -> Result<(), Error> { + if state < self.state { + Err(Error::UnsupportedInstruction(self.state, op)) + } else { + self.state = state; + Ok(()) + } + } + + /// Walk the statement tree and patch it in the following cases: + /// 1. Function call targets are replaced by `deferred_function_calls` map + fn patch_statements( + &mut self, + statements: &mut crate::Block, + expressions: &mut Arena<crate::Expression>, + fun_parameter_sampling: &mut [image::SamplingFlags], + ) -> Result<(), Error> { + use crate::Statement as S; + let mut i = 0usize; + while i < statements.len() { + match statements[i] { + S::Emit(_) => {} + S::Block(ref mut block) => { + self.patch_statements(block, expressions, fun_parameter_sampling)?; + } + S::If { + condition: _, + ref mut accept, + ref mut reject, + } => { + self.patch_statements(reject, expressions, fun_parameter_sampling)?; + self.patch_statements(accept, expressions, fun_parameter_sampling)?; + } + S::Switch { + selector: _, + ref mut cases, + } => { + for case in cases.iter_mut() { + self.patch_statements(&mut case.body, expressions, fun_parameter_sampling)?; + } + } + S::Loop { + ref mut body, + ref mut continuing, + break_if: _, + } => { + self.patch_statements(body, expressions, fun_parameter_sampling)?; + self.patch_statements(continuing, expressions, fun_parameter_sampling)?; + } + S::Break + | S::Continue + | S::Return { .. } + | S::Kill + | S::Barrier(_) + | S::Store { .. } + | S::ImageStore { .. } + | S::Atomic { .. } + | S::RayQuery { .. } => {} + S::Call { + function: ref mut callee, + ref arguments, + .. + } => { + let fun_id = self.deferred_function_calls[callee.index()]; + let fun_lookup = self.lookup_function.lookup(fun_id)?; + *callee = fun_lookup.handle; + + // Patch sampling flags + for (arg_index, arg) in arguments.iter().enumerate() { + let flags = match fun_lookup.parameters_sampling.get(arg_index) { + Some(&flags) if !flags.is_empty() => flags, + _ => continue, + }; + + match expressions[*arg] { + crate::Expression::GlobalVariable(handle) => { + if let Some(sampling) = self.handle_sampling.get_mut(&handle) { + *sampling |= flags + } + } + crate::Expression::FunctionArgument(i) => { + fun_parameter_sampling[i as usize] |= flags; + } + ref other => return Err(Error::InvalidGlobalVar(other.clone())), + } + } + } + S::WorkGroupUniformLoad { .. } => unreachable!(), + } + i += 1; + } + Ok(()) + } + + fn patch_function( + &mut self, + handle: Option<Handle<crate::Function>>, + fun: &mut crate::Function, + ) -> Result<(), Error> { + // Note: this search is a bit unfortunate + let (fun_id, mut parameters_sampling) = match handle { + Some(h) => { + let (&fun_id, lookup) = self + .lookup_function + .iter_mut() + .find(|&(_, ref lookup)| lookup.handle == h) + .unwrap(); + (fun_id, mem::take(&mut lookup.parameters_sampling)) + } + None => (0, Vec::new()), + }; + + for (_, expr) in fun.expressions.iter_mut() { + if let crate::Expression::CallResult(ref mut function) = *expr { + let fun_id = self.deferred_function_calls[function.index()]; + *function = self.lookup_function.lookup(fun_id)?.handle; + } + } + + self.patch_statements( + &mut fun.body, + &mut fun.expressions, + &mut parameters_sampling, + )?; + + if let Some(lookup) = self.lookup_function.get_mut(&fun_id) { + lookup.parameters_sampling = parameters_sampling; + } + Ok(()) + } + + pub fn parse(mut self) -> Result<crate::Module, Error> { + let mut module = { + if self.next()? != spirv::MAGIC_NUMBER { + return Err(Error::InvalidHeader); + } + let version_raw = self.next()?; + let generator = self.next()?; + let _bound = self.next()?; + let _schema = self.next()?; + log::info!("Generated by {} version {:x}", generator, version_raw); + crate::Module::default() + }; + + self.layouter.clear(); + self.dummy_functions = Arena::new(); + self.lookup_function.clear(); + self.function_call_graph.clear(); + + loop { + use spirv::Op; + + let inst = match self.next_inst() { + Ok(inst) => inst, + Err(Error::IncompleteData) => break, + Err(other) => return Err(other), + }; + log::debug!("\t{:?} [{}]", inst.op, inst.wc); + + match inst.op { + Op::Capability => self.parse_capability(inst), + Op::Extension => self.parse_extension(inst), + Op::ExtInstImport => self.parse_ext_inst_import(inst), + Op::MemoryModel => self.parse_memory_model(inst), + Op::EntryPoint => self.parse_entry_point(inst), + Op::ExecutionMode => self.parse_execution_mode(inst), + Op::String => self.parse_string(inst), + Op::Source => self.parse_source(inst), + Op::SourceExtension => self.parse_source_extension(inst), + Op::Name => self.parse_name(inst), + Op::MemberName => self.parse_member_name(inst), + Op::ModuleProcessed => self.parse_module_processed(inst), + Op::Decorate => self.parse_decorate(inst), + Op::MemberDecorate => self.parse_member_decorate(inst), + Op::TypeVoid => self.parse_type_void(inst), + Op::TypeBool => self.parse_type_bool(inst, &mut module), + Op::TypeInt => self.parse_type_int(inst, &mut module), + Op::TypeFloat => self.parse_type_float(inst, &mut module), + Op::TypeVector => self.parse_type_vector(inst, &mut module), + Op::TypeMatrix => self.parse_type_matrix(inst, &mut module), + Op::TypeFunction => self.parse_type_function(inst), + Op::TypePointer => self.parse_type_pointer(inst, &mut module), + Op::TypeArray => self.parse_type_array(inst, &mut module), + Op::TypeRuntimeArray => self.parse_type_runtime_array(inst, &mut module), + Op::TypeStruct => self.parse_type_struct(inst, &mut module), + Op::TypeImage => self.parse_type_image(inst, &mut module), + Op::TypeSampledImage => self.parse_type_sampled_image(inst), + Op::TypeSampler => self.parse_type_sampler(inst, &mut module), + Op::Constant | Op::SpecConstant => self.parse_constant(inst, &mut module), + Op::ConstantComposite => self.parse_composite_constant(inst, &mut module), + Op::ConstantNull | Op::Undef => self.parse_null_constant(inst, &mut module), + Op::ConstantTrue => self.parse_bool_constant(inst, true, &mut module), + Op::ConstantFalse => self.parse_bool_constant(inst, false, &mut module), + Op::Variable => self.parse_global_variable(inst, &mut module), + Op::Function => { + self.switch(ModuleState::Function, inst.op)?; + inst.expect(5)?; + self.parse_function(&mut module) + } + _ => Err(Error::UnsupportedInstruction(self.state, inst.op)), //TODO + }?; + } + + log::info!("Patching..."); + { + let mut nodes = petgraph::algo::toposort(&self.function_call_graph, None) + .map_err(|cycle| Error::FunctionCallCycle(cycle.node_id()))?; + nodes.reverse(); // we need dominated first + let mut functions = mem::take(&mut module.functions); + for fun_id in nodes { + if fun_id > !(functions.len() as u32) { + // skip all the fake IDs registered for the entry points + continue; + } + let lookup = self.lookup_function.get_mut(&fun_id).unwrap(); + // take out the function from the old array + let fun = mem::take(&mut functions[lookup.handle]); + // add it to the newly formed arena, and adjust the lookup + lookup.handle = module + .functions + .append(fun, functions.get_span(lookup.handle)); + } + } + // patch all the functions + for (handle, fun) in module.functions.iter_mut() { + self.patch_function(Some(handle), fun)?; + } + for ep in module.entry_points.iter_mut() { + self.patch_function(None, &mut ep.function)?; + } + + // Check all the images and samplers to have consistent comparison property. + for (handle, flags) in self.handle_sampling.drain() { + if !image::patch_comparison_type( + flags, + module.global_variables.get_mut(handle), + &mut module.types, + ) { + return Err(Error::InconsistentComparisonSampling(handle)); + } + } + + if !self.future_decor.is_empty() { + log::warn!("Unused item decorations: {:?}", self.future_decor); + self.future_decor.clear(); + } + if !self.future_member_decor.is_empty() { + log::warn!("Unused member decorations: {:?}", self.future_member_decor); + self.future_member_decor.clear(); + } + + Ok(module) + } + + fn parse_capability(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Capability, inst.op)?; + inst.expect(2)?; + let capability = self.next()?; + let cap = + spirv::Capability::from_u32(capability).ok_or(Error::UnknownCapability(capability))?; + if !SUPPORTED_CAPABILITIES.contains(&cap) { + if self.options.strict_capabilities { + return Err(Error::UnsupportedCapability(cap)); + } else { + log::warn!("Unknown capability {:?}", cap); + } + } + Ok(()) + } + + fn parse_extension(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Extension, inst.op)?; + inst.expect_at_least(2)?; + let (name, left) = self.next_string(inst.wc - 1)?; + if left != 0 { + return Err(Error::InvalidOperand); + } + if !SUPPORTED_EXTENSIONS.contains(&name.as_str()) { + return Err(Error::UnsupportedExtension(name)); + } + Ok(()) + } + + fn parse_ext_inst_import(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Extension, inst.op)?; + inst.expect_at_least(3)?; + let result_id = self.next()?; + let (name, left) = self.next_string(inst.wc - 2)?; + if left != 0 { + return Err(Error::InvalidOperand); + } + if !SUPPORTED_EXT_SETS.contains(&name.as_str()) { + return Err(Error::UnsupportedExtSet(name)); + } + self.ext_glsl_id = Some(result_id); + Ok(()) + } + + fn parse_memory_model(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::MemoryModel, inst.op)?; + inst.expect(3)?; + let _addressing_model = self.next()?; + let _memory_model = self.next()?; + Ok(()) + } + + fn parse_entry_point(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::EntryPoint, inst.op)?; + inst.expect_at_least(4)?; + let exec_model = self.next()?; + let exec_model = spirv::ExecutionModel::from_u32(exec_model) + .ok_or(Error::UnsupportedExecutionModel(exec_model))?; + let function_id = self.next()?; + let (name, left) = self.next_string(inst.wc - 3)?; + let ep = EntryPoint { + stage: match exec_model { + spirv::ExecutionModel::Vertex => crate::ShaderStage::Vertex, + spirv::ExecutionModel::Fragment => crate::ShaderStage::Fragment, + spirv::ExecutionModel::GLCompute => crate::ShaderStage::Compute, + _ => return Err(Error::UnsupportedExecutionModel(exec_model as u32)), + }, + name, + early_depth_test: None, + workgroup_size: [0; 3], + variable_ids: self.data.by_ref().take(left as usize).collect(), + }; + self.lookup_entry_point.insert(function_id, ep); + Ok(()) + } + + fn parse_execution_mode(&mut self, inst: Instruction) -> Result<(), Error> { + use spirv::ExecutionMode; + + self.switch(ModuleState::ExecutionMode, inst.op)?; + inst.expect_at_least(3)?; + + let ep_id = self.next()?; + let mode_id = self.next()?; + let args: Vec<spirv::Word> = self.data.by_ref().take(inst.wc as usize - 3).collect(); + + let ep = self + .lookup_entry_point + .get_mut(&ep_id) + .ok_or(Error::InvalidId(ep_id))?; + let mode = spirv::ExecutionMode::from_u32(mode_id) + .ok_or(Error::UnsupportedExecutionMode(mode_id))?; + + match mode { + ExecutionMode::EarlyFragmentTests => { + if ep.early_depth_test.is_none() { + ep.early_depth_test = Some(crate::EarlyDepthTest { conservative: None }); + } + } + ExecutionMode::DepthUnchanged => { + ep.early_depth_test = Some(crate::EarlyDepthTest { + conservative: Some(crate::ConservativeDepth::Unchanged), + }); + } + ExecutionMode::DepthGreater => { + ep.early_depth_test = Some(crate::EarlyDepthTest { + conservative: Some(crate::ConservativeDepth::GreaterEqual), + }); + } + ExecutionMode::DepthLess => { + ep.early_depth_test = Some(crate::EarlyDepthTest { + conservative: Some(crate::ConservativeDepth::LessEqual), + }); + } + ExecutionMode::DepthReplacing => { + // Ignored because it can be deduced from the IR. + } + ExecutionMode::OriginUpperLeft => { + // Ignored because the other option (OriginLowerLeft) is not valid in Vulkan mode. + } + ExecutionMode::LocalSize => { + ep.workgroup_size = [args[0], args[1], args[2]]; + } + _ => { + return Err(Error::UnsupportedExecutionMode(mode_id)); + } + } + + Ok(()) + } + + fn parse_string(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Source, inst.op)?; + inst.expect_at_least(3)?; + let _id = self.next()?; + let (_name, _) = self.next_string(inst.wc - 2)?; + Ok(()) + } + + fn parse_source(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Source, inst.op)?; + for _ in 1..inst.wc { + let _ = self.next()?; + } + Ok(()) + } + + fn parse_source_extension(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Source, inst.op)?; + inst.expect_at_least(2)?; + let (_name, _) = self.next_string(inst.wc - 1)?; + Ok(()) + } + + fn parse_name(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Name, inst.op)?; + inst.expect_at_least(3)?; + let id = self.next()?; + let (name, left) = self.next_string(inst.wc - 2)?; + if left != 0 { + return Err(Error::InvalidOperand); + } + self.future_decor.entry(id).or_default().name = Some(name); + Ok(()) + } + + fn parse_member_name(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Name, inst.op)?; + inst.expect_at_least(4)?; + let id = self.next()?; + let member = self.next()?; + let (name, left) = self.next_string(inst.wc - 3)?; + if left != 0 { + return Err(Error::InvalidOperand); + } + + self.future_member_decor + .entry((id, member)) + .or_default() + .name = Some(name); + Ok(()) + } + + fn parse_module_processed(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Name, inst.op)?; + inst.expect_at_least(2)?; + let (_info, left) = self.next_string(inst.wc - 1)?; + //Note: string is ignored + if left != 0 { + return Err(Error::InvalidOperand); + } + Ok(()) + } + + fn parse_decorate(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Annotation, inst.op)?; + inst.expect_at_least(3)?; + let id = self.next()?; + let mut dec = self.future_decor.remove(&id).unwrap_or_default(); + self.next_decoration(inst, 2, &mut dec)?; + self.future_decor.insert(id, dec); + Ok(()) + } + + fn parse_member_decorate(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Annotation, inst.op)?; + inst.expect_at_least(4)?; + let id = self.next()?; + let member = self.next()?; + + let mut dec = self + .future_member_decor + .remove(&(id, member)) + .unwrap_or_default(); + self.next_decoration(inst, 3, &mut dec)?; + self.future_member_decor.insert((id, member), dec); + Ok(()) + } + + fn parse_type_void(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Type, inst.op)?; + inst.expect(2)?; + let id = self.next()?; + self.lookup_void_type = Some(id); + Ok(()) + } + + fn parse_type_bool( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + let start = self.data_offset; + self.switch(ModuleState::Type, inst.op)?; + inst.expect(2)?; + let id = self.next()?; + let inner = crate::TypeInner::Scalar(crate::Scalar::BOOL); + self.lookup_type.insert( + id, + LookupType { + handle: module.types.insert( + crate::Type { + name: self.future_decor.remove(&id).and_then(|dec| dec.name), + inner, + }, + self.span_from_with_op(start), + ), + base_id: None, + }, + ); + Ok(()) + } + + fn parse_type_int( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + let start = self.data_offset; + self.switch(ModuleState::Type, inst.op)?; + inst.expect(4)?; + let id = self.next()?; + let width = self.next()?; + let sign = self.next()?; + let inner = crate::TypeInner::Scalar(crate::Scalar { + kind: match sign { + 0 => crate::ScalarKind::Uint, + 1 => crate::ScalarKind::Sint, + _ => return Err(Error::InvalidSign(sign)), + }, + width: map_width(width)?, + }); + self.lookup_type.insert( + id, + LookupType { + handle: module.types.insert( + crate::Type { + name: self.future_decor.remove(&id).and_then(|dec| dec.name), + inner, + }, + self.span_from_with_op(start), + ), + base_id: None, + }, + ); + Ok(()) + } + + fn parse_type_float( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + let start = self.data_offset; + self.switch(ModuleState::Type, inst.op)?; + inst.expect(3)?; + let id = self.next()?; + let width = self.next()?; + let inner = crate::TypeInner::Scalar(crate::Scalar::float(map_width(width)?)); + self.lookup_type.insert( + id, + LookupType { + handle: module.types.insert( + crate::Type { + name: self.future_decor.remove(&id).and_then(|dec| dec.name), + inner, + }, + self.span_from_with_op(start), + ), + base_id: None, + }, + ); + Ok(()) + } + + fn parse_type_vector( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + let start = self.data_offset; + self.switch(ModuleState::Type, inst.op)?; + inst.expect(4)?; + let id = self.next()?; + let type_id = self.next()?; + let type_lookup = self.lookup_type.lookup(type_id)?; + let scalar = match module.types[type_lookup.handle].inner { + crate::TypeInner::Scalar(scalar) => scalar, + _ => return Err(Error::InvalidInnerType(type_id)), + }; + let component_count = self.next()?; + let inner = crate::TypeInner::Vector { + size: map_vector_size(component_count)?, + scalar, + }; + self.lookup_type.insert( + id, + LookupType { + handle: module.types.insert( + crate::Type { + name: self.future_decor.remove(&id).and_then(|dec| dec.name), + inner, + }, + self.span_from_with_op(start), + ), + base_id: Some(type_id), + }, + ); + Ok(()) + } + + fn parse_type_matrix( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + let start = self.data_offset; + self.switch(ModuleState::Type, inst.op)?; + inst.expect(4)?; + let id = self.next()?; + let vector_type_id = self.next()?; + let num_columns = self.next()?; + let decor = self.future_decor.remove(&id); + + let vector_type_lookup = self.lookup_type.lookup(vector_type_id)?; + let inner = match module.types[vector_type_lookup.handle].inner { + crate::TypeInner::Vector { size, scalar } => crate::TypeInner::Matrix { + columns: map_vector_size(num_columns)?, + rows: size, + scalar, + }, + _ => return Err(Error::InvalidInnerType(vector_type_id)), + }; + + self.lookup_type.insert( + id, + LookupType { + handle: module.types.insert( + crate::Type { + name: decor.and_then(|dec| dec.name), + inner, + }, + self.span_from_with_op(start), + ), + base_id: Some(vector_type_id), + }, + ); + Ok(()) + } + + fn parse_type_function(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Type, inst.op)?; + inst.expect_at_least(3)?; + let id = self.next()?; + let return_type_id = self.next()?; + let parameter_type_ids = self.data.by_ref().take(inst.wc as usize - 3).collect(); + self.lookup_function_type.insert( + id, + LookupFunctionType { + parameter_type_ids, + return_type_id, + }, + ); + Ok(()) + } + + fn parse_type_pointer( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + let start = self.data_offset; + self.switch(ModuleState::Type, inst.op)?; + inst.expect(4)?; + let id = self.next()?; + let storage_class = self.next()?; + let type_id = self.next()?; + + let decor = self.future_decor.remove(&id); + let base_lookup_ty = self.lookup_type.lookup(type_id)?; + let base_inner = &module.types[base_lookup_ty.handle].inner; + + let space = if let Some(space) = base_inner.pointer_space() { + space + } else if self + .lookup_storage_buffer_types + .contains_key(&base_lookup_ty.handle) + { + crate::AddressSpace::Storage { + access: crate::StorageAccess::default(), + } + } else { + match map_storage_class(storage_class)? { + ExtendedClass::Global(space) => space, + ExtendedClass::Input | ExtendedClass::Output => crate::AddressSpace::Private, + } + }; + + // We don't support pointers to runtime-sized arrays in the `Uniform` + // storage class with the `BufferBlock` decoration. Runtime-sized arrays + // should be in the StorageBuffer class. + if let crate::TypeInner::Array { + size: crate::ArraySize::Dynamic, + .. + } = *base_inner + { + match space { + crate::AddressSpace::Storage { .. } => {} + _ => { + return Err(Error::UnsupportedRuntimeArrayStorageClass); + } + } + } + + // Don't bother with pointer stuff for `Handle` types. + let lookup_ty = if space == crate::AddressSpace::Handle { + base_lookup_ty.clone() + } else { + LookupType { + handle: module.types.insert( + crate::Type { + name: decor.and_then(|dec| dec.name), + inner: crate::TypeInner::Pointer { + base: base_lookup_ty.handle, + space, + }, + }, + self.span_from_with_op(start), + ), + base_id: Some(type_id), + } + }; + self.lookup_type.insert(id, lookup_ty); + Ok(()) + } + + fn parse_type_array( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + let start = self.data_offset; + self.switch(ModuleState::Type, inst.op)?; + inst.expect(4)?; + let id = self.next()?; + let type_id = self.next()?; + let length_id = self.next()?; + let length_const = self.lookup_constant.lookup(length_id)?; + + let size = resolve_constant(module.to_ctx(), length_const.handle) + .and_then(NonZeroU32::new) + .ok_or(Error::InvalidArraySize(length_const.handle))?; + + let decor = self.future_decor.remove(&id).unwrap_or_default(); + let base = self.lookup_type.lookup(type_id)?.handle; + + self.layouter.update(module.to_ctx()).unwrap(); + + // HACK if the underlying type is an image or a sampler, let's assume + // that we're dealing with a binding-array + // + // Note that it's not a strictly correct assumption, but rather a trade + // off caused by an impedance mismatch between SPIR-V's and Naga's type + // systems - Naga distinguishes between arrays and binding-arrays via + // types (i.e. both kinds of arrays are just different types), while + // SPIR-V distinguishes between them through usage - e.g. given: + // + // ``` + // %image = OpTypeImage %float 2D 2 0 0 2 Rgba16f + // %uint_256 = OpConstant %uint 256 + // %image_array = OpTypeArray %image %uint_256 + // ``` + // + // ``` + // %image = OpTypeImage %float 2D 2 0 0 2 Rgba16f + // %uint_256 = OpConstant %uint 256 + // %image_array = OpTypeArray %image %uint_256 + // %image_array_ptr = OpTypePointer UniformConstant %image_array + // ``` + // + // ... in the first case, `%image_array` should technically correspond + // to `TypeInner::Array`, while in the second case it should say + // `TypeInner::BindingArray` (kinda, depending on whether `%image_array` + // is ever used as a freestanding type or rather always through the + // pointer-indirection). + // + // Anyway, at the moment we don't support other kinds of image / sampler + // arrays than those binding-based, so this assumption is pretty safe + // for now. + let inner = if let crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } = + module.types[base].inner + { + crate::TypeInner::BindingArray { + base, + size: crate::ArraySize::Constant(size), + } + } else { + crate::TypeInner::Array { + base, + size: crate::ArraySize::Constant(size), + stride: match decor.array_stride { + Some(stride) => stride.get(), + None => self.layouter[base].to_stride(), + }, + } + }; + + self.lookup_type.insert( + id, + LookupType { + handle: module.types.insert( + crate::Type { + name: decor.name, + inner, + }, + self.span_from_with_op(start), + ), + base_id: Some(type_id), + }, + ); + Ok(()) + } + + fn parse_type_runtime_array( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + let start = self.data_offset; + self.switch(ModuleState::Type, inst.op)?; + inst.expect(3)?; + let id = self.next()?; + let type_id = self.next()?; + + let decor = self.future_decor.remove(&id).unwrap_or_default(); + let base = self.lookup_type.lookup(type_id)?.handle; + + self.layouter.update(module.to_ctx()).unwrap(); + + // HACK same case as in `parse_type_array()` + let inner = if let crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } = + module.types[base].inner + { + crate::TypeInner::BindingArray { + base: self.lookup_type.lookup(type_id)?.handle, + size: crate::ArraySize::Dynamic, + } + } else { + crate::TypeInner::Array { + base: self.lookup_type.lookup(type_id)?.handle, + size: crate::ArraySize::Dynamic, + stride: match decor.array_stride { + Some(stride) => stride.get(), + None => self.layouter[base].to_stride(), + }, + } + }; + + self.lookup_type.insert( + id, + LookupType { + handle: module.types.insert( + crate::Type { + name: decor.name, + inner, + }, + self.span_from_with_op(start), + ), + base_id: Some(type_id), + }, + ); + Ok(()) + } + + fn parse_type_struct( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + let start = self.data_offset; + self.switch(ModuleState::Type, inst.op)?; + inst.expect_at_least(2)?; + let id = self.next()?; + let parent_decor = self.future_decor.remove(&id); + let is_storage_buffer = parent_decor + .as_ref() + .map_or(false, |decor| decor.storage_buffer); + + self.layouter.update(module.to_ctx()).unwrap(); + + let mut members = Vec::<crate::StructMember>::with_capacity(inst.wc as usize - 2); + let mut member_lookups = Vec::with_capacity(members.capacity()); + let mut storage_access = crate::StorageAccess::empty(); + let mut span = 0; + let mut alignment = Alignment::ONE; + for i in 0..u32::from(inst.wc) - 2 { + let type_id = self.next()?; + let ty = self.lookup_type.lookup(type_id)?.handle; + let decor = self + .future_member_decor + .remove(&(id, i)) + .unwrap_or_default(); + + storage_access |= decor.flags.to_storage_access(); + + member_lookups.push(LookupMember { + type_id, + row_major: decor.matrix_major == Some(Majority::Row), + }); + + let member_alignment = self.layouter[ty].alignment; + span = member_alignment.round_up(span); + alignment = member_alignment.max(alignment); + + let binding = decor.io_binding().ok(); + if let Some(offset) = decor.offset { + span = offset; + } + let offset = span; + + span += self.layouter[ty].size; + + let inner = &module.types[ty].inner; + if let crate::TypeInner::Matrix { + columns, + rows, + scalar, + } = *inner + { + if let Some(stride) = decor.matrix_stride { + let expected_stride = Alignment::from(rows) * scalar.width as u32; + if stride.get() != expected_stride { + return Err(Error::UnsupportedMatrixStride { + stride: stride.get(), + columns: columns as u8, + rows: rows as u8, + width: scalar.width, + }); + } + } + } + + members.push(crate::StructMember { + name: decor.name, + ty, + binding, + offset, + }); + } + + span = alignment.round_up(span); + + let inner = crate::TypeInner::Struct { span, members }; + + let ty_handle = module.types.insert( + crate::Type { + name: parent_decor.and_then(|dec| dec.name), + inner, + }, + self.span_from_with_op(start), + ); + + if is_storage_buffer { + self.lookup_storage_buffer_types + .insert(ty_handle, storage_access); + } + for (i, member_lookup) in member_lookups.into_iter().enumerate() { + self.lookup_member + .insert((ty_handle, i as u32), member_lookup); + } + self.lookup_type.insert( + id, + LookupType { + handle: ty_handle, + base_id: None, + }, + ); + Ok(()) + } + + fn parse_type_image( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + let start = self.data_offset; + self.switch(ModuleState::Type, inst.op)?; + inst.expect(9)?; + + let id = self.next()?; + let sample_type_id = self.next()?; + let dim = self.next()?; + let is_depth = self.next()?; + let is_array = self.next()? != 0; + let is_msaa = self.next()? != 0; + let _is_sampled = self.next()?; + let format = self.next()?; + + let dim = map_image_dim(dim)?; + let decor = self.future_decor.remove(&id).unwrap_or_default(); + + // ensure there is a type for texture coordinate without extra components + module.types.insert( + crate::Type { + name: None, + inner: { + let scalar = crate::Scalar::F32; + match dim.required_coordinate_size() { + None => crate::TypeInner::Scalar(scalar), + Some(size) => crate::TypeInner::Vector { size, scalar }, + } + }, + }, + Default::default(), + ); + + let base_handle = self.lookup_type.lookup(sample_type_id)?.handle; + let kind = module.types[base_handle] + .inner + .scalar_kind() + .ok_or(Error::InvalidImageBaseType(base_handle))?; + + let inner = crate::TypeInner::Image { + class: if is_depth == 1 { + crate::ImageClass::Depth { multi: is_msaa } + } else if format != 0 { + crate::ImageClass::Storage { + format: map_image_format(format)?, + access: crate::StorageAccess::default(), + } + } else { + crate::ImageClass::Sampled { + kind, + multi: is_msaa, + } + }, + dim, + arrayed: is_array, + }; + + let handle = module.types.insert( + crate::Type { + name: decor.name, + inner, + }, + self.span_from_with_op(start), + ); + + self.lookup_type.insert( + id, + LookupType { + handle, + base_id: Some(sample_type_id), + }, + ); + Ok(()) + } + + fn parse_type_sampled_image(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Type, inst.op)?; + inst.expect(3)?; + let id = self.next()?; + let image_id = self.next()?; + self.lookup_type.insert( + id, + LookupType { + handle: self.lookup_type.lookup(image_id)?.handle, + base_id: Some(image_id), + }, + ); + Ok(()) + } + + fn parse_type_sampler( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + let start = self.data_offset; + self.switch(ModuleState::Type, inst.op)?; + inst.expect(2)?; + let id = self.next()?; + let decor = self.future_decor.remove(&id).unwrap_or_default(); + let handle = module.types.insert( + crate::Type { + name: decor.name, + inner: crate::TypeInner::Sampler { comparison: false }, + }, + self.span_from_with_op(start), + ); + self.lookup_type.insert( + id, + LookupType { + handle, + base_id: None, + }, + ); + Ok(()) + } + + fn parse_constant( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + let start = self.data_offset; + self.switch(ModuleState::Type, inst.op)?; + inst.expect_at_least(4)?; + let type_id = self.next()?; + let id = self.next()?; + let type_lookup = self.lookup_type.lookup(type_id)?; + let ty = type_lookup.handle; + + let literal = match module.types[ty].inner { + crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Uint, + width, + }) => { + let low = self.next()?; + match width { + 4 => crate::Literal::U32(low), + _ => return Err(Error::InvalidTypeWidth(width as u32)), + } + } + crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Sint, + width, + }) => { + let low = self.next()?; + match width { + 4 => crate::Literal::I32(low as i32), + 8 => { + inst.expect(5)?; + let high = self.next()?; + crate::Literal::I64((u64::from(high) << 32 | u64::from(low)) as i64) + } + _ => return Err(Error::InvalidTypeWidth(width as u32)), + } + } + crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Float, + width, + }) => { + let low = self.next()?; + match width { + 4 => crate::Literal::F32(f32::from_bits(low)), + 8 => { + inst.expect(5)?; + let high = self.next()?; + crate::Literal::F64(f64::from_bits( + (u64::from(high) << 32) | u64::from(low), + )) + } + _ => return Err(Error::InvalidTypeWidth(width as u32)), + } + } + _ => return Err(Error::UnsupportedType(type_lookup.handle)), + }; + + let decor = self.future_decor.remove(&id).unwrap_or_default(); + + let span = self.span_from_with_op(start); + + let init = module + .const_expressions + .append(crate::Expression::Literal(literal), span); + self.lookup_constant.insert( + id, + LookupConstant { + handle: module.constants.append( + crate::Constant { + r#override: decor.specialization(), + name: decor.name, + ty, + init, + }, + span, + ), + type_id, + }, + ); + Ok(()) + } + + fn parse_composite_constant( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + let start = self.data_offset; + self.switch(ModuleState::Type, inst.op)?; + inst.expect_at_least(3)?; + let type_id = self.next()?; + let id = self.next()?; + + let type_lookup = self.lookup_type.lookup(type_id)?; + let ty = type_lookup.handle; + + let mut components = Vec::with_capacity(inst.wc as usize - 3); + for _ in 0..components.capacity() { + let start = self.data_offset; + let component_id = self.next()?; + let span = self.span_from_with_op(start); + let constant = self.lookup_constant.lookup(component_id)?; + let expr = module + .const_expressions + .append(crate::Expression::Constant(constant.handle), span); + components.push(expr); + } + + let decor = self.future_decor.remove(&id).unwrap_or_default(); + + let span = self.span_from_with_op(start); + + let init = module + .const_expressions + .append(crate::Expression::Compose { ty, components }, span); + self.lookup_constant.insert( + id, + LookupConstant { + handle: module.constants.append( + crate::Constant { + r#override: decor.specialization(), + name: decor.name, + ty, + init, + }, + span, + ), + type_id, + }, + ); + Ok(()) + } + + fn parse_null_constant( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + let start = self.data_offset; + self.switch(ModuleState::Type, inst.op)?; + inst.expect(3)?; + let type_id = self.next()?; + let id = self.next()?; + let span = self.span_from_with_op(start); + + let type_lookup = self.lookup_type.lookup(type_id)?; + let ty = type_lookup.handle; + + let decor = self.future_decor.remove(&id).unwrap_or_default(); + + let init = module + .const_expressions + .append(crate::Expression::ZeroValue(ty), span); + let handle = module.constants.append( + crate::Constant { + r#override: decor.specialization(), + name: decor.name, + ty, + init, + }, + span, + ); + self.lookup_constant + .insert(id, LookupConstant { handle, type_id }); + Ok(()) + } + + fn parse_bool_constant( + &mut self, + inst: Instruction, + value: bool, + module: &mut crate::Module, + ) -> Result<(), Error> { + let start = self.data_offset; + self.switch(ModuleState::Type, inst.op)?; + inst.expect(3)?; + let type_id = self.next()?; + let id = self.next()?; + let span = self.span_from_with_op(start); + + let type_lookup = self.lookup_type.lookup(type_id)?; + let ty = type_lookup.handle; + + let decor = self.future_decor.remove(&id).unwrap_or_default(); + + let init = module.const_expressions.append( + crate::Expression::Literal(crate::Literal::Bool(value)), + span, + ); + self.lookup_constant.insert( + id, + LookupConstant { + handle: module.constants.append( + crate::Constant { + r#override: decor.specialization(), + name: decor.name, + ty, + init, + }, + span, + ), + type_id, + }, + ); + Ok(()) + } + + fn parse_global_variable( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + let start = self.data_offset; + self.switch(ModuleState::Type, inst.op)?; + inst.expect_at_least(4)?; + let type_id = self.next()?; + let id = self.next()?; + let storage_class = self.next()?; + let init = if inst.wc > 4 { + inst.expect(5)?; + let start = self.data_offset; + let init_id = self.next()?; + let span = self.span_from_with_op(start); + let lconst = self.lookup_constant.lookup(init_id)?; + let expr = module + .const_expressions + .append(crate::Expression::Constant(lconst.handle), span); + Some(expr) + } else { + None + }; + let span = self.span_from_with_op(start); + let mut dec = self.future_decor.remove(&id).unwrap_or_default(); + + let original_ty = self.lookup_type.lookup(type_id)?.handle; + let mut ty = original_ty; + + if let crate::TypeInner::Pointer { base, space: _ } = module.types[original_ty].inner { + ty = base; + } + + if let crate::TypeInner::BindingArray { .. } = module.types[original_ty].inner { + // Inside `parse_type_array()` we guess that an array of images or + // samplers must be a binding array, and here we validate that guess + if dec.desc_set.is_none() || dec.desc_index.is_none() { + return Err(Error::NonBindingArrayOfImageOrSamplers); + } + } + + if let crate::TypeInner::Image { + dim, + arrayed, + class: crate::ImageClass::Storage { format, access: _ }, + } = module.types[ty].inner + { + // Storage image types in IR have to contain the access, but not in the SPIR-V. + // The same image type in SPIR-V can be used (and has to be used) for multiple images. + // So we copy the type out and apply the variable access decorations. + let access = dec.flags.to_storage_access(); + + ty = module.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Image { + dim, + arrayed, + class: crate::ImageClass::Storage { format, access }, + }, + }, + Default::default(), + ); + } + + let ext_class = match self.lookup_storage_buffer_types.get(&ty) { + Some(&access) => ExtendedClass::Global(crate::AddressSpace::Storage { access }), + None => map_storage_class(storage_class)?, + }; + + // Fix empty name for gl_PerVertex struct generated by glslang + if let crate::TypeInner::Pointer { .. } = module.types[original_ty].inner { + if ext_class == ExtendedClass::Input || ext_class == ExtendedClass::Output { + if let Some(ref dec_name) = dec.name { + if dec_name.is_empty() { + dec.name = Some("perVertexStruct".to_string()) + } + } + } + } + + let (inner, var) = match ext_class { + ExtendedClass::Global(mut space) => { + if let crate::AddressSpace::Storage { ref mut access } = space { + *access &= dec.flags.to_storage_access(); + } + let var = crate::GlobalVariable { + binding: dec.resource_binding(), + name: dec.name, + space, + ty, + init, + }; + (Variable::Global, var) + } + ExtendedClass::Input => { + let binding = dec.io_binding()?; + let mut unsigned_ty = ty; + if let crate::Binding::BuiltIn(built_in) = binding { + let needs_inner_uint = match built_in { + crate::BuiltIn::BaseInstance + | crate::BuiltIn::BaseVertex + | crate::BuiltIn::InstanceIndex + | crate::BuiltIn::SampleIndex + | crate::BuiltIn::VertexIndex + | crate::BuiltIn::PrimitiveIndex + | crate::BuiltIn::LocalInvocationIndex => { + Some(crate::TypeInner::Scalar(crate::Scalar::U32)) + } + crate::BuiltIn::GlobalInvocationId + | crate::BuiltIn::LocalInvocationId + | crate::BuiltIn::WorkGroupId + | crate::BuiltIn::WorkGroupSize => Some(crate::TypeInner::Vector { + size: crate::VectorSize::Tri, + scalar: crate::Scalar::U32, + }), + _ => None, + }; + if let (Some(inner), Some(crate::ScalarKind::Sint)) = + (needs_inner_uint, module.types[ty].inner.scalar_kind()) + { + unsigned_ty = module + .types + .insert(crate::Type { name: None, inner }, Default::default()); + } + } + + let var = crate::GlobalVariable { + name: dec.name.clone(), + space: crate::AddressSpace::Private, + binding: None, + ty, + init: None, + }; + + let inner = Variable::Input(crate::FunctionArgument { + name: dec.name, + ty: unsigned_ty, + binding: Some(binding), + }); + (inner, var) + } + ExtendedClass::Output => { + // For output interface blocks, this would be a structure. + let binding = dec.io_binding().ok(); + let init = match binding { + Some(crate::Binding::BuiltIn(built_in)) => { + match null::generate_default_built_in( + Some(built_in), + ty, + &mut module.const_expressions, + span, + ) { + Ok(handle) => Some(handle), + Err(e) => { + log::warn!("Failed to initialize output built-in: {}", e); + None + } + } + } + Some(crate::Binding::Location { .. }) => None, + None => match module.types[ty].inner { + crate::TypeInner::Struct { ref members, .. } => { + let mut components = Vec::with_capacity(members.len()); + for member in members.iter() { + let built_in = match member.binding { + Some(crate::Binding::BuiltIn(built_in)) => Some(built_in), + _ => None, + }; + let handle = null::generate_default_built_in( + built_in, + member.ty, + &mut module.const_expressions, + span, + )?; + components.push(handle); + } + Some( + module + .const_expressions + .append(crate::Expression::Compose { ty, components }, span), + ) + } + _ => None, + }, + }; + + let var = crate::GlobalVariable { + name: dec.name, + space: crate::AddressSpace::Private, + binding: None, + ty, + init, + }; + let inner = Variable::Output(crate::FunctionResult { ty, binding }); + (inner, var) + } + }; + + let handle = module.global_variables.append(var, span); + + if module.types[ty].inner.can_comparison_sample(module) { + log::debug!("\t\ttracking {:?} for sampling properties", handle); + + self.handle_sampling + .insert(handle, image::SamplingFlags::empty()); + } + + self.lookup_variable.insert( + id, + LookupVariable { + inner, + handle, + type_id, + }, + ); + Ok(()) + } +} + +fn make_index_literal( + ctx: &mut BlockContext, + index: u32, + block: &mut crate::Block, + emitter: &mut crate::proc::Emitter, + index_type: Handle<crate::Type>, + index_type_id: spirv::Word, + span: crate::Span, +) -> Result<Handle<crate::Expression>, Error> { + block.extend(emitter.finish(ctx.expressions)); + + let literal = match ctx.type_arena[index_type].inner.scalar_kind() { + Some(crate::ScalarKind::Uint) => crate::Literal::U32(index), + Some(crate::ScalarKind::Sint) => crate::Literal::I32(index as i32), + _ => return Err(Error::InvalidIndexType(index_type_id)), + }; + let expr = ctx + .expressions + .append(crate::Expression::Literal(literal), span); + + emitter.start(ctx.expressions); + Ok(expr) +} + +fn resolve_constant( + gctx: crate::proc::GlobalCtx, + constant: Handle<crate::Constant>, +) -> Option<u32> { + match gctx.const_expressions[gctx.constants[constant].init] { + crate::Expression::Literal(crate::Literal::U32(id)) => Some(id), + crate::Expression::Literal(crate::Literal::I32(id)) => Some(id as u32), + _ => None, + } +} + +pub fn parse_u8_slice(data: &[u8], options: &Options) -> Result<crate::Module, Error> { + if data.len() % 4 != 0 { + return Err(Error::IncompleteData); + } + + let words = data + .chunks(4) + .map(|c| u32::from_le_bytes(c.try_into().unwrap())); + Frontend::new(words, options).parse() +} + +#[cfg(test)] +mod test { + #[test] + fn parse() { + let bin = vec![ + // Magic number. Version number: 1.0. + 0x03, 0x02, 0x23, 0x07, 0x00, 0x00, 0x01, 0x00, + // Generator number: 0. Bound: 0. + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Reserved word: 0. + 0x00, 0x00, 0x00, 0x00, // OpMemoryModel. Logical. + 0x0e, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, // GLSL450. + 0x01, 0x00, 0x00, 0x00, + ]; + let _ = super::parse_u8_slice(&bin, &Default::default()).unwrap(); + } +} + +/// Helper function to check if `child` is in the scope of `parent` +fn is_parent(mut child: usize, parent: usize, block_ctx: &BlockContext) -> bool { + loop { + if child == parent { + // The child is in the scope parent + break true; + } else if child == 0 { + // Searched finished at the root the child isn't in the parent's body + break false; + } + + child = block_ctx.bodies[child].parent; + } +} diff --git a/third_party/rust/naga/src/front/spv/null.rs b/third_party/rust/naga/src/front/spv/null.rs new file mode 100644 index 0000000000..42cccca80a --- /dev/null +++ b/third_party/rust/naga/src/front/spv/null.rs @@ -0,0 +1,31 @@ +use super::Error; +use crate::arena::{Arena, Handle}; + +/// Create a default value for an output built-in. +pub fn generate_default_built_in( + built_in: Option<crate::BuiltIn>, + ty: Handle<crate::Type>, + const_expressions: &mut Arena<crate::Expression>, + span: crate::Span, +) -> Result<Handle<crate::Expression>, Error> { + let expr = match built_in { + Some(crate::BuiltIn::Position { .. }) => { + let zero = const_expressions + .append(crate::Expression::Literal(crate::Literal::F32(0.0)), span); + let one = const_expressions + .append(crate::Expression::Literal(crate::Literal::F32(1.0)), span); + crate::Expression::Compose { + ty, + components: vec![zero, zero, zero, one], + } + } + Some(crate::BuiltIn::PointSize) => crate::Expression::Literal(crate::Literal::F32(1.0)), + Some(crate::BuiltIn::FragDepth) => crate::Expression::Literal(crate::Literal::F32(0.0)), + Some(crate::BuiltIn::SampleMask) => { + crate::Expression::Literal(crate::Literal::U32(u32::MAX)) + } + // Note: `crate::BuiltIn::ClipDistance` is intentionally left for the default path + _ => crate::Expression::ZeroValue(ty), + }; + Ok(const_expressions.append(expr, span)) +} diff --git a/third_party/rust/naga/src/front/type_gen.rs b/third_party/rust/naga/src/front/type_gen.rs new file mode 100644 index 0000000000..34730c1db5 --- /dev/null +++ b/third_party/rust/naga/src/front/type_gen.rs @@ -0,0 +1,437 @@ +/*! +Type generators. +*/ + +use crate::{arena::Handle, span::Span}; + +impl crate::Module { + /// Populate this module's [`SpecialTypes::ray_desc`] type. + /// + /// [`SpecialTypes::ray_desc`] is the type of the [`descriptor`] operand of + /// an [`Initialize`] [`RayQuery`] statement. In WGSL, it is a struct type + /// referred to as `RayDesc`. + /// + /// Backends consume values of this type to drive platform APIs, so if you + /// change any its fields, you must update the backends to match. Look for + /// backend code dealing with [`RayQueryFunction::Initialize`]. + /// + /// [`SpecialTypes::ray_desc`]: crate::SpecialTypes::ray_desc + /// [`descriptor`]: crate::RayQueryFunction::Initialize::descriptor + /// [`Initialize`]: crate::RayQueryFunction::Initialize + /// [`RayQuery`]: crate::Statement::RayQuery + /// [`RayQueryFunction::Initialize`]: crate::RayQueryFunction::Initialize + pub fn generate_ray_desc_type(&mut self) -> Handle<crate::Type> { + if let Some(handle) = self.special_types.ray_desc { + return handle; + } + + let ty_flag = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar(crate::Scalar::U32), + }, + Span::UNDEFINED, + ); + let ty_scalar = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar(crate::Scalar::F32), + }, + Span::UNDEFINED, + ); + let ty_vector = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Vector { + size: crate::VectorSize::Tri, + scalar: crate::Scalar::F32, + }, + }, + Span::UNDEFINED, + ); + + let handle = self.types.insert( + crate::Type { + name: Some("RayDesc".to_string()), + inner: crate::TypeInner::Struct { + members: vec![ + crate::StructMember { + name: Some("flags".to_string()), + ty: ty_flag, + binding: None, + offset: 0, + }, + crate::StructMember { + name: Some("cull_mask".to_string()), + ty: ty_flag, + binding: None, + offset: 4, + }, + crate::StructMember { + name: Some("tmin".to_string()), + ty: ty_scalar, + binding: None, + offset: 8, + }, + crate::StructMember { + name: Some("tmax".to_string()), + ty: ty_scalar, + binding: None, + offset: 12, + }, + crate::StructMember { + name: Some("origin".to_string()), + ty: ty_vector, + binding: None, + offset: 16, + }, + crate::StructMember { + name: Some("dir".to_string()), + ty: ty_vector, + binding: None, + offset: 32, + }, + ], + span: 48, + }, + }, + Span::UNDEFINED, + ); + + self.special_types.ray_desc = Some(handle); + handle + } + + /// Populate this module's [`SpecialTypes::ray_intersection`] type. + /// + /// [`SpecialTypes::ray_intersection`] is the type of a + /// `RayQueryGetIntersection` expression. In WGSL, it is a struct type + /// referred to as `RayIntersection`. + /// + /// Backends construct values of this type based on platform APIs, so if you + /// change any its fields, you must update the backends to match. Look for + /// the backend's handling for [`Expression::RayQueryGetIntersection`]. + /// + /// [`SpecialTypes::ray_intersection`]: crate::SpecialTypes::ray_intersection + /// [`Expression::RayQueryGetIntersection`]: crate::Expression::RayQueryGetIntersection + pub fn generate_ray_intersection_type(&mut self) -> Handle<crate::Type> { + if let Some(handle) = self.special_types.ray_intersection { + return handle; + } + + let ty_flag = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar(crate::Scalar::U32), + }, + Span::UNDEFINED, + ); + let ty_scalar = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar(crate::Scalar::F32), + }, + Span::UNDEFINED, + ); + let ty_barycentrics = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Vector { + size: crate::VectorSize::Bi, + scalar: crate::Scalar::F32, + }, + }, + Span::UNDEFINED, + ); + let ty_bool = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar(crate::Scalar::BOOL), + }, + Span::UNDEFINED, + ); + let ty_transform = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Matrix { + columns: crate::VectorSize::Quad, + rows: crate::VectorSize::Tri, + scalar: crate::Scalar::F32, + }, + }, + Span::UNDEFINED, + ); + + let handle = self.types.insert( + crate::Type { + name: Some("RayIntersection".to_string()), + inner: crate::TypeInner::Struct { + members: vec![ + crate::StructMember { + name: Some("kind".to_string()), + ty: ty_flag, + binding: None, + offset: 0, + }, + crate::StructMember { + name: Some("t".to_string()), + ty: ty_scalar, + binding: None, + offset: 4, + }, + crate::StructMember { + name: Some("instance_custom_index".to_string()), + ty: ty_flag, + binding: None, + offset: 8, + }, + crate::StructMember { + name: Some("instance_id".to_string()), + ty: ty_flag, + binding: None, + offset: 12, + }, + crate::StructMember { + name: Some("sbt_record_offset".to_string()), + ty: ty_flag, + binding: None, + offset: 16, + }, + crate::StructMember { + name: Some("geometry_index".to_string()), + ty: ty_flag, + binding: None, + offset: 20, + }, + crate::StructMember { + name: Some("primitive_index".to_string()), + ty: ty_flag, + binding: None, + offset: 24, + }, + crate::StructMember { + name: Some("barycentrics".to_string()), + ty: ty_barycentrics, + binding: None, + offset: 28, + }, + crate::StructMember { + name: Some("front_face".to_string()), + ty: ty_bool, + binding: None, + offset: 36, + }, + crate::StructMember { + name: Some("object_to_world".to_string()), + ty: ty_transform, + binding: None, + offset: 48, + }, + crate::StructMember { + name: Some("world_to_object".to_string()), + ty: ty_transform, + binding: None, + offset: 112, + }, + ], + span: 176, + }, + }, + Span::UNDEFINED, + ); + + self.special_types.ray_intersection = Some(handle); + handle + } + + /// Populate this module's [`SpecialTypes::predeclared_types`] type and return the handle. + /// + /// [`SpecialTypes::predeclared_types`]: crate::SpecialTypes::predeclared_types + pub fn generate_predeclared_type( + &mut self, + special_type: crate::PredeclaredType, + ) -> Handle<crate::Type> { + use std::fmt::Write; + + if let Some(value) = self.special_types.predeclared_types.get(&special_type) { + return *value; + } + + let ty = match special_type { + crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => { + let bool_ty = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar(crate::Scalar::BOOL), + }, + Span::UNDEFINED, + ); + let scalar_ty = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar(scalar), + }, + Span::UNDEFINED, + ); + + crate::Type { + name: Some(format!( + "__atomic_compare_exchange_result<{:?},{}>", + scalar.kind, scalar.width, + )), + inner: crate::TypeInner::Struct { + members: vec![ + crate::StructMember { + name: Some("old_value".to_string()), + ty: scalar_ty, + binding: None, + offset: 0, + }, + crate::StructMember { + name: Some("exchanged".to_string()), + ty: bool_ty, + binding: None, + offset: 4, + }, + ], + span: 8, + }, + } + } + crate::PredeclaredType::ModfResult { size, width } => { + let float_ty = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar(crate::Scalar::float(width)), + }, + Span::UNDEFINED, + ); + + let (member_ty, second_offset) = if let Some(size) = size { + let vec_ty = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Vector { + size, + scalar: crate::Scalar::float(width), + }, + }, + Span::UNDEFINED, + ); + (vec_ty, size as u32 * width as u32) + } else { + (float_ty, width as u32) + }; + + let mut type_name = "__modf_result_".to_string(); + if let Some(size) = size { + let _ = write!(type_name, "vec{}_", size as u8); + } + let _ = write!(type_name, "f{}", width * 8); + + crate::Type { + name: Some(type_name), + inner: crate::TypeInner::Struct { + members: vec![ + crate::StructMember { + name: Some("fract".to_string()), + ty: member_ty, + binding: None, + offset: 0, + }, + crate::StructMember { + name: Some("whole".to_string()), + ty: member_ty, + binding: None, + offset: second_offset, + }, + ], + span: second_offset * 2, + }, + } + } + crate::PredeclaredType::FrexpResult { size, width } => { + let float_ty = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar(crate::Scalar::float(width)), + }, + Span::UNDEFINED, + ); + + let int_ty = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Sint, + width, + }), + }, + Span::UNDEFINED, + ); + + let (fract_member_ty, exp_member_ty, second_offset) = if let Some(size) = size { + let vec_float_ty = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Vector { + size, + scalar: crate::Scalar::float(width), + }, + }, + Span::UNDEFINED, + ); + let vec_int_ty = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Vector { + size, + scalar: crate::Scalar { + kind: crate::ScalarKind::Sint, + width, + }, + }, + }, + Span::UNDEFINED, + ); + (vec_float_ty, vec_int_ty, size as u32 * width as u32) + } else { + (float_ty, int_ty, width as u32) + }; + + let mut type_name = "__frexp_result_".to_string(); + if let Some(size) = size { + let _ = write!(type_name, "vec{}_", size as u8); + } + let _ = write!(type_name, "f{}", width * 8); + + crate::Type { + name: Some(type_name), + inner: crate::TypeInner::Struct { + members: vec![ + crate::StructMember { + name: Some("fract".to_string()), + ty: fract_member_ty, + binding: None, + offset: 0, + }, + crate::StructMember { + name: Some("exp".to_string()), + ty: exp_member_ty, + binding: None, + offset: second_offset, + }, + ], + span: second_offset * 2, + }, + } + } + }; + + let handle = self.types.insert(ty, Span::UNDEFINED); + self.special_types + .predeclared_types + .insert(special_type, handle); + handle + } +} diff --git a/third_party/rust/naga/src/front/wgsl/error.rs b/third_party/rust/naga/src/front/wgsl/error.rs new file mode 100644 index 0000000000..07e68f8dd9 --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/error.rs @@ -0,0 +1,775 @@ +use crate::front::wgsl::parse::lexer::Token; +use crate::front::wgsl::Scalar; +use crate::proc::{Alignment, ConstantEvaluatorError, ResolveError}; +use crate::{SourceLocation, Span}; +use codespan_reporting::diagnostic::{Diagnostic, Label}; +use codespan_reporting::files::SimpleFile; +use codespan_reporting::term; +use std::borrow::Cow; +use std::ops::Range; +use termcolor::{ColorChoice, NoColor, StandardStream}; +use thiserror::Error; + +#[derive(Clone, Debug)] +pub struct ParseError { + message: String, + labels: Vec<(Span, Cow<'static, str>)>, + notes: Vec<String>, +} + +impl ParseError { + pub fn labels(&self) -> impl ExactSizeIterator<Item = (Span, &str)> + '_ { + self.labels + .iter() + .map(|&(span, ref msg)| (span, msg.as_ref())) + } + + pub fn message(&self) -> &str { + &self.message + } + + fn diagnostic(&self) -> Diagnostic<()> { + let diagnostic = Diagnostic::error() + .with_message(self.message.to_string()) + .with_labels( + self.labels + .iter() + .filter_map(|label| label.0.to_range().map(|range| (label, range))) + .map(|(label, range)| { + Label::primary((), range).with_message(label.1.to_string()) + }) + .collect(), + ) + .with_notes( + self.notes + .iter() + .map(|note| format!("note: {note}")) + .collect(), + ); + diagnostic + } + + /// Emits a summary of the error to standard error stream. + pub fn emit_to_stderr(&self, source: &str) { + self.emit_to_stderr_with_path(source, "wgsl") + } + + /// Emits a summary of the error to standard error stream. + pub fn emit_to_stderr_with_path<P>(&self, source: &str, path: P) + where + P: AsRef<std::path::Path>, + { + let path = path.as_ref().display().to_string(); + let files = SimpleFile::new(path, source); + let config = codespan_reporting::term::Config::default(); + let writer = StandardStream::stderr(ColorChoice::Auto); + term::emit(&mut writer.lock(), &config, &files, &self.diagnostic()) + .expect("cannot write error"); + } + + /// Emits a summary of the error to a string. + pub fn emit_to_string(&self, source: &str) -> String { + self.emit_to_string_with_path(source, "wgsl") + } + + /// Emits a summary of the error to a string. + pub fn emit_to_string_with_path<P>(&self, source: &str, path: P) -> String + where + P: AsRef<std::path::Path>, + { + let path = path.as_ref().display().to_string(); + let files = SimpleFile::new(path, source); + let config = codespan_reporting::term::Config::default(); + let mut writer = NoColor::new(Vec::new()); + term::emit(&mut writer, &config, &files, &self.diagnostic()).expect("cannot write error"); + String::from_utf8(writer.into_inner()).unwrap() + } + + /// Returns a [`SourceLocation`] for the first label in the error message. + pub fn location(&self, source: &str) -> Option<SourceLocation> { + self.labels.get(0).map(|label| label.0.location(source)) + } +} + +impl std::fmt::Display for ParseError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.message) + } +} + +impl std::error::Error for ParseError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + None + } +} + +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum ExpectedToken<'a> { + Token(Token<'a>), + Identifier, + /// Expected: constant, parenthesized expression, identifier + PrimaryExpression, + /// Expected: assignment, increment/decrement expression + Assignment, + /// Expected: 'case', 'default', '}' + SwitchItem, + /// Expected: ',', ')' + WorkgroupSizeSeparator, + /// Expected: 'struct', 'let', 'var', 'type', ';', 'fn', eof + GlobalItem, + /// Expected a type. + Type, + /// Access of `var`, `let`, `const`. + Variable, + /// Access of a function + Function, +} + +#[derive(Clone, Copy, Debug, Error, PartialEq)] +pub enum NumberError { + #[error("invalid numeric literal format")] + Invalid, + #[error("numeric literal not representable by target type")] + NotRepresentable, + #[error("unimplemented f16 type")] + UnimplementedF16, +} + +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum InvalidAssignmentType { + Other, + Swizzle, + ImmutableBinding(Span), +} + +#[derive(Clone, Debug)] +pub enum Error<'a> { + Unexpected(Span, ExpectedToken<'a>), + UnexpectedComponents(Span), + UnexpectedOperationInConstContext(Span), + BadNumber(Span, NumberError), + BadMatrixScalarKind(Span, Scalar), + BadAccessor(Span), + BadTexture(Span), + BadTypeCast { + span: Span, + from_type: String, + to_type: String, + }, + BadTextureSampleType { + span: Span, + scalar: Scalar, + }, + BadIncrDecrReferenceType(Span), + InvalidResolve(ResolveError), + InvalidForInitializer(Span), + /// A break if appeared outside of a continuing block + InvalidBreakIf(Span), + InvalidGatherComponent(Span), + InvalidConstructorComponentType(Span, i32), + InvalidIdentifierUnderscore(Span), + ReservedIdentifierPrefix(Span), + UnknownAddressSpace(Span), + RepeatedAttribute(Span), + UnknownAttribute(Span), + UnknownBuiltin(Span), + UnknownAccess(Span), + UnknownIdent(Span, &'a str), + UnknownScalarType(Span), + UnknownType(Span), + UnknownStorageFormat(Span), + UnknownConservativeDepth(Span), + SizeAttributeTooLow(Span, u32), + AlignAttributeTooLow(Span, Alignment), + NonPowerOfTwoAlignAttribute(Span), + InconsistentBinding(Span), + TypeNotConstructible(Span), + TypeNotInferable(Span), + InitializationTypeMismatch { + name: Span, + expected: String, + got: String, + }, + MissingType(Span), + MissingAttribute(&'static str, Span), + InvalidAtomicPointer(Span), + InvalidAtomicOperandType(Span), + InvalidRayQueryPointer(Span), + Pointer(&'static str, Span), + NotPointer(Span), + NotReference(&'static str, Span), + InvalidAssignment { + span: Span, + ty: InvalidAssignmentType, + }, + ReservedKeyword(Span), + /// Redefinition of an identifier (used for both module-scope and local redefinitions). + Redefinition { + /// Span of the identifier in the previous definition. + previous: Span, + + /// Span of the identifier in the new definition. + current: Span, + }, + /// A declaration refers to itself directly. + RecursiveDeclaration { + /// The location of the name of the declaration. + ident: Span, + + /// The point at which it is used. + usage: Span, + }, + /// A declaration refers to itself indirectly, through one or more other + /// definitions. + CyclicDeclaration { + /// The location of the name of some declaration in the cycle. + ident: Span, + + /// The edges of the cycle of references. + /// + /// Each `(decl, reference)` pair indicates that the declaration whose + /// name is `decl` has an identifier at `reference` whose definition is + /// the next declaration in the cycle. The last pair's `reference` is + /// the same identifier as `ident`, above. + path: Vec<(Span, Span)>, + }, + InvalidSwitchValue { + uint: bool, + span: Span, + }, + CalledEntryPoint(Span), + WrongArgumentCount { + span: Span, + expected: Range<u32>, + found: u32, + }, + FunctionReturnsVoid(Span), + InvalidWorkGroupUniformLoad(Span), + Internal(&'static str), + ExpectedConstExprConcreteIntegerScalar(Span), + ExpectedNonNegative(Span), + ExpectedPositiveArrayLength(Span), + MissingWorkgroupSize(Span), + ConstantEvaluatorError(ConstantEvaluatorError, Span), + AutoConversion { + dest_span: Span, + dest_type: String, + source_span: Span, + source_type: String, + }, + AutoConversionLeafScalar { + dest_span: Span, + dest_scalar: String, + source_span: Span, + source_type: String, + }, + ConcretizationFailed { + expr_span: Span, + expr_type: String, + scalar: String, + inner: ConstantEvaluatorError, + }, +} + +impl<'a> Error<'a> { + pub(crate) fn as_parse_error(&self, source: &'a str) -> ParseError { + match *self { + Error::Unexpected(unexpected_span, expected) => { + let expected_str = match expected { + ExpectedToken::Token(token) => { + match token { + Token::Separator(c) => format!("'{c}'"), + Token::Paren(c) => format!("'{c}'"), + Token::Attribute => "@".to_string(), + Token::Number(_) => "number".to_string(), + Token::Word(s) => s.to_string(), + Token::Operation(c) => format!("operation ('{c}')"), + Token::LogicalOperation(c) => format!("logical operation ('{c}')"), + Token::ShiftOperation(c) => format!("bitshift ('{c}{c}')"), + Token::AssignmentOperation(c) if c=='<' || c=='>' => format!("bitshift ('{c}{c}=')"), + Token::AssignmentOperation(c) => format!("operation ('{c}=')"), + Token::IncrementOperation => "increment operation".to_string(), + Token::DecrementOperation => "decrement operation".to_string(), + Token::Arrow => "->".to_string(), + Token::Unknown(c) => format!("unknown ('{c}')"), + Token::Trivia => "trivia".to_string(), + Token::End => "end".to_string(), + } + } + ExpectedToken::Identifier => "identifier".to_string(), + ExpectedToken::PrimaryExpression => "expression".to_string(), + ExpectedToken::Assignment => "assignment or increment/decrement".to_string(), + ExpectedToken::SwitchItem => "switch item ('case' or 'default') or a closing curly bracket to signify the end of the switch statement ('}')".to_string(), + ExpectedToken::WorkgroupSizeSeparator => "workgroup size separator (',') or a closing parenthesis".to_string(), + ExpectedToken::GlobalItem => "global item ('struct', 'const', 'var', 'alias', ';', 'fn') or the end of the file".to_string(), + ExpectedToken::Type => "type".to_string(), + ExpectedToken::Variable => "variable access".to_string(), + ExpectedToken::Function => "function name".to_string(), + }; + ParseError { + message: format!( + "expected {}, found '{}'", + expected_str, &source[unexpected_span], + ), + labels: vec![(unexpected_span, format!("expected {expected_str}").into())], + notes: vec![], + } + } + Error::UnexpectedComponents(bad_span) => ParseError { + message: "unexpected components".to_string(), + labels: vec![(bad_span, "unexpected components".into())], + notes: vec![], + }, + Error::UnexpectedOperationInConstContext(span) => ParseError { + message: "this operation is not supported in a const context".to_string(), + labels: vec![(span, "operation not supported here".into())], + notes: vec![], + }, + Error::BadNumber(bad_span, ref err) => ParseError { + message: format!("{}: `{}`", err, &source[bad_span],), + labels: vec![(bad_span, err.to_string().into())], + notes: vec![], + }, + Error::BadMatrixScalarKind(span, scalar) => ParseError { + message: format!( + "matrix scalar type must be floating-point, but found `{}`", + scalar.to_wgsl() + ), + labels: vec![(span, "must be floating-point (e.g. `f32`)".into())], + notes: vec![], + }, + Error::BadAccessor(accessor_span) => ParseError { + message: format!("invalid field accessor `{}`", &source[accessor_span],), + labels: vec![(accessor_span, "invalid accessor".into())], + notes: vec![], + }, + Error::UnknownIdent(ident_span, ident) => ParseError { + message: format!("no definition in scope for identifier: '{ident}'"), + labels: vec![(ident_span, "unknown identifier".into())], + notes: vec![], + }, + Error::UnknownScalarType(bad_span) => ParseError { + message: format!("unknown scalar type: '{}'", &source[bad_span]), + labels: vec![(bad_span, "unknown scalar type".into())], + notes: vec!["Valid scalar types are f32, f64, i32, u32, bool".into()], + }, + Error::BadTextureSampleType { span, scalar } => ParseError { + message: format!( + "texture sample type must be one of f32, i32 or u32, but found {}", + scalar.to_wgsl() + ), + labels: vec![(span, "must be one of f32, i32 or u32".into())], + notes: vec![], + }, + Error::BadIncrDecrReferenceType(span) => ParseError { + message: + "increment/decrement operation requires reference type to be one of i32 or u32" + .to_string(), + labels: vec![(span, "must be a reference type of i32 or u32".into())], + notes: vec![], + }, + Error::BadTexture(bad_span) => ParseError { + message: format!( + "expected an image, but found '{}' which is not an image", + &source[bad_span] + ), + labels: vec![(bad_span, "not an image".into())], + notes: vec![], + }, + Error::BadTypeCast { + span, + ref from_type, + ref to_type, + } => { + let msg = format!("cannot cast a {from_type} to a {to_type}"); + ParseError { + message: msg.clone(), + labels: vec![(span, msg.into())], + notes: vec![], + } + } + Error::InvalidResolve(ref resolve_error) => ParseError { + message: resolve_error.to_string(), + labels: vec![], + notes: vec![], + }, + Error::InvalidForInitializer(bad_span) => ParseError { + message: format!( + "for(;;) initializer is not an assignment or a function call: '{}'", + &source[bad_span] + ), + labels: vec![(bad_span, "not an assignment or function call".into())], + notes: vec![], + }, + Error::InvalidBreakIf(bad_span) => ParseError { + message: "A break if is only allowed in a continuing block".to_string(), + labels: vec![(bad_span, "not in a continuing block".into())], + notes: vec![], + }, + Error::InvalidGatherComponent(bad_span) => ParseError { + message: format!( + "textureGather component '{}' doesn't exist, must be 0, 1, 2, or 3", + &source[bad_span] + ), + labels: vec![(bad_span, "invalid component".into())], + notes: vec![], + }, + Error::InvalidConstructorComponentType(bad_span, component) => ParseError { + message: format!("invalid type for constructor component at index [{component}]"), + labels: vec![(bad_span, "invalid component type".into())], + notes: vec![], + }, + Error::InvalidIdentifierUnderscore(bad_span) => ParseError { + message: "Identifier can't be '_'".to_string(), + labels: vec![(bad_span, "invalid identifier".into())], + notes: vec![ + "Use phony assignment instead ('_ =' notice the absence of 'let' or 'var')" + .to_string(), + ], + }, + Error::ReservedIdentifierPrefix(bad_span) => ParseError { + message: format!( + "Identifier starts with a reserved prefix: '{}'", + &source[bad_span] + ), + labels: vec![(bad_span, "invalid identifier".into())], + notes: vec![], + }, + Error::UnknownAddressSpace(bad_span) => ParseError { + message: format!("unknown address space: '{}'", &source[bad_span]), + labels: vec![(bad_span, "unknown address space".into())], + notes: vec![], + }, + Error::RepeatedAttribute(bad_span) => ParseError { + message: format!("repeated attribute: '{}'", &source[bad_span]), + labels: vec![(bad_span, "repeated attribute".into())], + notes: vec![], + }, + Error::UnknownAttribute(bad_span) => ParseError { + message: format!("unknown attribute: '{}'", &source[bad_span]), + labels: vec![(bad_span, "unknown attribute".into())], + notes: vec![], + }, + Error::UnknownBuiltin(bad_span) => ParseError { + message: format!("unknown builtin: '{}'", &source[bad_span]), + labels: vec![(bad_span, "unknown builtin".into())], + notes: vec![], + }, + Error::UnknownAccess(bad_span) => ParseError { + message: format!("unknown access: '{}'", &source[bad_span]), + labels: vec![(bad_span, "unknown access".into())], + notes: vec![], + }, + Error::UnknownStorageFormat(bad_span) => ParseError { + message: format!("unknown storage format: '{}'", &source[bad_span]), + labels: vec![(bad_span, "unknown storage format".into())], + notes: vec![], + }, + Error::UnknownConservativeDepth(bad_span) => ParseError { + message: format!("unknown conservative depth: '{}'", &source[bad_span]), + labels: vec![(bad_span, "unknown conservative depth".into())], + notes: vec![], + }, + Error::UnknownType(bad_span) => ParseError { + message: format!("unknown type: '{}'", &source[bad_span]), + labels: vec![(bad_span, "unknown type".into())], + notes: vec![], + }, + Error::SizeAttributeTooLow(bad_span, min_size) => ParseError { + message: format!("struct member size must be at least {min_size}"), + labels: vec![(bad_span, format!("must be at least {min_size}").into())], + notes: vec![], + }, + Error::AlignAttributeTooLow(bad_span, min_align) => ParseError { + message: format!("struct member alignment must be at least {min_align}"), + labels: vec![(bad_span, format!("must be at least {min_align}").into())], + notes: vec![], + }, + Error::NonPowerOfTwoAlignAttribute(bad_span) => ParseError { + message: "struct member alignment must be a power of 2".to_string(), + labels: vec![(bad_span, "must be a power of 2".into())], + notes: vec![], + }, + Error::InconsistentBinding(span) => ParseError { + message: "input/output binding is not consistent".to_string(), + labels: vec![(span, "input/output binding is not consistent".into())], + notes: vec![], + }, + Error::TypeNotConstructible(span) => ParseError { + message: format!("type `{}` is not constructible", &source[span]), + labels: vec![(span, "type is not constructible".into())], + notes: vec![], + }, + Error::TypeNotInferable(span) => ParseError { + message: "type can't be inferred".to_string(), + labels: vec![(span, "type can't be inferred".into())], + notes: vec![], + }, + Error::InitializationTypeMismatch { name, ref expected, ref got } => { + ParseError { + message: format!( + "the type of `{}` is expected to be `{}`, but got `{}`", + &source[name], expected, got, + ), + labels: vec![( + name, + format!("definition of `{}`", &source[name]).into(), + )], + notes: vec![], + } + } + Error::MissingType(name_span) => ParseError { + message: format!("variable `{}` needs a type", &source[name_span]), + labels: vec![( + name_span, + format!("definition of `{}`", &source[name_span]).into(), + )], + notes: vec![], + }, + Error::MissingAttribute(name, name_span) => ParseError { + message: format!( + "variable `{}` needs a '{}' attribute", + &source[name_span], name + ), + labels: vec![( + name_span, + format!("definition of `{}`", &source[name_span]).into(), + )], + notes: vec![], + }, + Error::InvalidAtomicPointer(span) => ParseError { + message: "atomic operation is done on a pointer to a non-atomic".to_string(), + labels: vec![(span, "atomic pointer is invalid".into())], + notes: vec![], + }, + Error::InvalidAtomicOperandType(span) => ParseError { + message: "atomic operand type is inconsistent with the operation".to_string(), + labels: vec![(span, "atomic operand type is invalid".into())], + notes: vec![], + }, + Error::InvalidRayQueryPointer(span) => ParseError { + message: "ray query operation is done on a pointer to a non-ray-query".to_string(), + labels: vec![(span, "ray query pointer is invalid".into())], + notes: vec![], + }, + Error::NotPointer(span) => ParseError { + message: "the operand of the `*` operator must be a pointer".to_string(), + labels: vec![(span, "expression is not a pointer".into())], + notes: vec![], + }, + Error::NotReference(what, span) => ParseError { + message: format!("{what} must be a reference"), + labels: vec![(span, "expression is not a reference".into())], + notes: vec![], + }, + Error::InvalidAssignment { span, ty } => { + let (extra_label, notes) = match ty { + InvalidAssignmentType::Swizzle => ( + None, + vec![ + "WGSL does not support assignments to swizzles".into(), + "consider assigning each component individually".into(), + ], + ), + InvalidAssignmentType::ImmutableBinding(binding_span) => ( + Some((binding_span, "this is an immutable binding".into())), + vec![format!( + "consider declaring '{}' with `var` instead of `let`", + &source[binding_span] + )], + ), + InvalidAssignmentType::Other => (None, vec![]), + }; + + ParseError { + message: "invalid left-hand side of assignment".into(), + labels: std::iter::once((span, "cannot assign to this expression".into())) + .chain(extra_label) + .collect(), + notes, + } + } + Error::Pointer(what, span) => ParseError { + message: format!("{what} must not be a pointer"), + labels: vec![(span, "expression is a pointer".into())], + notes: vec![], + }, + Error::ReservedKeyword(name_span) => ParseError { + message: format!("name `{}` is a reserved keyword", &source[name_span]), + labels: vec![( + name_span, + format!("definition of `{}`", &source[name_span]).into(), + )], + notes: vec![], + }, + Error::Redefinition { previous, current } => ParseError { + message: format!("redefinition of `{}`", &source[current]), + labels: vec![ + ( + current, + format!("redefinition of `{}`", &source[current]).into(), + ), + ( + previous, + format!("previous definition of `{}`", &source[previous]).into(), + ), + ], + notes: vec![], + }, + Error::RecursiveDeclaration { ident, usage } => ParseError { + message: format!("declaration of `{}` is recursive", &source[ident]), + labels: vec![(ident, "".into()), (usage, "uses itself here".into())], + notes: vec![], + }, + Error::CyclicDeclaration { ident, ref path } => ParseError { + message: format!("declaration of `{}` is cyclic", &source[ident]), + labels: path + .iter() + .enumerate() + .flat_map(|(i, &(ident, usage))| { + [ + (ident, "".into()), + ( + usage, + if i == path.len() - 1 { + "ending the cycle".into() + } else { + format!("uses `{}`", &source[ident]).into() + }, + ), + ] + }) + .collect(), + notes: vec![], + }, + Error::InvalidSwitchValue { uint, span } => ParseError { + message: "invalid switch value".to_string(), + labels: vec![( + span, + if uint { + "expected unsigned integer" + } else { + "expected signed integer" + } + .into(), + )], + notes: vec![if uint { + format!("suffix the integer with a `u`: '{}u'", &source[span]) + } else { + let span = span.to_range().unwrap(); + format!( + "remove the `u` suffix: '{}'", + &source[span.start..span.end - 1] + ) + }], + }, + Error::CalledEntryPoint(span) => ParseError { + message: "entry point cannot be called".to_string(), + labels: vec![(span, "entry point cannot be called".into())], + notes: vec![], + }, + Error::WrongArgumentCount { + span, + ref expected, + found, + } => ParseError { + message: format!( + "wrong number of arguments: expected {}, found {}", + if expected.len() < 2 { + format!("{}", expected.start) + } else { + format!("{}..{}", expected.start, expected.end) + }, + found + ), + labels: vec![(span, "wrong number of arguments".into())], + notes: vec![], + }, + Error::FunctionReturnsVoid(span) => ParseError { + message: "function does not return any value".to_string(), + labels: vec![(span, "".into())], + notes: vec![ + "perhaps you meant to call the function in a separate statement?".into(), + ], + }, + Error::InvalidWorkGroupUniformLoad(span) => ParseError { + message: "incorrect type passed to workgroupUniformLoad".into(), + labels: vec![(span, "".into())], + notes: vec!["passed type must be a workgroup pointer".into()], + }, + Error::Internal(message) => ParseError { + message: "internal WGSL front end error".to_string(), + labels: vec![], + notes: vec![message.into()], + }, + Error::ExpectedConstExprConcreteIntegerScalar(span) => ParseError { + message: "must be a const-expression that resolves to a concrete integer scalar (u32 or i32)".to_string(), + labels: vec![(span, "must resolve to u32 or i32".into())], + notes: vec![], + }, + Error::ExpectedNonNegative(span) => ParseError { + message: "must be non-negative (>= 0)".to_string(), + labels: vec![(span, "must be non-negative".into())], + notes: vec![], + }, + Error::ExpectedPositiveArrayLength(span) => ParseError { + message: "array element count must be positive (> 0)".to_string(), + labels: vec![(span, "must be positive".into())], + notes: vec![], + }, + Error::ConstantEvaluatorError(ref e, span) => ParseError { + message: e.to_string(), + labels: vec![(span, "see msg".into())], + notes: vec![], + }, + Error::MissingWorkgroupSize(span) => ParseError { + message: "workgroup size is missing on compute shader entry point".to_string(), + labels: vec![( + span, + "must be paired with a @workgroup_size attribute".into(), + )], + notes: vec![], + }, + Error::AutoConversion { dest_span, ref dest_type, source_span, ref source_type } => ParseError { + message: format!("automatic conversions cannot convert `{source_type}` to `{dest_type}`"), + labels: vec![ + ( + dest_span, + format!("a value of type {dest_type} is required here").into(), + ), + ( + source_span, + format!("this expression has type {source_type}").into(), + ) + ], + notes: vec![], + }, + Error::AutoConversionLeafScalar { dest_span, ref dest_scalar, source_span, ref source_type } => ParseError { + message: format!("automatic conversions cannot convert elements of `{source_type}` to `{dest_scalar}`"), + labels: vec![ + ( + dest_span, + format!("a value with elements of type {dest_scalar} is required here").into(), + ), + ( + source_span, + format!("this expression has type {source_type}").into(), + ) + ], + notes: vec![], + }, + Error::ConcretizationFailed { expr_span, ref expr_type, ref scalar, ref inner } => ParseError { + message: format!("failed to convert expression to a concrete type: {}", inner), + labels: vec![ + ( + expr_span, + format!("this expression has type {}", expr_type).into(), + ) + ], + notes: vec![ + format!("the expression should have been converted to have {} scalar type", scalar), + ] + }, + } + } +} diff --git a/third_party/rust/naga/src/front/wgsl/index.rs b/third_party/rust/naga/src/front/wgsl/index.rs new file mode 100644 index 0000000000..a5524fe8f1 --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/index.rs @@ -0,0 +1,193 @@ +use super::Error; +use crate::front::wgsl::parse::ast; +use crate::{FastHashMap, Handle, Span}; + +/// A `GlobalDecl` list in which each definition occurs before all its uses. +pub struct Index<'a> { + dependency_order: Vec<Handle<ast::GlobalDecl<'a>>>, +} + +impl<'a> Index<'a> { + /// Generate an `Index` for the given translation unit. + /// + /// Perform a topological sort on `tu`'s global declarations, placing + /// referents before the definitions that refer to them. + /// + /// Return an error if the graph of references between declarations contains + /// any cycles. + pub fn generate(tu: &ast::TranslationUnit<'a>) -> Result<Self, Error<'a>> { + // Produce a map from global definitions' names to their `Handle<GlobalDecl>`s. + // While doing so, reject conflicting definitions. + let mut globals = FastHashMap::with_capacity_and_hasher(tu.decls.len(), Default::default()); + for (handle, decl) in tu.decls.iter() { + let ident = decl_ident(decl); + let name = ident.name; + if let Some(old) = globals.insert(name, handle) { + return Err(Error::Redefinition { + previous: decl_ident(&tu.decls[old]).span, + current: ident.span, + }); + } + } + + let len = tu.decls.len(); + let solver = DependencySolver { + globals: &globals, + module: tu, + visited: vec![false; len], + temp_visited: vec![false; len], + path: Vec::new(), + out: Vec::with_capacity(len), + }; + let dependency_order = solver.solve()?; + + Ok(Self { dependency_order }) + } + + /// Iterate over `GlobalDecl`s, visiting each definition before all its uses. + /// + /// Produce handles for all of the `GlobalDecl`s of the `TranslationUnit` + /// passed to `Index::generate`, ordered so that a given declaration is + /// produced before any other declaration that uses it. + pub fn visit_ordered(&self) -> impl Iterator<Item = Handle<ast::GlobalDecl<'a>>> + '_ { + self.dependency_order.iter().copied() + } +} + +/// An edge from a reference to its referent in the current depth-first +/// traversal. +/// +/// This is like `ast::Dependency`, except that we've determined which +/// `GlobalDecl` it refers to. +struct ResolvedDependency<'a> { + /// The referent of some identifier used in the current declaration. + decl: Handle<ast::GlobalDecl<'a>>, + + /// Where that use occurs within the current declaration. + usage: Span, +} + +/// Local state for ordering a `TranslationUnit`'s module-scope declarations. +/// +/// Values of this type are used temporarily by `Index::generate` +/// to perform a depth-first sort on the declarations. +/// Technically, what we want is a topological sort, but a depth-first sort +/// has one key benefit - it's much more efficient in storing +/// the path of each node for error generation. +struct DependencySolver<'source, 'temp> { + /// A map from module-scope definitions' names to their handles. + globals: &'temp FastHashMap<&'source str, Handle<ast::GlobalDecl<'source>>>, + + /// The translation unit whose declarations we're ordering. + module: &'temp ast::TranslationUnit<'source>, + + /// For each handle, whether we have pushed it onto `out` yet. + visited: Vec<bool>, + + /// For each handle, whether it is an predecessor in the current depth-first + /// traversal. This is used to detect cycles in the reference graph. + temp_visited: Vec<bool>, + + /// The current path in our depth-first traversal. Used for generating + /// error messages for non-trivial reference cycles. + path: Vec<ResolvedDependency<'source>>, + + /// The list of declaration handles, with declarations before uses. + out: Vec<Handle<ast::GlobalDecl<'source>>>, +} + +impl<'a> DependencySolver<'a, '_> { + /// Produce the sorted list of declaration handles, and check for cycles. + fn solve(mut self) -> Result<Vec<Handle<ast::GlobalDecl<'a>>>, Error<'a>> { + for (id, _) in self.module.decls.iter() { + if self.visited[id.index()] { + continue; + } + + self.dfs(id)?; + } + + Ok(self.out) + } + + /// Ensure that all declarations used by `id` have been added to the + /// ordering, and then append `id` itself. + fn dfs(&mut self, id: Handle<ast::GlobalDecl<'a>>) -> Result<(), Error<'a>> { + let decl = &self.module.decls[id]; + let id_usize = id.index(); + + self.temp_visited[id_usize] = true; + for dep in decl.dependencies.iter() { + if let Some(&dep_id) = self.globals.get(dep.ident) { + self.path.push(ResolvedDependency { + decl: dep_id, + usage: dep.usage, + }); + let dep_id_usize = dep_id.index(); + + if self.temp_visited[dep_id_usize] { + // Found a cycle. + return if dep_id == id { + // A declaration refers to itself directly. + Err(Error::RecursiveDeclaration { + ident: decl_ident(decl).span, + usage: dep.usage, + }) + } else { + // A declaration refers to itself indirectly, through + // one or more other definitions. Report the entire path + // of references. + let start_at = self + .path + .iter() + .rev() + .enumerate() + .find_map(|(i, dep)| (dep.decl == dep_id).then_some(i)) + .unwrap_or(0); + + Err(Error::CyclicDeclaration { + ident: decl_ident(&self.module.decls[dep_id]).span, + path: self.path[start_at..] + .iter() + .map(|curr_dep| { + let curr_id = curr_dep.decl; + let curr_decl = &self.module.decls[curr_id]; + + (decl_ident(curr_decl).span, curr_dep.usage) + }) + .collect(), + }) + }; + } else if !self.visited[dep_id_usize] { + self.dfs(dep_id)?; + } + + // Remove this edge from the current path. + self.path.pop(); + } + + // Ignore unresolved identifiers; they may be predeclared objects. + } + + // Remove this node from the current path. + self.temp_visited[id_usize] = false; + + // Now everything this declaration uses has been visited, and is already + // present in `out`. That means we we can append this one to the + // ordering, and mark it as visited. + self.out.push(id); + self.visited[id_usize] = true; + + Ok(()) + } +} + +const fn decl_ident<'a>(decl: &ast::GlobalDecl<'a>) -> ast::Ident<'a> { + match decl.kind { + ast::GlobalDeclKind::Fn(ref f) => f.name, + ast::GlobalDeclKind::Var(ref v) => v.name, + ast::GlobalDeclKind::Const(ref c) => c.name, + ast::GlobalDeclKind::Struct(ref s) => s.name, + ast::GlobalDeclKind::Type(ref t) => t.name, + } +} diff --git a/third_party/rust/naga/src/front/wgsl/lower/construction.rs b/third_party/rust/naga/src/front/wgsl/lower/construction.rs new file mode 100644 index 0000000000..de0d11d227 --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/lower/construction.rs @@ -0,0 +1,616 @@ +use std::num::NonZeroU32; + +use crate::front::wgsl::parse::ast; +use crate::{Handle, Span}; + +use crate::front::wgsl::error::Error; +use crate::front::wgsl::lower::{ExpressionContext, Lowerer}; + +/// A cooked form of `ast::ConstructorType` that uses Naga types whenever +/// possible. +enum Constructor<T> { + /// A vector construction whose component type is inferred from the + /// argument: `vec3(1.0)`. + PartialVector { size: crate::VectorSize }, + + /// A matrix construction whose component type is inferred from the + /// argument: `mat2x2(1,2,3,4)`. + PartialMatrix { + columns: crate::VectorSize, + rows: crate::VectorSize, + }, + + /// An array whose component type and size are inferred from the arguments: + /// `array(3,4,5)`. + PartialArray, + + /// A known Naga type. + /// + /// When we match on this type, we need to see the `TypeInner` here, but at + /// the point that we build this value we'll still need mutable access to + /// the module later. To avoid borrowing from the module, the type parameter + /// `T` is `Handle<Type>` initially. Then we use `borrow_inner` to produce a + /// version holding a tuple `(Handle<Type>, &TypeInner)`. + Type(T), +} + +impl Constructor<Handle<crate::Type>> { + /// Return an equivalent `Constructor` value that includes borrowed + /// `TypeInner` values alongside any type handles. + /// + /// The returned form is more convenient to match on, since the patterns + /// can actually see what the handle refers to. + fn borrow_inner( + self, + module: &crate::Module, + ) -> Constructor<(Handle<crate::Type>, &crate::TypeInner)> { + match self { + Constructor::PartialVector { size } => Constructor::PartialVector { size }, + Constructor::PartialMatrix { columns, rows } => { + Constructor::PartialMatrix { columns, rows } + } + Constructor::PartialArray => Constructor::PartialArray, + Constructor::Type(handle) => Constructor::Type((handle, &module.types[handle].inner)), + } + } +} + +impl Constructor<(Handle<crate::Type>, &crate::TypeInner)> { + fn to_error_string(&self, ctx: &ExpressionContext) -> String { + match *self { + Self::PartialVector { size } => { + format!("vec{}<?>", size as u32,) + } + Self::PartialMatrix { columns, rows } => { + format!("mat{}x{}<?>", columns as u32, rows as u32,) + } + Self::PartialArray => "array<?, ?>".to_string(), + Self::Type((handle, _inner)) => handle.to_wgsl(&ctx.module.to_ctx()), + } + } +} + +enum Components<'a> { + None, + One { + component: Handle<crate::Expression>, + span: Span, + ty_inner: &'a crate::TypeInner, + }, + Many { + components: Vec<Handle<crate::Expression>>, + spans: Vec<Span>, + }, +} + +impl Components<'_> { + fn into_components_vec(self) -> Vec<Handle<crate::Expression>> { + match self { + Self::None => vec![], + Self::One { component, .. } => vec![component], + Self::Many { components, .. } => components, + } + } +} + +impl<'source, 'temp> Lowerer<'source, 'temp> { + /// Generate Naga IR for a type constructor expression. + /// + /// The `constructor` value represents the head of the constructor + /// expression, which is at least a hint of which type is being built; if + /// it's one of the `Partial` variants, we need to consider the argument + /// types as well. + /// + /// This is used for [`Construct`] expressions, but also for [`Call`] + /// expressions, once we've determined that the "callable" (in WGSL spec + /// terms) is actually a type. + /// + /// [`Construct`]: ast::Expression::Construct + /// [`Call`]: ast::Expression::Call + pub fn construct( + &mut self, + span: Span, + constructor: &ast::ConstructorType<'source>, + ty_span: Span, + components: &[Handle<ast::Expression<'source>>], + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result<Handle<crate::Expression>, Error<'source>> { + use crate::proc::TypeResolution as Tr; + + let constructor_h = self.constructor(constructor, ctx)?; + + let components = match *components { + [] => Components::None, + [component] => { + let span = ctx.ast_expressions.get_span(component); + let component = self.expression_for_abstract(component, ctx)?; + let ty_inner = super::resolve_inner!(ctx, component); + + Components::One { + component, + span, + ty_inner, + } + } + ref ast_components @ [_, _, ..] => { + let components = ast_components + .iter() + .map(|&expr| self.expression_for_abstract(expr, ctx)) + .collect::<Result<_, _>>()?; + let spans = ast_components + .iter() + .map(|&expr| ctx.ast_expressions.get_span(expr)) + .collect(); + + for &component in &components { + ctx.grow_types(component)?; + } + + Components::Many { components, spans } + } + }; + + // Even though we computed `constructor` above, wait until now to borrow + // a reference to the `TypeInner`, so that the component-handling code + // above can have mutable access to the type arena. + let constructor = constructor_h.borrow_inner(ctx.module); + + let expr; + match (components, constructor) { + // Empty constructor + (Components::None, dst_ty) => match dst_ty { + Constructor::Type((result_ty, _)) => { + return ctx.append_expression(crate::Expression::ZeroValue(result_ty), span) + } + Constructor::PartialVector { .. } + | Constructor::PartialMatrix { .. } + | Constructor::PartialArray => { + // We have no arguments from which to infer the result type, so + // partial constructors aren't acceptable here. + return Err(Error::TypeNotInferable(ty_span)); + } + }, + + // Scalar constructor & conversion (scalar -> scalar) + ( + Components::One { + component, + ty_inner: &crate::TypeInner::Scalar { .. }, + .. + }, + Constructor::Type((_, &crate::TypeInner::Scalar(scalar))), + ) => { + expr = crate::Expression::As { + expr: component, + kind: scalar.kind, + convert: Some(scalar.width), + }; + } + + // Vector conversion (vector -> vector) + ( + Components::One { + component, + ty_inner: &crate::TypeInner::Vector { size: src_size, .. }, + .. + }, + Constructor::Type(( + _, + &crate::TypeInner::Vector { + size: dst_size, + scalar: dst_scalar, + }, + )), + ) if dst_size == src_size => { + expr = crate::Expression::As { + expr: component, + kind: dst_scalar.kind, + convert: Some(dst_scalar.width), + }; + } + + // Vector conversion (vector -> vector) - partial + ( + Components::One { + component, + ty_inner: &crate::TypeInner::Vector { size: src_size, .. }, + .. + }, + Constructor::PartialVector { size: dst_size }, + ) if dst_size == src_size => { + // This is a trivial conversion: the sizes match, and a Partial + // constructor doesn't specify a scalar type, so nothing can + // possibly happen. + return Ok(component); + } + + // Matrix conversion (matrix -> matrix) + ( + Components::One { + component, + ty_inner: + &crate::TypeInner::Matrix { + columns: src_columns, + rows: src_rows, + .. + }, + .. + }, + Constructor::Type(( + _, + &crate::TypeInner::Matrix { + columns: dst_columns, + rows: dst_rows, + scalar: dst_scalar, + }, + )), + ) if dst_columns == src_columns && dst_rows == src_rows => { + expr = crate::Expression::As { + expr: component, + kind: dst_scalar.kind, + convert: Some(dst_scalar.width), + }; + } + + // Matrix conversion (matrix -> matrix) - partial + ( + Components::One { + component, + ty_inner: + &crate::TypeInner::Matrix { + columns: src_columns, + rows: src_rows, + .. + }, + .. + }, + Constructor::PartialMatrix { + columns: dst_columns, + rows: dst_rows, + }, + ) if dst_columns == src_columns && dst_rows == src_rows => { + // This is a trivial conversion: the sizes match, and a Partial + // constructor doesn't specify a scalar type, so nothing can + // possibly happen. + return Ok(component); + } + + // Vector constructor (splat) - infer type + ( + Components::One { + component, + ty_inner: &crate::TypeInner::Scalar { .. }, + .. + }, + Constructor::PartialVector { size }, + ) => { + expr = crate::Expression::Splat { + size, + value: component, + }; + } + + // Vector constructor (splat) + ( + Components::One { + mut component, + ty_inner: &crate::TypeInner::Scalar(_), + .. + }, + Constructor::Type((_, &crate::TypeInner::Vector { size, scalar })), + ) => { + ctx.convert_slice_to_common_leaf_scalar( + std::slice::from_mut(&mut component), + scalar, + )?; + expr = crate::Expression::Splat { + size, + value: component, + }; + } + + // Vector constructor (by elements), partial + ( + Components::Many { + mut components, + spans, + }, + Constructor::PartialVector { size }, + ) => { + let consensus_scalar = + ctx.automatic_conversion_consensus(&components) + .map_err(|index| { + Error::InvalidConstructorComponentType(spans[index], index as i32) + })?; + ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?; + let inner = consensus_scalar.to_inner_vector(size); + let ty = ctx.ensure_type_exists(inner); + expr = crate::Expression::Compose { ty, components }; + } + + // Vector constructor (by elements), full type given + ( + Components::Many { mut components, .. }, + Constructor::Type((ty, &crate::TypeInner::Vector { scalar, .. })), + ) => { + ctx.try_automatic_conversions_for_vector(&mut components, scalar, ty_span)?; + expr = crate::Expression::Compose { ty, components }; + } + + // Matrix constructor (by elements), partial + ( + Components::Many { + mut components, + spans, + }, + Constructor::PartialMatrix { columns, rows }, + ) if components.len() == columns as usize * rows as usize => { + let consensus_scalar = + ctx.automatic_conversion_consensus(&components) + .map_err(|index| { + Error::InvalidConstructorComponentType(spans[index], index as i32) + })?; + // We actually only accept floating-point elements. + let consensus_scalar = consensus_scalar + .automatic_conversion_combine(crate::Scalar::ABSTRACT_FLOAT) + .unwrap_or(consensus_scalar); + ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?; + let vec_ty = ctx.ensure_type_exists(consensus_scalar.to_inner_vector(rows)); + + let components = components + .chunks(rows as usize) + .map(|vec_components| { + ctx.append_expression( + crate::Expression::Compose { + ty: vec_ty, + components: Vec::from(vec_components), + }, + Default::default(), + ) + }) + .collect::<Result<Vec<_>, _>>()?; + + let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix { + columns, + rows, + scalar: consensus_scalar, + }); + expr = crate::Expression::Compose { ty, components }; + } + + // Matrix constructor (by elements), type given + ( + Components::Many { mut components, .. }, + Constructor::Type(( + _, + &crate::TypeInner::Matrix { + columns, + rows, + scalar, + }, + )), + ) if components.len() == columns as usize * rows as usize => { + let element = Tr::Value(crate::TypeInner::Scalar(scalar)); + ctx.try_automatic_conversions_slice(&mut components, &element, ty_span)?; + let vec_ty = ctx.ensure_type_exists(scalar.to_inner_vector(rows)); + + let components = components + .chunks(rows as usize) + .map(|vec_components| { + ctx.append_expression( + crate::Expression::Compose { + ty: vec_ty, + components: Vec::from(vec_components), + }, + Default::default(), + ) + }) + .collect::<Result<Vec<_>, _>>()?; + + let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix { + columns, + rows, + scalar, + }); + expr = crate::Expression::Compose { ty, components }; + } + + // Matrix constructor (by columns), partial + ( + Components::Many { + mut components, + spans, + }, + Constructor::PartialMatrix { columns, rows }, + ) => { + let consensus_scalar = + ctx.automatic_conversion_consensus(&components) + .map_err(|index| { + Error::InvalidConstructorComponentType(spans[index], index as i32) + })?; + ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?; + let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix { + columns, + rows, + scalar: consensus_scalar, + }); + expr = crate::Expression::Compose { ty, components }; + } + + // Matrix constructor (by columns), type given + ( + Components::Many { mut components, .. }, + Constructor::Type(( + ty, + &crate::TypeInner::Matrix { + columns: _, + rows, + scalar, + }, + )), + ) => { + let component_ty = crate::TypeInner::Vector { size: rows, scalar }; + ctx.try_automatic_conversions_slice( + &mut components, + &Tr::Value(component_ty), + ty_span, + )?; + expr = crate::Expression::Compose { ty, components }; + } + + // Array constructor - infer type + (components, Constructor::PartialArray) => { + let mut components = components.into_components_vec(); + if let Ok(consensus_scalar) = ctx.automatic_conversion_consensus(&components) { + // Note that this will *not* necessarily convert all the + // components to the same type! The `automatic_conversion_consensus` + // method only considers the parameters' leaf scalar + // types; the parameters themselves could be any mix of + // vectors, matrices, and scalars. + // + // But *if* it is possible for this array construction + // expression to be well-typed at all, then all the + // parameters must have the same type constructors (vec, + // matrix, scalar) applied to their leaf scalars, so + // reconciling their scalars is always the right thing to + // do. And if this array construction is not well-typed, + // these conversions will not make it so, and we can let + // validation catch the error. + ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?; + } else { + // There's no consensus scalar. Emit the `Compose` + // expression anyway, and let validation catch the problem. + } + + let base = ctx.register_type(components[0])?; + + let inner = crate::TypeInner::Array { + base, + size: crate::ArraySize::Constant( + NonZeroU32::new(u32::try_from(components.len()).unwrap()).unwrap(), + ), + stride: { + self.layouter.update(ctx.module.to_ctx()).unwrap(); + self.layouter[base].to_stride() + }, + }; + let ty = ctx.ensure_type_exists(inner); + + expr = crate::Expression::Compose { ty, components }; + } + + // Array constructor, explicit type + (components, Constructor::Type((ty, &crate::TypeInner::Array { base, .. }))) => { + let mut components = components.into_components_vec(); + ctx.try_automatic_conversions_slice(&mut components, &Tr::Handle(base), ty_span)?; + expr = crate::Expression::Compose { ty, components }; + } + + // Struct constructor + ( + components, + Constructor::Type((ty, &crate::TypeInner::Struct { ref members, .. })), + ) => { + let mut components = components.into_components_vec(); + let struct_ty_span = ctx.module.types.get_span(ty); + + // Make a vector of the members' type handles in advance, to + // avoid borrowing `members` from `ctx` while we generate + // new code. + let members: Vec<Handle<crate::Type>> = members.iter().map(|m| m.ty).collect(); + + for (component, &ty) in components.iter_mut().zip(&members) { + *component = + ctx.try_automatic_conversions(*component, &Tr::Handle(ty), struct_ty_span)?; + } + expr = crate::Expression::Compose { ty, components }; + } + + // ERRORS + + // Bad conversion (type cast) + (Components::One { span, ty_inner, .. }, constructor) => { + let from_type = ty_inner.to_wgsl(&ctx.module.to_ctx()); + return Err(Error::BadTypeCast { + span, + from_type, + to_type: constructor.to_error_string(ctx), + }); + } + + // Too many parameters for scalar constructor + ( + Components::Many { spans, .. }, + Constructor::Type((_, &crate::TypeInner::Scalar { .. })), + ) => { + let span = spans[1].until(spans.last().unwrap()); + return Err(Error::UnexpectedComponents(span)); + } + + // Other types can't be constructed + _ => return Err(Error::TypeNotConstructible(ty_span)), + } + + let expr = ctx.append_expression(expr, span)?; + Ok(expr) + } + + /// Build a [`Constructor`] for a WGSL construction expression. + /// + /// If `constructor` conveys enough information to determine which Naga [`Type`] + /// we're actually building (i.e., it's not a partial constructor), then + /// ensure the `Type` exists in [`ctx.module`], and return + /// [`Constructor::Type`]. + /// + /// Otherwise, return the [`Constructor`] partial variant corresponding to + /// `constructor`. + /// + /// [`Type`]: crate::Type + /// [`ctx.module`]: ExpressionContext::module + fn constructor<'out>( + &mut self, + constructor: &ast::ConstructorType<'source>, + ctx: &mut ExpressionContext<'source, '_, 'out>, + ) -> Result<Constructor<Handle<crate::Type>>, Error<'source>> { + let handle = match *constructor { + ast::ConstructorType::Scalar(scalar) => { + let ty = ctx.ensure_type_exists(scalar.to_inner_scalar()); + Constructor::Type(ty) + } + ast::ConstructorType::PartialVector { size } => Constructor::PartialVector { size }, + ast::ConstructorType::Vector { size, scalar } => { + let ty = ctx.ensure_type_exists(scalar.to_inner_vector(size)); + Constructor::Type(ty) + } + ast::ConstructorType::PartialMatrix { columns, rows } => { + Constructor::PartialMatrix { columns, rows } + } + ast::ConstructorType::Matrix { + rows, + columns, + width, + } => { + let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix { + columns, + rows, + scalar: crate::Scalar::float(width), + }); + Constructor::Type(ty) + } + ast::ConstructorType::PartialArray => Constructor::PartialArray, + ast::ConstructorType::Array { base, size } => { + let base = self.resolve_ast_type(base, &mut ctx.as_global())?; + let size = self.array_size(size, &mut ctx.as_global())?; + + self.layouter.update(ctx.module.to_ctx()).unwrap(); + let stride = self.layouter[base].to_stride(); + + let ty = ctx.ensure_type_exists(crate::TypeInner::Array { base, size, stride }); + Constructor::Type(ty) + } + ast::ConstructorType::Type(ty) => Constructor::Type(ty), + }; + + Ok(handle) + } +} diff --git a/third_party/rust/naga/src/front/wgsl/lower/conversion.rs b/third_party/rust/naga/src/front/wgsl/lower/conversion.rs new file mode 100644 index 0000000000..2a2690f096 --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/lower/conversion.rs @@ -0,0 +1,503 @@ +//! WGSL's automatic conversions for abstract types. + +use crate::{Handle, Span}; + +impl<'source, 'temp, 'out> super::ExpressionContext<'source, 'temp, 'out> { + /// Try to use WGSL's automatic conversions to convert `expr` to `goal_ty`. + /// + /// If no conversions are necessary, return `expr` unchanged. + /// + /// If automatic conversions cannot convert `expr` to `goal_ty`, return an + /// [`AutoConversion`] error. + /// + /// Although the Load Rule is one of the automatic conversions, this + /// function assumes it has already been applied if appropriate, as + /// indicated by the fact that the Rust type of `expr` is not `Typed<_>`. + /// + /// [`AutoConversion`]: super::Error::AutoConversion + pub fn try_automatic_conversions( + &mut self, + expr: Handle<crate::Expression>, + goal_ty: &crate::proc::TypeResolution, + goal_span: Span, + ) -> Result<Handle<crate::Expression>, super::Error<'source>> { + let expr_span = self.get_expression_span(expr); + // Keep the TypeResolution so we can get type names for + // structs in error messages. + let expr_resolution = super::resolve!(self, expr); + let types = &self.module.types; + let expr_inner = expr_resolution.inner_with(types); + let goal_inner = goal_ty.inner_with(types); + + // If `expr` already has the requested type, we're done. + if expr_inner.equivalent(goal_inner, types) { + return Ok(expr); + } + + let (_expr_scalar, goal_scalar) = + match expr_inner.automatically_converts_to(goal_inner, types) { + Some(scalars) => scalars, + None => { + let gctx = &self.module.to_ctx(); + let source_type = expr_resolution.to_wgsl(gctx); + let dest_type = goal_ty.to_wgsl(gctx); + + return Err(super::Error::AutoConversion { + dest_span: goal_span, + dest_type, + source_span: expr_span, + source_type, + }); + } + }; + + self.convert_leaf_scalar(expr, expr_span, goal_scalar) + } + + /// Try to convert `expr`'s leaf scalar to `goal` using automatic conversions. + /// + /// If no conversions are necessary, return `expr` unchanged. + /// + /// If automatic conversions cannot convert `expr` to `goal_scalar`, return + /// an [`AutoConversionLeafScalar`] error. + /// + /// Although the Load Rule is one of the automatic conversions, this + /// function assumes it has already been applied if appropriate, as + /// indicated by the fact that the Rust type of `expr` is not `Typed<_>`. + /// + /// [`AutoConversionLeafScalar`]: super::Error::AutoConversionLeafScalar + pub fn try_automatic_conversion_for_leaf_scalar( + &mut self, + expr: Handle<crate::Expression>, + goal_scalar: crate::Scalar, + goal_span: Span, + ) -> Result<Handle<crate::Expression>, super::Error<'source>> { + let expr_span = self.get_expression_span(expr); + let expr_resolution = super::resolve!(self, expr); + let types = &self.module.types; + let expr_inner = expr_resolution.inner_with(types); + + let make_error = || { + let gctx = &self.module.to_ctx(); + let source_type = expr_resolution.to_wgsl(gctx); + super::Error::AutoConversionLeafScalar { + dest_span: goal_span, + dest_scalar: goal_scalar.to_wgsl(), + source_span: expr_span, + source_type, + } + }; + + let expr_scalar = match expr_inner.scalar() { + Some(scalar) => scalar, + None => return Err(make_error()), + }; + + if expr_scalar == goal_scalar { + return Ok(expr); + } + + if !expr_scalar.automatically_converts_to(goal_scalar) { + return Err(make_error()); + } + + assert!(expr_scalar.is_abstract()); + + self.convert_leaf_scalar(expr, expr_span, goal_scalar) + } + + fn convert_leaf_scalar( + &mut self, + expr: Handle<crate::Expression>, + expr_span: Span, + goal_scalar: crate::Scalar, + ) -> Result<Handle<crate::Expression>, super::Error<'source>> { + let expr_inner = super::resolve_inner!(self, expr); + if let crate::TypeInner::Array { .. } = *expr_inner { + self.as_const_evaluator() + .cast_array(expr, goal_scalar, expr_span) + .map_err(|err| super::Error::ConstantEvaluatorError(err, expr_span)) + } else { + let cast = crate::Expression::As { + expr, + kind: goal_scalar.kind, + convert: Some(goal_scalar.width), + }; + self.append_expression(cast, expr_span) + } + } + + /// Try to convert `exprs` to `goal_ty` using WGSL's automatic conversions. + pub fn try_automatic_conversions_slice( + &mut self, + exprs: &mut [Handle<crate::Expression>], + goal_ty: &crate::proc::TypeResolution, + goal_span: Span, + ) -> Result<(), super::Error<'source>> { + for expr in exprs.iter_mut() { + *expr = self.try_automatic_conversions(*expr, goal_ty, goal_span)?; + } + + Ok(()) + } + + /// Apply WGSL's automatic conversions to a vector constructor's arguments. + /// + /// When calling a vector constructor like `vec3<f32>(...)`, the parameters + /// can be a mix of scalars and vectors, with the latter being spread out to + /// contribute each of their components as a component of the new value. + /// When the element type is explicit, as with `<f32>` in the example above, + /// WGSL's automatic conversions should convert abstract scalar and vector + /// parameters to the constructor's required scalar type. + pub fn try_automatic_conversions_for_vector( + &mut self, + exprs: &mut [Handle<crate::Expression>], + goal_scalar: crate::Scalar, + goal_span: Span, + ) -> Result<(), super::Error<'source>> { + use crate::proc::TypeResolution as Tr; + use crate::TypeInner as Ti; + let goal_scalar_res = Tr::Value(Ti::Scalar(goal_scalar)); + + for (i, expr) in exprs.iter_mut().enumerate() { + // Keep the TypeResolution so we can get full type names + // in error messages. + let expr_resolution = super::resolve!(self, *expr); + let types = &self.module.types; + let expr_inner = expr_resolution.inner_with(types); + + match *expr_inner { + Ti::Scalar(_) => { + *expr = self.try_automatic_conversions(*expr, &goal_scalar_res, goal_span)?; + } + Ti::Vector { size, scalar: _ } => { + let goal_vector_res = Tr::Value(Ti::Vector { + size, + scalar: goal_scalar, + }); + *expr = self.try_automatic_conversions(*expr, &goal_vector_res, goal_span)?; + } + _ => { + let span = self.get_expression_span(*expr); + return Err(super::Error::InvalidConstructorComponentType( + span, i as i32, + )); + } + } + } + + Ok(()) + } + + /// Convert `expr` to the leaf scalar type `scalar`. + pub fn convert_to_leaf_scalar( + &mut self, + expr: &mut Handle<crate::Expression>, + goal: crate::Scalar, + ) -> Result<(), super::Error<'source>> { + let inner = super::resolve_inner!(self, *expr); + // Do nothing if `inner` doesn't even have leaf scalars; + // it's a type error that validation will catch. + if inner.scalar() != Some(goal) { + let cast = crate::Expression::As { + expr: *expr, + kind: goal.kind, + convert: Some(goal.width), + }; + let expr_span = self.get_expression_span(*expr); + *expr = self.append_expression(cast, expr_span)?; + } + + Ok(()) + } + + /// Convert all expressions in `exprs` to a common scalar type. + /// + /// Note that the caller is responsible for making sure these + /// conversions are actually justified. This function simply + /// generates `As` expressions, regardless of whether they are + /// permitted WGSL automatic conversions. Callers intending to + /// implement automatic conversions need to determine for + /// themselves whether the casts we we generate are justified, + /// perhaps by calling `TypeInner::automatically_converts_to` or + /// `Scalar::automatic_conversion_combine`. + pub fn convert_slice_to_common_leaf_scalar( + &mut self, + exprs: &mut [Handle<crate::Expression>], + goal: crate::Scalar, + ) -> Result<(), super::Error<'source>> { + for expr in exprs.iter_mut() { + self.convert_to_leaf_scalar(expr, goal)?; + } + + Ok(()) + } + + /// Return an expression for the concretized value of `expr`. + /// + /// If `expr` is already concrete, return it unchanged. + pub fn concretize( + &mut self, + mut expr: Handle<crate::Expression>, + ) -> Result<Handle<crate::Expression>, super::Error<'source>> { + let inner = super::resolve_inner!(self, expr); + if let Some(scalar) = inner.automatically_convertible_scalar(&self.module.types) { + let concretized = scalar.concretize(); + if concretized != scalar { + assert!(scalar.is_abstract()); + let expr_span = self.get_expression_span(expr); + expr = self + .as_const_evaluator() + .cast_array(expr, concretized, expr_span) + .map_err(|err| { + // A `TypeResolution` includes the type's full name, if + // it has one. Also, avoid holding the borrow of `inner` + // across the call to `cast_array`. + let expr_type = &self.typifier()[expr]; + super::Error::ConcretizationFailed { + expr_span, + expr_type: expr_type.to_wgsl(&self.module.to_ctx()), + scalar: concretized.to_wgsl(), + inner: err, + } + })?; + } + } + + Ok(expr) + } + + /// Find the consensus scalar of `components` under WGSL's automatic + /// conversions. + /// + /// If `components` can all be converted to any common scalar via + /// WGSL's automatic conversions, return the best such scalar. + /// + /// The `components` slice must not be empty. All elements' types must + /// have been resolved. + /// + /// If `components` are definitely not acceptable as arguments to such + /// constructors, return `Err(i)`, where `i` is the index in + /// `components` of some problematic argument. + /// + /// This function doesn't fully type-check the arguments - it only + /// considers their leaf scalar types. This means it may return `Ok` + /// even when the Naga validator will reject the resulting + /// construction expression later. + pub fn automatic_conversion_consensus<'handle, I>( + &self, + components: I, + ) -> Result<crate::Scalar, usize> + where + I: IntoIterator<Item = &'handle Handle<crate::Expression>>, + I::IntoIter: Clone, // for debugging + { + let types = &self.module.types; + let mut inners = components + .into_iter() + .map(|&c| self.typifier()[c].inner_with(types)); + log::debug!( + "wgsl automatic_conversion_consensus: {:?}", + inners + .clone() + .map(|inner| inner.to_wgsl(&self.module.to_ctx())) + .collect::<Vec<String>>() + ); + let mut best = inners.next().unwrap().scalar().ok_or(0_usize)?; + for (inner, i) in inners.zip(1..) { + let scalar = inner.scalar().ok_or(i)?; + match best.automatic_conversion_combine(scalar) { + Some(new_best) => { + best = new_best; + } + None => return Err(i), + } + } + + log::debug!(" consensus: {:?}", best.to_wgsl()); + Ok(best) + } +} + +impl crate::TypeInner { + /// Determine whether `self` automatically converts to `goal`. + /// + /// If WGSL's automatic conversions (excluding the Load Rule) will + /// convert `self` to `goal`, then return a pair `(from, to)`, + /// where `from` and `to` are the scalar types of the leaf values + /// of `self` and `goal`. + /// + /// This function assumes that `self` and `goal` are different + /// types. Callers should first check whether any conversion is + /// needed at all. + /// + /// If the automatic conversions cannot convert `self` to `goal`, + /// return `None`. + fn automatically_converts_to( + &self, + goal: &Self, + types: &crate::UniqueArena<crate::Type>, + ) -> Option<(crate::Scalar, crate::Scalar)> { + use crate::ScalarKind as Sk; + use crate::TypeInner as Ti; + + // Automatic conversions only change the scalar type of a value's leaves + // (e.g., `vec4<AbstractFloat>` to `vec4<f32>`), never the type + // constructors applied to those scalar types (e.g., never scalar to + // `vec4`, or `vec2` to `vec3`). So first we check that the type + // constructors match, extracting the leaf scalar types in the process. + let expr_scalar; + let goal_scalar; + match (self, goal) { + (&Ti::Scalar(expr), &Ti::Scalar(goal)) => { + expr_scalar = expr; + goal_scalar = goal; + } + ( + &Ti::Vector { + size: expr_size, + scalar: expr, + }, + &Ti::Vector { + size: goal_size, + scalar: goal, + }, + ) if expr_size == goal_size => { + expr_scalar = expr; + goal_scalar = goal; + } + ( + &Ti::Matrix { + rows: expr_rows, + columns: expr_columns, + scalar: expr, + }, + &Ti::Matrix { + rows: goal_rows, + columns: goal_columns, + scalar: goal, + }, + ) if expr_rows == goal_rows && expr_columns == goal_columns => { + expr_scalar = expr; + goal_scalar = goal; + } + ( + &Ti::Array { + base: expr_base, + size: expr_size, + stride: _, + }, + &Ti::Array { + base: goal_base, + size: goal_size, + stride: _, + }, + ) if expr_size == goal_size => { + return types[expr_base] + .inner + .automatically_converts_to(&types[goal_base].inner, types); + } + _ => return None, + } + + match (expr_scalar.kind, goal_scalar.kind) { + (Sk::AbstractFloat, Sk::Float) => {} + (Sk::AbstractInt, Sk::Sint | Sk::Uint | Sk::AbstractFloat | Sk::Float) => {} + _ => return None, + } + + log::trace!(" okay: expr {expr_scalar:?}, goal {goal_scalar:?}"); + Some((expr_scalar, goal_scalar)) + } + + fn automatically_convertible_scalar( + &self, + types: &crate::UniqueArena<crate::Type>, + ) -> Option<crate::Scalar> { + use crate::TypeInner as Ti; + match *self { + Ti::Scalar(scalar) | Ti::Vector { scalar, .. } | Ti::Matrix { scalar, .. } => { + Some(scalar) + } + Ti::Array { base, .. } => types[base].inner.automatically_convertible_scalar(types), + Ti::Atomic(_) + | Ti::Pointer { .. } + | Ti::ValuePointer { .. } + | Ti::Struct { .. } + | Ti::Image { .. } + | Ti::Sampler { .. } + | Ti::AccelerationStructure + | Ti::RayQuery + | Ti::BindingArray { .. } => None, + } + } +} + +impl crate::Scalar { + /// Find the common type of `self` and `other` under WGSL's + /// automatic conversions. + /// + /// If there are any scalars to which WGSL's automatic conversions + /// will convert both `self` and `other`, return the best such + /// scalar. Otherwise, return `None`. + pub const fn automatic_conversion_combine(self, other: Self) -> Option<crate::Scalar> { + use crate::ScalarKind as Sk; + + match (self.kind, other.kind) { + // When the kinds match... + (Sk::AbstractFloat, Sk::AbstractFloat) + | (Sk::AbstractInt, Sk::AbstractInt) + | (Sk::Sint, Sk::Sint) + | (Sk::Uint, Sk::Uint) + | (Sk::Float, Sk::Float) + | (Sk::Bool, Sk::Bool) => { + if self.width == other.width { + // ... either no conversion is necessary ... + Some(self) + } else { + // ... or no conversion is possible. + // We never convert concrete to concrete, and + // abstract types should have only one size. + None + } + } + + // AbstractInt converts to AbstractFloat. + (Sk::AbstractFloat, Sk::AbstractInt) => Some(self), + (Sk::AbstractInt, Sk::AbstractFloat) => Some(other), + + // AbstractFloat converts to Float. + (Sk::AbstractFloat, Sk::Float) => Some(other), + (Sk::Float, Sk::AbstractFloat) => Some(self), + + // AbstractInt converts to concrete integer or float. + (Sk::AbstractInt, Sk::Uint | Sk::Sint | Sk::Float) => Some(other), + (Sk::Uint | Sk::Sint | Sk::Float, Sk::AbstractInt) => Some(self), + + // AbstractFloat can't be reconciled with concrete integer types. + (Sk::AbstractFloat, Sk::Uint | Sk::Sint) | (Sk::Uint | Sk::Sint, Sk::AbstractFloat) => { + None + } + + // Nothing can be reconciled with `bool`. + (Sk::Bool, _) | (_, Sk::Bool) => None, + + // Different concrete types cannot be reconciled. + (Sk::Sint | Sk::Uint | Sk::Float, Sk::Sint | Sk::Uint | Sk::Float) => None, + } + } + + /// Return `true` if automatic conversions will covert `self` to `goal`. + pub fn automatically_converts_to(self, goal: Self) -> bool { + self.automatic_conversion_combine(goal) == Some(goal) + } + + const fn concretize(self) -> Self { + use crate::ScalarKind as Sk; + match self.kind { + Sk::Sint | Sk::Uint | Sk::Float | Sk::Bool => self, + Sk::AbstractInt => Self::I32, + Sk::AbstractFloat => Self::F32, + } + } +} diff --git a/third_party/rust/naga/src/front/wgsl/lower/mod.rs b/third_party/rust/naga/src/front/wgsl/lower/mod.rs new file mode 100644 index 0000000000..ba9b49e135 --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/lower/mod.rs @@ -0,0 +1,2760 @@ +use std::num::NonZeroU32; + +use crate::front::wgsl::error::{Error, ExpectedToken, InvalidAssignmentType}; +use crate::front::wgsl::index::Index; +use crate::front::wgsl::parse::number::Number; +use crate::front::wgsl::parse::{ast, conv}; +use crate::front::Typifier; +use crate::proc::{ + ensure_block_returns, Alignment, ConstantEvaluator, Emitter, Layouter, ResolveContext, +}; +use crate::{Arena, FastHashMap, FastIndexMap, Handle, Span}; + +mod construction; +mod conversion; + +/// Resolves the inner type of a given expression. +/// +/// Expects a &mut [`ExpressionContext`] and a [`Handle<Expression>`]. +/// +/// Returns a &[`crate::TypeInner`]. +/// +/// Ideally, we would simply have a function that takes a `&mut ExpressionContext` +/// and returns a `&TypeResolution`. Unfortunately, this leads the borrow checker +/// to conclude that the mutable borrow lasts for as long as we are using the +/// `&TypeResolution`, so we can't use the `ExpressionContext` for anything else - +/// like, say, resolving another operand's type. Using a macro that expands to +/// two separate calls, only the first of which needs a `&mut`, +/// lets the borrow checker see that the mutable borrow is over. +macro_rules! resolve_inner { + ($ctx:ident, $expr:expr) => {{ + $ctx.grow_types($expr)?; + $ctx.typifier()[$expr].inner_with(&$ctx.module.types) + }}; +} +pub(super) use resolve_inner; + +/// Resolves the inner types of two given expressions. +/// +/// Expects a &mut [`ExpressionContext`] and two [`Handle<Expression>`]s. +/// +/// Returns a tuple containing two &[`crate::TypeInner`]. +/// +/// See the documentation of [`resolve_inner!`] for why this macro is necessary. +macro_rules! resolve_inner_binary { + ($ctx:ident, $left:expr, $right:expr) => {{ + $ctx.grow_types($left)?; + $ctx.grow_types($right)?; + ( + $ctx.typifier()[$left].inner_with(&$ctx.module.types), + $ctx.typifier()[$right].inner_with(&$ctx.module.types), + ) + }}; +} + +/// Resolves the type of a given expression. +/// +/// Expects a &mut [`ExpressionContext`] and a [`Handle<Expression>`]. +/// +/// Returns a &[`TypeResolution`]. +/// +/// See the documentation of [`resolve_inner!`] for why this macro is necessary. +/// +/// [`TypeResolution`]: crate::proc::TypeResolution +macro_rules! resolve { + ($ctx:ident, $expr:expr) => {{ + $ctx.grow_types($expr)?; + &$ctx.typifier()[$expr] + }}; +} +pub(super) use resolve; + +/// State for constructing a `crate::Module`. +pub struct GlobalContext<'source, 'temp, 'out> { + /// The `TranslationUnit`'s expressions arena. + ast_expressions: &'temp Arena<ast::Expression<'source>>, + + /// The `TranslationUnit`'s types arena. + types: &'temp Arena<ast::Type<'source>>, + + // Naga IR values. + /// The map from the names of module-scope declarations to the Naga IR + /// `Handle`s we have built for them, owned by `Lowerer::lower`. + globals: &'temp mut FastHashMap<&'source str, LoweredGlobalDecl>, + + /// The module we're constructing. + module: &'out mut crate::Module, + + const_typifier: &'temp mut Typifier, +} + +impl<'source> GlobalContext<'source, '_, '_> { + fn as_const(&mut self) -> ExpressionContext<'source, '_, '_> { + ExpressionContext { + ast_expressions: self.ast_expressions, + globals: self.globals, + types: self.types, + module: self.module, + const_typifier: self.const_typifier, + expr_type: ExpressionContextType::Constant, + } + } + + fn ensure_type_exists( + &mut self, + name: Option<String>, + inner: crate::TypeInner, + ) -> Handle<crate::Type> { + self.module + .types + .insert(crate::Type { inner, name }, Span::UNDEFINED) + } +} + +/// State for lowering a statement within a function. +pub struct StatementContext<'source, 'temp, 'out> { + // WGSL AST values. + /// A reference to [`TranslationUnit::expressions`] for the translation unit + /// we're lowering. + /// + /// [`TranslationUnit::expressions`]: ast::TranslationUnit::expressions + ast_expressions: &'temp Arena<ast::Expression<'source>>, + + /// A reference to [`TranslationUnit::types`] for the translation unit + /// we're lowering. + /// + /// [`TranslationUnit::types`]: ast::TranslationUnit::types + types: &'temp Arena<ast::Type<'source>>, + + // Naga IR values. + /// The map from the names of module-scope declarations to the Naga IR + /// `Handle`s we have built for them, owned by `Lowerer::lower`. + globals: &'temp mut FastHashMap<&'source str, LoweredGlobalDecl>, + + /// A map from each `ast::Local` handle to the Naga expression + /// we've built for it: + /// + /// - WGSL function arguments become Naga [`FunctionArgument`] expressions. + /// + /// - WGSL `var` declarations become Naga [`LocalVariable`] expressions. + /// + /// - WGSL `let` declararations become arbitrary Naga expressions. + /// + /// This always borrows the `local_table` local variable in + /// [`Lowerer::function`]. + /// + /// [`LocalVariable`]: crate::Expression::LocalVariable + /// [`FunctionArgument`]: crate::Expression::FunctionArgument + local_table: &'temp mut FastHashMap<Handle<ast::Local>, Typed<Handle<crate::Expression>>>, + + const_typifier: &'temp mut Typifier, + typifier: &'temp mut Typifier, + function: &'out mut crate::Function, + /// Stores the names of expressions that are assigned in `let` statement + /// Also stores the spans of the names, for use in errors. + named_expressions: &'out mut FastIndexMap<Handle<crate::Expression>, (String, Span)>, + module: &'out mut crate::Module, + + /// Which `Expression`s in `self.naga_expressions` are const expressions, in + /// the WGSL sense. + /// + /// According to the WGSL spec, a const expression must not refer to any + /// `let` declarations, even if those declarations' initializers are + /// themselves const expressions. So this tracker is not simply concerned + /// with the form of the expressions; it is also tracking whether WGSL says + /// we should consider them to be const. See the use of `force_non_const` in + /// the code for lowering `let` bindings. + expression_constness: &'temp mut crate::proc::ExpressionConstnessTracker, +} + +impl<'a, 'temp> StatementContext<'a, 'temp, '_> { + fn as_expression<'t>( + &'t mut self, + block: &'t mut crate::Block, + emitter: &'t mut Emitter, + ) -> ExpressionContext<'a, 't, '_> + where + 'temp: 't, + { + ExpressionContext { + globals: self.globals, + types: self.types, + ast_expressions: self.ast_expressions, + const_typifier: self.const_typifier, + module: self.module, + expr_type: ExpressionContextType::Runtime(RuntimeExpressionContext { + local_table: self.local_table, + function: self.function, + block, + emitter, + typifier: self.typifier, + expression_constness: self.expression_constness, + }), + } + } + + fn as_global(&mut self) -> GlobalContext<'a, '_, '_> { + GlobalContext { + ast_expressions: self.ast_expressions, + globals: self.globals, + types: self.types, + module: self.module, + const_typifier: self.const_typifier, + } + } + + fn invalid_assignment_type(&self, expr: Handle<crate::Expression>) -> InvalidAssignmentType { + if let Some(&(_, span)) = self.named_expressions.get(&expr) { + InvalidAssignmentType::ImmutableBinding(span) + } else { + match self.function.expressions[expr] { + crate::Expression::Swizzle { .. } => InvalidAssignmentType::Swizzle, + crate::Expression::Access { base, .. } => self.invalid_assignment_type(base), + crate::Expression::AccessIndex { base, .. } => self.invalid_assignment_type(base), + _ => InvalidAssignmentType::Other, + } + } + } +} + +pub struct RuntimeExpressionContext<'temp, 'out> { + /// A map from [`ast::Local`] handles to the Naga expressions we've built for them. + /// + /// This is always [`StatementContext::local_table`] for the + /// enclosing statement; see that documentation for details. + local_table: &'temp FastHashMap<Handle<ast::Local>, Typed<Handle<crate::Expression>>>, + + function: &'out mut crate::Function, + block: &'temp mut crate::Block, + emitter: &'temp mut Emitter, + typifier: &'temp mut Typifier, + + /// Which `Expression`s in `self.naga_expressions` are const expressions, in + /// the WGSL sense. + /// + /// See [`StatementContext::expression_constness`] for details. + expression_constness: &'temp mut crate::proc::ExpressionConstnessTracker, +} + +/// The type of Naga IR expression we are lowering an [`ast::Expression`] to. +pub enum ExpressionContextType<'temp, 'out> { + /// We are lowering to an arbitrary runtime expression, to be + /// included in a function's body. + /// + /// The given [`RuntimeExpressionContext`] holds information about local + /// variables, arguments, and other definitions available only to runtime + /// expressions, not constant or override expressions. + Runtime(RuntimeExpressionContext<'temp, 'out>), + + /// We are lowering to a constant expression, to be included in the module's + /// constant expression arena. + /// + /// Everything constant expressions are allowed to refer to is + /// available in the [`ExpressionContext`], so this variant + /// carries no further information. + Constant, +} + +/// State for lowering an [`ast::Expression`] to Naga IR. +/// +/// [`ExpressionContext`]s come in two kinds, distinguished by +/// the value of the [`expr_type`] field: +/// +/// - A [`Runtime`] context contributes [`naga::Expression`]s to a [`naga::Function`]'s +/// runtime expression arena. +/// +/// - A [`Constant`] context contributes [`naga::Expression`]s to a [`naga::Module`]'s +/// constant expression arena. +/// +/// [`ExpressionContext`]s are constructed in restricted ways: +/// +/// - To get a [`Runtime`] [`ExpressionContext`], call +/// [`StatementContext::as_expression`]. +/// +/// - To get a [`Constant`] [`ExpressionContext`], call +/// [`GlobalContext::as_const`]. +/// +/// - You can demote a [`Runtime`] context to a [`Constant`] context +/// by calling [`as_const`], but there's no way to go in the other +/// direction, producing a runtime context from a constant one. This +/// is because runtime expressions can refer to constant +/// expressions, via [`Expression::Constant`], but constant +/// expressions can't refer to a function's expressions. +/// +/// Not to be confused with `wgsl::parse::ExpressionContext`, which is +/// for parsing the `ast::Expression` in the first place. +/// +/// [`expr_type`]: ExpressionContext::expr_type +/// [`Runtime`]: ExpressionContextType::Runtime +/// [`naga::Expression`]: crate::Expression +/// [`naga::Function`]: crate::Function +/// [`Constant`]: ExpressionContextType::Constant +/// [`naga::Module`]: crate::Module +/// [`as_const`]: ExpressionContext::as_const +/// [`Expression::Constant`]: crate::Expression::Constant +pub struct ExpressionContext<'source, 'temp, 'out> { + // WGSL AST values. + ast_expressions: &'temp Arena<ast::Expression<'source>>, + types: &'temp Arena<ast::Type<'source>>, + + // Naga IR values. + /// The map from the names of module-scope declarations to the Naga IR + /// `Handle`s we have built for them, owned by `Lowerer::lower`. + globals: &'temp mut FastHashMap<&'source str, LoweredGlobalDecl>, + + /// The IR [`Module`] we're constructing. + /// + /// [`Module`]: crate::Module + module: &'out mut crate::Module, + + /// Type judgments for [`module::const_expressions`]. + /// + /// [`module::const_expressions`]: crate::Module::const_expressions + const_typifier: &'temp mut Typifier, + + /// Whether we are lowering a constant expression or a general + /// runtime expression, and the data needed in each case. + expr_type: ExpressionContextType<'temp, 'out>, +} + +impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { + fn as_const(&mut self) -> ExpressionContext<'source, '_, '_> { + ExpressionContext { + globals: self.globals, + types: self.types, + ast_expressions: self.ast_expressions, + const_typifier: self.const_typifier, + module: self.module, + expr_type: ExpressionContextType::Constant, + } + } + + fn as_global(&mut self) -> GlobalContext<'source, '_, '_> { + GlobalContext { + ast_expressions: self.ast_expressions, + globals: self.globals, + types: self.types, + module: self.module, + const_typifier: self.const_typifier, + } + } + + fn as_const_evaluator(&mut self) -> ConstantEvaluator { + match self.expr_type { + ExpressionContextType::Runtime(ref mut rctx) => ConstantEvaluator::for_wgsl_function( + self.module, + &mut rctx.function.expressions, + rctx.expression_constness, + rctx.emitter, + rctx.block, + ), + ExpressionContextType::Constant => ConstantEvaluator::for_wgsl_module(self.module), + } + } + + fn append_expression( + &mut self, + expr: crate::Expression, + span: Span, + ) -> Result<Handle<crate::Expression>, Error<'source>> { + let mut eval = self.as_const_evaluator(); + match eval.try_eval_and_append(&expr, span) { + Ok(expr) => Ok(expr), + + // `expr` is not a constant expression. This is fine as + // long as we're not building `Module::const_expressions`. + Err(err) => match self.expr_type { + ExpressionContextType::Runtime(ref mut rctx) => { + Ok(rctx.function.expressions.append(expr, span)) + } + ExpressionContextType::Constant => Err(Error::ConstantEvaluatorError(err, span)), + }, + } + } + + fn const_access(&self, handle: Handle<crate::Expression>) -> Option<u32> { + match self.expr_type { + ExpressionContextType::Runtime(ref ctx) => { + if !ctx.expression_constness.is_const(handle) { + return None; + } + + self.module + .to_ctx() + .eval_expr_to_u32_from(handle, &ctx.function.expressions) + .ok() + } + ExpressionContextType::Constant => self.module.to_ctx().eval_expr_to_u32(handle).ok(), + } + } + + fn get_expression_span(&self, handle: Handle<crate::Expression>) -> Span { + match self.expr_type { + ExpressionContextType::Runtime(ref ctx) => ctx.function.expressions.get_span(handle), + ExpressionContextType::Constant => self.module.const_expressions.get_span(handle), + } + } + + fn typifier(&self) -> &Typifier { + match self.expr_type { + ExpressionContextType::Runtime(ref ctx) => ctx.typifier, + ExpressionContextType::Constant => self.const_typifier, + } + } + + fn runtime_expression_ctx( + &mut self, + span: Span, + ) -> Result<&mut RuntimeExpressionContext<'temp, 'out>, Error<'source>> { + match self.expr_type { + ExpressionContextType::Runtime(ref mut ctx) => Ok(ctx), + ExpressionContextType::Constant => Err(Error::UnexpectedOperationInConstContext(span)), + } + } + + fn gather_component( + &mut self, + expr: Handle<crate::Expression>, + component_span: Span, + gather_span: Span, + ) -> Result<crate::SwizzleComponent, Error<'source>> { + match self.expr_type { + ExpressionContextType::Runtime(ref rctx) => { + if !rctx.expression_constness.is_const(expr) { + return Err(Error::ExpectedConstExprConcreteIntegerScalar( + component_span, + )); + } + + let index = self + .module + .to_ctx() + .eval_expr_to_u32_from(expr, &rctx.function.expressions) + .map_err(|err| match err { + crate::proc::U32EvalError::NonConst => { + Error::ExpectedConstExprConcreteIntegerScalar(component_span) + } + crate::proc::U32EvalError::Negative => { + Error::ExpectedNonNegative(component_span) + } + })?; + crate::SwizzleComponent::XYZW + .get(index as usize) + .copied() + .ok_or(Error::InvalidGatherComponent(component_span)) + } + // This means a `gather` operation appeared in a constant expression. + // This error refers to the `gather` itself, not its "component" argument. + ExpressionContextType::Constant => { + Err(Error::UnexpectedOperationInConstContext(gather_span)) + } + } + } + + /// Determine the type of `handle`, and add it to the module's arena. + /// + /// If you just need a `TypeInner` for `handle`'s type, use the + /// [`resolve_inner!`] macro instead. This function + /// should only be used when the type of `handle` needs to appear + /// in the module's final `Arena<Type>`, for example, if you're + /// creating a [`LocalVariable`] whose type is inferred from its + /// initializer. + /// + /// [`LocalVariable`]: crate::LocalVariable + fn register_type( + &mut self, + handle: Handle<crate::Expression>, + ) -> Result<Handle<crate::Type>, Error<'source>> { + self.grow_types(handle)?; + // This is equivalent to calling ExpressionContext::typifier(), + // except that this lets the borrow checker see that it's okay + // to also borrow self.module.types mutably below. + let typifier = match self.expr_type { + ExpressionContextType::Runtime(ref ctx) => ctx.typifier, + ExpressionContextType::Constant => &*self.const_typifier, + }; + Ok(typifier.register_type(handle, &mut self.module.types)) + } + + /// Resolve the types of all expressions up through `handle`. + /// + /// Ensure that [`self.typifier`] has a [`TypeResolution`] for + /// every expression in [`self.function.expressions`]. + /// + /// This does not add types to any arena. The [`Typifier`] + /// documentation explains the steps we take to avoid filling + /// arenas with intermediate types. + /// + /// This function takes `&mut self`, so it can't conveniently + /// return a shared reference to the resulting `TypeResolution`: + /// the shared reference would extend the mutable borrow, and you + /// wouldn't be able to use `self` for anything else. Instead, you + /// should use [`register_type`] or one of [`resolve!`], + /// [`resolve_inner!`] or [`resolve_inner_binary!`]. + /// + /// [`self.typifier`]: ExpressionContext::typifier + /// [`TypeResolution`]: crate::proc::TypeResolution + /// [`register_type`]: Self::register_type + /// [`Typifier`]: Typifier + fn grow_types( + &mut self, + handle: Handle<crate::Expression>, + ) -> Result<&mut Self, Error<'source>> { + let empty_arena = Arena::new(); + let resolve_ctx; + let typifier; + let expressions; + match self.expr_type { + ExpressionContextType::Runtime(ref mut ctx) => { + resolve_ctx = ResolveContext::with_locals( + self.module, + &ctx.function.local_variables, + &ctx.function.arguments, + ); + typifier = &mut *ctx.typifier; + expressions = &ctx.function.expressions; + } + ExpressionContextType::Constant => { + resolve_ctx = ResolveContext::with_locals(self.module, &empty_arena, &[]); + typifier = self.const_typifier; + expressions = &self.module.const_expressions; + } + }; + typifier + .grow(handle, expressions, &resolve_ctx) + .map_err(Error::InvalidResolve)?; + + Ok(self) + } + + fn image_data( + &mut self, + image: Handle<crate::Expression>, + span: Span, + ) -> Result<(crate::ImageClass, bool), Error<'source>> { + match *resolve_inner!(self, image) { + crate::TypeInner::Image { class, arrayed, .. } => Ok((class, arrayed)), + _ => Err(Error::BadTexture(span)), + } + } + + fn prepare_args<'b>( + &mut self, + args: &'b [Handle<ast::Expression<'source>>], + min_args: u32, + span: Span, + ) -> ArgumentContext<'b, 'source> { + ArgumentContext { + args: args.iter(), + min_args, + args_used: 0, + total_args: args.len() as u32, + span, + } + } + + /// Insert splats, if needed by the non-'*' operations. + /// + /// See the "Binary arithmetic expressions with mixed scalar and vector operands" + /// table in the WebGPU Shading Language specification for relevant operators. + /// + /// Multiply is not handled here as backends are expected to handle vec*scalar + /// operations, so inserting splats into the IR increases size needlessly. + fn binary_op_splat( + &mut self, + op: crate::BinaryOperator, + left: &mut Handle<crate::Expression>, + right: &mut Handle<crate::Expression>, + ) -> Result<(), Error<'source>> { + if matches!( + op, + crate::BinaryOperator::Add + | crate::BinaryOperator::Subtract + | crate::BinaryOperator::Divide + | crate::BinaryOperator::Modulo + ) { + match resolve_inner_binary!(self, *left, *right) { + (&crate::TypeInner::Vector { size, .. }, &crate::TypeInner::Scalar { .. }) => { + *right = self.append_expression( + crate::Expression::Splat { + size, + value: *right, + }, + self.get_expression_span(*right), + )?; + } + (&crate::TypeInner::Scalar { .. }, &crate::TypeInner::Vector { size, .. }) => { + *left = self.append_expression( + crate::Expression::Splat { size, value: *left }, + self.get_expression_span(*left), + )?; + } + _ => {} + } + } + + Ok(()) + } + + /// Add a single expression to the expression table that is not covered by `self.emitter`. + /// + /// This is useful for `CallResult` and `AtomicResult` expressions, which should not be covered by + /// `Emit` statements. + fn interrupt_emitter( + &mut self, + expression: crate::Expression, + span: Span, + ) -> Result<Handle<crate::Expression>, Error<'source>> { + match self.expr_type { + ExpressionContextType::Runtime(ref mut rctx) => { + rctx.block + .extend(rctx.emitter.finish(&rctx.function.expressions)); + } + ExpressionContextType::Constant => {} + } + let result = self.append_expression(expression, span); + match self.expr_type { + ExpressionContextType::Runtime(ref mut rctx) => { + rctx.emitter.start(&rctx.function.expressions); + } + ExpressionContextType::Constant => {} + } + result + } + + /// Apply the WGSL Load Rule to `expr`. + /// + /// If `expr` is has type `ref<SC, T, A>`, perform a load to produce a value of type + /// `T`. Otherwise, return `expr` unchanged. + fn apply_load_rule( + &mut self, + expr: Typed<Handle<crate::Expression>>, + ) -> Result<Handle<crate::Expression>, Error<'source>> { + match expr { + Typed::Reference(pointer) => { + let load = crate::Expression::Load { pointer }; + let span = self.get_expression_span(pointer); + self.append_expression(load, span) + } + Typed::Plain(handle) => Ok(handle), + } + } + + fn ensure_type_exists(&mut self, inner: crate::TypeInner) -> Handle<crate::Type> { + self.as_global().ensure_type_exists(None, inner) + } +} + +struct ArgumentContext<'ctx, 'source> { + args: std::slice::Iter<'ctx, Handle<ast::Expression<'source>>>, + min_args: u32, + args_used: u32, + total_args: u32, + span: Span, +} + +impl<'source> ArgumentContext<'_, 'source> { + pub fn finish(self) -> Result<(), Error<'source>> { + if self.args.len() == 0 { + Ok(()) + } else { + Err(Error::WrongArgumentCount { + found: self.total_args, + expected: self.min_args..self.args_used + 1, + span: self.span, + }) + } + } + + pub fn next(&mut self) -> Result<Handle<ast::Expression<'source>>, Error<'source>> { + match self.args.next().copied() { + Some(arg) => { + self.args_used += 1; + Ok(arg) + } + None => Err(Error::WrongArgumentCount { + found: self.total_args, + expected: self.min_args..self.args_used + 1, + span: self.span, + }), + } + } +} + +/// WGSL type annotations on expressions, types, values, etc. +/// +/// Naga and WGSL types are very close, but Naga lacks WGSL's `ref` types, which +/// we need to know to apply the Load Rule. This enum carries some WGSL or Naga +/// datum along with enough information to determine its corresponding WGSL +/// type. +/// +/// The `T` type parameter can be any expression-like thing: +/// +/// - `Typed<Handle<crate::Type>>` can represent a full WGSL type. For example, +/// given some Naga `Pointer` type `ptr`, a WGSL reference type is a +/// `Typed::Reference(ptr)` whereas a WGSL pointer type is a +/// `Typed::Plain(ptr)`. +/// +/// - `Typed<crate::Expression>` or `Typed<Handle<crate::Expression>>` can +/// represent references similarly. +/// +/// Use the `map` and `try_map` methods to convert from one expression +/// representation to another. +/// +/// [`Expression`]: crate::Expression +#[derive(Debug, Copy, Clone)] +enum Typed<T> { + /// A WGSL reference. + Reference(T), + + /// A WGSL plain type. + Plain(T), +} + +impl<T> Typed<T> { + fn map<U>(self, mut f: impl FnMut(T) -> U) -> Typed<U> { + match self { + Self::Reference(v) => Typed::Reference(f(v)), + Self::Plain(v) => Typed::Plain(f(v)), + } + } + + fn try_map<U, E>(self, mut f: impl FnMut(T) -> Result<U, E>) -> Result<Typed<U>, E> { + Ok(match self { + Self::Reference(expr) => Typed::Reference(f(expr)?), + Self::Plain(expr) => Typed::Plain(f(expr)?), + }) + } +} + +/// A single vector component or swizzle. +/// +/// This represents the things that can appear after the `.` in a vector access +/// expression: either a single component name, or a series of them, +/// representing a swizzle. +enum Components { + Single(u32), + Swizzle { + size: crate::VectorSize, + pattern: [crate::SwizzleComponent; 4], + }, +} + +impl Components { + const fn letter_component(letter: char) -> Option<crate::SwizzleComponent> { + use crate::SwizzleComponent as Sc; + match letter { + 'x' | 'r' => Some(Sc::X), + 'y' | 'g' => Some(Sc::Y), + 'z' | 'b' => Some(Sc::Z), + 'w' | 'a' => Some(Sc::W), + _ => None, + } + } + + fn single_component(name: &str, name_span: Span) -> Result<u32, Error> { + let ch = name.chars().next().ok_or(Error::BadAccessor(name_span))?; + match Self::letter_component(ch) { + Some(sc) => Ok(sc as u32), + None => Err(Error::BadAccessor(name_span)), + } + } + + /// Construct a `Components` value from a 'member' name, like `"wzy"` or `"x"`. + /// + /// Use `name_span` for reporting errors in parsing the component string. + fn new(name: &str, name_span: Span) -> Result<Self, Error> { + let size = match name.len() { + 1 => return Ok(Components::Single(Self::single_component(name, name_span)?)), + 2 => crate::VectorSize::Bi, + 3 => crate::VectorSize::Tri, + 4 => crate::VectorSize::Quad, + _ => return Err(Error::BadAccessor(name_span)), + }; + + let mut pattern = [crate::SwizzleComponent::X; 4]; + for (comp, ch) in pattern.iter_mut().zip(name.chars()) { + *comp = Self::letter_component(ch).ok_or(Error::BadAccessor(name_span))?; + } + + Ok(Components::Swizzle { size, pattern }) + } +} + +/// An `ast::GlobalDecl` for which we have built the Naga IR equivalent. +enum LoweredGlobalDecl { + Function(Handle<crate::Function>), + Var(Handle<crate::GlobalVariable>), + Const(Handle<crate::Constant>), + Type(Handle<crate::Type>), + EntryPoint, +} + +enum Texture { + Gather, + GatherCompare, + + Sample, + SampleBias, + SampleCompare, + SampleCompareLevel, + SampleGrad, + SampleLevel, + // SampleBaseClampToEdge, +} + +impl Texture { + pub fn map(word: &str) -> Option<Self> { + Some(match word { + "textureGather" => Self::Gather, + "textureGatherCompare" => Self::GatherCompare, + + "textureSample" => Self::Sample, + "textureSampleBias" => Self::SampleBias, + "textureSampleCompare" => Self::SampleCompare, + "textureSampleCompareLevel" => Self::SampleCompareLevel, + "textureSampleGrad" => Self::SampleGrad, + "textureSampleLevel" => Self::SampleLevel, + // "textureSampleBaseClampToEdge" => Some(Self::SampleBaseClampToEdge), + _ => return None, + }) + } + + pub const fn min_argument_count(&self) -> u32 { + match *self { + Self::Gather => 3, + Self::GatherCompare => 4, + + Self::Sample => 3, + Self::SampleBias => 5, + Self::SampleCompare => 5, + Self::SampleCompareLevel => 5, + Self::SampleGrad => 6, + Self::SampleLevel => 5, + // Self::SampleBaseClampToEdge => 3, + } + } +} + +pub struct Lowerer<'source, 'temp> { + index: &'temp Index<'source>, + layouter: Layouter, +} + +impl<'source, 'temp> Lowerer<'source, 'temp> { + pub fn new(index: &'temp Index<'source>) -> Self { + Self { + index, + layouter: Layouter::default(), + } + } + + pub fn lower( + &mut self, + tu: &'temp ast::TranslationUnit<'source>, + ) -> Result<crate::Module, Error<'source>> { + let mut module = crate::Module::default(); + + let mut ctx = GlobalContext { + ast_expressions: &tu.expressions, + globals: &mut FastHashMap::default(), + types: &tu.types, + module: &mut module, + const_typifier: &mut Typifier::new(), + }; + + for decl_handle in self.index.visit_ordered() { + let span = tu.decls.get_span(decl_handle); + let decl = &tu.decls[decl_handle]; + + match decl.kind { + ast::GlobalDeclKind::Fn(ref f) => { + let lowered_decl = self.function(f, span, &mut ctx)?; + ctx.globals.insert(f.name.name, lowered_decl); + } + ast::GlobalDeclKind::Var(ref v) => { + let ty = self.resolve_ast_type(v.ty, &mut ctx)?; + + let init; + if let Some(init_ast) = v.init { + let mut ectx = ctx.as_const(); + let lowered = self.expression_for_abstract(init_ast, &mut ectx)?; + let ty_res = crate::proc::TypeResolution::Handle(ty); + let converted = ectx + .try_automatic_conversions(lowered, &ty_res, v.name.span) + .map_err(|error| match error { + Error::AutoConversion { + dest_span: _, + dest_type, + source_span: _, + source_type, + } => Error::InitializationTypeMismatch { + name: v.name.span, + expected: dest_type, + got: source_type, + }, + other => other, + })?; + init = Some(converted); + } else { + init = None; + } + + let binding = if let Some(ref binding) = v.binding { + Some(crate::ResourceBinding { + group: self.const_u32(binding.group, &mut ctx.as_const())?.0, + binding: self.const_u32(binding.binding, &mut ctx.as_const())?.0, + }) + } else { + None + }; + + let handle = ctx.module.global_variables.append( + crate::GlobalVariable { + name: Some(v.name.name.to_string()), + space: v.space, + binding, + ty, + init, + }, + span, + ); + + ctx.globals + .insert(v.name.name, LoweredGlobalDecl::Var(handle)); + } + ast::GlobalDeclKind::Const(ref c) => { + let mut ectx = ctx.as_const(); + let mut init = self.expression_for_abstract(c.init, &mut ectx)?; + + let ty; + if let Some(explicit_ty) = c.ty { + let explicit_ty = + self.resolve_ast_type(explicit_ty, &mut ectx.as_global())?; + let explicit_ty_res = crate::proc::TypeResolution::Handle(explicit_ty); + init = ectx + .try_automatic_conversions(init, &explicit_ty_res, c.name.span) + .map_err(|error| match error { + Error::AutoConversion { + dest_span: _, + dest_type, + source_span: _, + source_type, + } => Error::InitializationTypeMismatch { + name: c.name.span, + expected: dest_type, + got: source_type, + }, + other => other, + })?; + ty = explicit_ty; + } else { + init = ectx.concretize(init)?; + ty = ectx.register_type(init)?; + } + + let handle = ctx.module.constants.append( + crate::Constant { + name: Some(c.name.name.to_string()), + r#override: crate::Override::None, + ty, + init, + }, + span, + ); + + ctx.globals + .insert(c.name.name, LoweredGlobalDecl::Const(handle)); + } + ast::GlobalDeclKind::Struct(ref s) => { + let handle = self.r#struct(s, span, &mut ctx)?; + ctx.globals + .insert(s.name.name, LoweredGlobalDecl::Type(handle)); + } + ast::GlobalDeclKind::Type(ref alias) => { + let ty = self.resolve_named_ast_type( + alias.ty, + Some(alias.name.name.to_string()), + &mut ctx, + )?; + ctx.globals + .insert(alias.name.name, LoweredGlobalDecl::Type(ty)); + } + } + } + + // Constant evaluation may leave abstract-typed literals and + // compositions in expression arenas, so we need to compact the module + // to remove unused expressions and types. + crate::compact::compact(&mut module); + + Ok(module) + } + + fn function( + &mut self, + f: &ast::Function<'source>, + span: Span, + ctx: &mut GlobalContext<'source, '_, '_>, + ) -> Result<LoweredGlobalDecl, Error<'source>> { + let mut local_table = FastHashMap::default(); + let mut expressions = Arena::new(); + let mut named_expressions = FastIndexMap::default(); + + let arguments = f + .arguments + .iter() + .enumerate() + .map(|(i, arg)| { + let ty = self.resolve_ast_type(arg.ty, ctx)?; + let expr = expressions + .append(crate::Expression::FunctionArgument(i as u32), arg.name.span); + local_table.insert(arg.handle, Typed::Plain(expr)); + named_expressions.insert(expr, (arg.name.name.to_string(), arg.name.span)); + + Ok(crate::FunctionArgument { + name: Some(arg.name.name.to_string()), + ty, + binding: self.binding(&arg.binding, ty, ctx)?, + }) + }) + .collect::<Result<Vec<_>, _>>()?; + + let result = f + .result + .as_ref() + .map(|res| { + let ty = self.resolve_ast_type(res.ty, ctx)?; + Ok(crate::FunctionResult { + ty, + binding: self.binding(&res.binding, ty, ctx)?, + }) + }) + .transpose()?; + + let mut function = crate::Function { + name: Some(f.name.name.to_string()), + arguments, + result, + local_variables: Arena::new(), + expressions, + named_expressions: crate::NamedExpressions::default(), + body: crate::Block::default(), + }; + + let mut typifier = Typifier::default(); + let mut stmt_ctx = StatementContext { + local_table: &mut local_table, + globals: ctx.globals, + ast_expressions: ctx.ast_expressions, + const_typifier: ctx.const_typifier, + typifier: &mut typifier, + function: &mut function, + named_expressions: &mut named_expressions, + types: ctx.types, + module: ctx.module, + expression_constness: &mut crate::proc::ExpressionConstnessTracker::new(), + }; + let mut body = self.block(&f.body, false, &mut stmt_ctx)?; + ensure_block_returns(&mut body); + + function.body = body; + function.named_expressions = named_expressions + .into_iter() + .map(|(key, (name, _))| (key, name)) + .collect(); + + if let Some(ref entry) = f.entry_point { + let workgroup_size = if let Some(workgroup_size) = entry.workgroup_size { + // TODO: replace with try_map once stabilized + let mut workgroup_size_out = [1; 3]; + for (i, size) in workgroup_size.into_iter().enumerate() { + if let Some(size_expr) = size { + workgroup_size_out[i] = self.const_u32(size_expr, &mut ctx.as_const())?.0; + } + } + workgroup_size_out + } else { + [0; 3] + }; + + ctx.module.entry_points.push(crate::EntryPoint { + name: f.name.name.to_string(), + stage: entry.stage, + early_depth_test: entry.early_depth_test, + workgroup_size, + function, + }); + Ok(LoweredGlobalDecl::EntryPoint) + } else { + let handle = ctx.module.functions.append(function, span); + Ok(LoweredGlobalDecl::Function(handle)) + } + } + + fn block( + &mut self, + b: &ast::Block<'source>, + is_inside_loop: bool, + ctx: &mut StatementContext<'source, '_, '_>, + ) -> Result<crate::Block, Error<'source>> { + let mut block = crate::Block::default(); + + for stmt in b.stmts.iter() { + self.statement(stmt, &mut block, is_inside_loop, ctx)?; + } + + Ok(block) + } + + fn statement( + &mut self, + stmt: &ast::Statement<'source>, + block: &mut crate::Block, + is_inside_loop: bool, + ctx: &mut StatementContext<'source, '_, '_>, + ) -> Result<(), Error<'source>> { + let out = match stmt.kind { + ast::StatementKind::Block(ref block) => { + let block = self.block(block, is_inside_loop, ctx)?; + crate::Statement::Block(block) + } + ast::StatementKind::LocalDecl(ref decl) => match *decl { + ast::LocalDecl::Let(ref l) => { + let mut emitter = Emitter::default(); + emitter.start(&ctx.function.expressions); + + let value = + self.expression(l.init, &mut ctx.as_expression(block, &mut emitter))?; + + // The WGSL spec says that any expression that refers to a + // `let`-bound variable is not a const expression. This + // affects when errors must be reported, so we can't even + // treat suitable `let` bindings as constant as an + // optimization. + ctx.expression_constness.force_non_const(value); + + let explicit_ty = + l.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx.as_global())) + .transpose()?; + + if let Some(ty) = explicit_ty { + let mut ctx = ctx.as_expression(block, &mut emitter); + let init_ty = ctx.register_type(value)?; + if !ctx.module.types[ty] + .inner + .equivalent(&ctx.module.types[init_ty].inner, &ctx.module.types) + { + let gctx = &ctx.module.to_ctx(); + return Err(Error::InitializationTypeMismatch { + name: l.name.span, + expected: ty.to_wgsl(gctx), + got: init_ty.to_wgsl(gctx), + }); + } + } + + block.extend(emitter.finish(&ctx.function.expressions)); + ctx.local_table.insert(l.handle, Typed::Plain(value)); + ctx.named_expressions + .insert(value, (l.name.name.to_string(), l.name.span)); + + return Ok(()); + } + ast::LocalDecl::Var(ref v) => { + let explicit_ty = + v.ty.map(|ast| self.resolve_ast_type(ast, &mut ctx.as_global())) + .transpose()?; + + let mut emitter = Emitter::default(); + emitter.start(&ctx.function.expressions); + let mut ectx = ctx.as_expression(block, &mut emitter); + + let ty; + let initializer; + match (v.init, explicit_ty) { + (Some(init), Some(explicit_ty)) => { + let init = self.expression_for_abstract(init, &mut ectx)?; + let ty_res = crate::proc::TypeResolution::Handle(explicit_ty); + let init = ectx + .try_automatic_conversions(init, &ty_res, v.name.span) + .map_err(|error| match error { + Error::AutoConversion { + dest_span: _, + dest_type, + source_span: _, + source_type, + } => Error::InitializationTypeMismatch { + name: v.name.span, + expected: dest_type, + got: source_type, + }, + other => other, + })?; + ty = explicit_ty; + initializer = Some(init); + } + (Some(init), None) => { + let concretized = self.expression(init, &mut ectx)?; + ty = ectx.register_type(concretized)?; + initializer = Some(concretized); + } + (None, Some(explicit_ty)) => { + ty = explicit_ty; + initializer = None; + } + (None, None) => return Err(Error::MissingType(v.name.span)), + } + + let (const_initializer, initializer) = { + match initializer { + Some(init) => { + // It's not correct to hoist the initializer up + // to the top of the function if: + // - the initialization is inside a loop, and should + // take place on every iteration, or + // - the initialization is not a constant + // expression, so its value depends on the + // state at the point of initialization. + if is_inside_loop || !ctx.expression_constness.is_const(init) { + (None, Some(init)) + } else { + (Some(init), None) + } + } + None => (None, None), + } + }; + + let var = ctx.function.local_variables.append( + crate::LocalVariable { + name: Some(v.name.name.to_string()), + ty, + init: const_initializer, + }, + stmt.span, + ); + + let handle = ctx.as_expression(block, &mut emitter).interrupt_emitter( + crate::Expression::LocalVariable(var), + Span::UNDEFINED, + )?; + block.extend(emitter.finish(&ctx.function.expressions)); + ctx.local_table.insert(v.handle, Typed::Reference(handle)); + + match initializer { + Some(initializer) => crate::Statement::Store { + pointer: handle, + value: initializer, + }, + None => return Ok(()), + } + } + }, + ast::StatementKind::If { + condition, + ref accept, + ref reject, + } => { + let mut emitter = Emitter::default(); + emitter.start(&ctx.function.expressions); + + let condition = + self.expression(condition, &mut ctx.as_expression(block, &mut emitter))?; + block.extend(emitter.finish(&ctx.function.expressions)); + + let accept = self.block(accept, is_inside_loop, ctx)?; + let reject = self.block(reject, is_inside_loop, ctx)?; + + crate::Statement::If { + condition, + accept, + reject, + } + } + ast::StatementKind::Switch { + selector, + ref cases, + } => { + let mut emitter = Emitter::default(); + emitter.start(&ctx.function.expressions); + + let mut ectx = ctx.as_expression(block, &mut emitter); + let selector = self.expression(selector, &mut ectx)?; + + let uint = + resolve_inner!(ectx, selector).scalar_kind() == Some(crate::ScalarKind::Uint); + block.extend(emitter.finish(&ctx.function.expressions)); + + let cases = cases + .iter() + .map(|case| { + Ok(crate::SwitchCase { + value: match case.value { + ast::SwitchValue::Expr(expr) => { + let span = ctx.ast_expressions.get_span(expr); + let expr = + self.expression(expr, &mut ctx.as_global().as_const())?; + match ctx.module.to_ctx().eval_expr_to_literal(expr) { + Some(crate::Literal::I32(value)) if !uint => { + crate::SwitchValue::I32(value) + } + Some(crate::Literal::U32(value)) if uint => { + crate::SwitchValue::U32(value) + } + _ => { + return Err(Error::InvalidSwitchValue { uint, span }); + } + } + } + ast::SwitchValue::Default => crate::SwitchValue::Default, + }, + body: self.block(&case.body, is_inside_loop, ctx)?, + fall_through: case.fall_through, + }) + }) + .collect::<Result<_, _>>()?; + + crate::Statement::Switch { selector, cases } + } + ast::StatementKind::Loop { + ref body, + ref continuing, + break_if, + } => { + let body = self.block(body, true, ctx)?; + let mut continuing = self.block(continuing, true, ctx)?; + + let mut emitter = Emitter::default(); + emitter.start(&ctx.function.expressions); + let break_if = break_if + .map(|expr| { + self.expression(expr, &mut ctx.as_expression(&mut continuing, &mut emitter)) + }) + .transpose()?; + continuing.extend(emitter.finish(&ctx.function.expressions)); + + crate::Statement::Loop { + body, + continuing, + break_if, + } + } + ast::StatementKind::Break => crate::Statement::Break, + ast::StatementKind::Continue => crate::Statement::Continue, + ast::StatementKind::Return { value } => { + let mut emitter = Emitter::default(); + emitter.start(&ctx.function.expressions); + + let value = value + .map(|expr| self.expression(expr, &mut ctx.as_expression(block, &mut emitter))) + .transpose()?; + block.extend(emitter.finish(&ctx.function.expressions)); + + crate::Statement::Return { value } + } + ast::StatementKind::Kill => crate::Statement::Kill, + ast::StatementKind::Call { + ref function, + ref arguments, + } => { + let mut emitter = Emitter::default(); + emitter.start(&ctx.function.expressions); + + let _ = self.call( + stmt.span, + function, + arguments, + &mut ctx.as_expression(block, &mut emitter), + )?; + block.extend(emitter.finish(&ctx.function.expressions)); + return Ok(()); + } + ast::StatementKind::Assign { + target: ast_target, + op, + value, + } => { + let mut emitter = Emitter::default(); + emitter.start(&ctx.function.expressions); + + let target = self.expression_for_reference( + ast_target, + &mut ctx.as_expression(block, &mut emitter), + )?; + let mut value = + self.expression(value, &mut ctx.as_expression(block, &mut emitter))?; + + let target_handle = match target { + Typed::Reference(handle) => handle, + Typed::Plain(handle) => { + let ty = ctx.invalid_assignment_type(handle); + return Err(Error::InvalidAssignment { + span: ctx.ast_expressions.get_span(ast_target), + ty, + }); + } + }; + + let value = match op { + Some(op) => { + let mut ctx = ctx.as_expression(block, &mut emitter); + let mut left = ctx.apply_load_rule(target)?; + ctx.binary_op_splat(op, &mut left, &mut value)?; + ctx.append_expression( + crate::Expression::Binary { + op, + left, + right: value, + }, + stmt.span, + )? + } + None => value, + }; + block.extend(emitter.finish(&ctx.function.expressions)); + + crate::Statement::Store { + pointer: target_handle, + value, + } + } + ast::StatementKind::Increment(value) | ast::StatementKind::Decrement(value) => { + let mut emitter = Emitter::default(); + emitter.start(&ctx.function.expressions); + + let op = match stmt.kind { + ast::StatementKind::Increment(_) => crate::BinaryOperator::Add, + ast::StatementKind::Decrement(_) => crate::BinaryOperator::Subtract, + _ => unreachable!(), + }; + + let value_span = ctx.ast_expressions.get_span(value); + let target = self + .expression_for_reference(value, &mut ctx.as_expression(block, &mut emitter))?; + let target_handle = match target { + Typed::Reference(handle) => handle, + Typed::Plain(_) => return Err(Error::BadIncrDecrReferenceType(value_span)), + }; + + let mut ectx = ctx.as_expression(block, &mut emitter); + let scalar = match *resolve_inner!(ectx, target_handle) { + crate::TypeInner::ValuePointer { + size: None, scalar, .. + } => scalar, + crate::TypeInner::Pointer { base, .. } => match ectx.module.types[base].inner { + crate::TypeInner::Scalar(scalar) => scalar, + _ => return Err(Error::BadIncrDecrReferenceType(value_span)), + }, + _ => return Err(Error::BadIncrDecrReferenceType(value_span)), + }; + let literal = match scalar.kind { + crate::ScalarKind::Sint | crate::ScalarKind::Uint => { + crate::Literal::one(scalar) + .ok_or(Error::BadIncrDecrReferenceType(value_span))? + } + _ => return Err(Error::BadIncrDecrReferenceType(value_span)), + }; + + let right = + ectx.interrupt_emitter(crate::Expression::Literal(literal), Span::UNDEFINED)?; + let rctx = ectx.runtime_expression_ctx(stmt.span)?; + let left = rctx.function.expressions.append( + crate::Expression::Load { + pointer: target_handle, + }, + value_span, + ); + let value = rctx + .function + .expressions + .append(crate::Expression::Binary { op, left, right }, stmt.span); + + block.extend(emitter.finish(&ctx.function.expressions)); + crate::Statement::Store { + pointer: target_handle, + value, + } + } + ast::StatementKind::Ignore(expr) => { + let mut emitter = Emitter::default(); + emitter.start(&ctx.function.expressions); + + let _ = self.expression(expr, &mut ctx.as_expression(block, &mut emitter))?; + block.extend(emitter.finish(&ctx.function.expressions)); + return Ok(()); + } + }; + + block.push(out, stmt.span); + + Ok(()) + } + + /// Lower `expr` and apply the Load Rule if possible. + /// + /// For the time being, this concretizes abstract values, to support + /// consumers that haven't been adapted to consume them yet. Consumers + /// prepared for abstract values can call [`expression_for_abstract`]. + /// + /// [`expression_for_abstract`]: Lowerer::expression_for_abstract + fn expression( + &mut self, + expr: Handle<ast::Expression<'source>>, + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result<Handle<crate::Expression>, Error<'source>> { + let expr = self.expression_for_abstract(expr, ctx)?; + ctx.concretize(expr) + } + + fn expression_for_abstract( + &mut self, + expr: Handle<ast::Expression<'source>>, + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result<Handle<crate::Expression>, Error<'source>> { + let expr = self.expression_for_reference(expr, ctx)?; + ctx.apply_load_rule(expr) + } + + fn expression_for_reference( + &mut self, + expr: Handle<ast::Expression<'source>>, + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result<Typed<Handle<crate::Expression>>, Error<'source>> { + let span = ctx.ast_expressions.get_span(expr); + let expr = &ctx.ast_expressions[expr]; + + let expr: Typed<crate::Expression> = match *expr { + ast::Expression::Literal(literal) => { + let literal = match literal { + ast::Literal::Number(Number::F32(f)) => crate::Literal::F32(f), + ast::Literal::Number(Number::I32(i)) => crate::Literal::I32(i), + ast::Literal::Number(Number::U32(u)) => crate::Literal::U32(u), + ast::Literal::Number(Number::F64(f)) => crate::Literal::F64(f), + ast::Literal::Number(Number::AbstractInt(i)) => crate::Literal::AbstractInt(i), + ast::Literal::Number(Number::AbstractFloat(f)) => { + crate::Literal::AbstractFloat(f) + } + ast::Literal::Bool(b) => crate::Literal::Bool(b), + }; + let handle = ctx.interrupt_emitter(crate::Expression::Literal(literal), span)?; + return Ok(Typed::Plain(handle)); + } + ast::Expression::Ident(ast::IdentExpr::Local(local)) => { + let rctx = ctx.runtime_expression_ctx(span)?; + return Ok(rctx.local_table[&local]); + } + ast::Expression::Ident(ast::IdentExpr::Unresolved(name)) => { + let global = ctx + .globals + .get(name) + .ok_or(Error::UnknownIdent(span, name))?; + let expr = match *global { + LoweredGlobalDecl::Var(handle) => { + let expr = crate::Expression::GlobalVariable(handle); + match ctx.module.global_variables[handle].space { + crate::AddressSpace::Handle => Typed::Plain(expr), + _ => Typed::Reference(expr), + } + } + LoweredGlobalDecl::Const(handle) => { + Typed::Plain(crate::Expression::Constant(handle)) + } + _ => { + return Err(Error::Unexpected(span, ExpectedToken::Variable)); + } + }; + + return expr.try_map(|handle| ctx.interrupt_emitter(handle, span)); + } + ast::Expression::Construct { + ref ty, + ty_span, + ref components, + } => { + let handle = self.construct(span, ty, ty_span, components, ctx)?; + return Ok(Typed::Plain(handle)); + } + ast::Expression::Unary { op, expr } => { + let expr = self.expression_for_abstract(expr, ctx)?; + Typed::Plain(crate::Expression::Unary { op, expr }) + } + ast::Expression::AddrOf(expr) => { + // The `&` operator simply converts a reference to a pointer. And since a + // reference is required, the Load Rule is not applied. + match self.expression_for_reference(expr, ctx)? { + Typed::Reference(handle) => { + // No code is generated. We just declare the reference a pointer now. + return Ok(Typed::Plain(handle)); + } + Typed::Plain(_) => { + return Err(Error::NotReference("the operand of the `&` operator", span)); + } + } + } + ast::Expression::Deref(expr) => { + // The pointer we dereference must be loaded. + let pointer = self.expression(expr, ctx)?; + + if resolve_inner!(ctx, pointer).pointer_space().is_none() { + return Err(Error::NotPointer(span)); + } + + // No code is generated. We just declare the pointer a reference now. + return Ok(Typed::Reference(pointer)); + } + ast::Expression::Binary { op, left, right } => { + self.binary(op, left, right, span, ctx)? + } + ast::Expression::Call { + ref function, + ref arguments, + } => { + let handle = self + .call(span, function, arguments, ctx)? + .ok_or(Error::FunctionReturnsVoid(function.span))?; + return Ok(Typed::Plain(handle)); + } + ast::Expression::Index { base, index } => { + let lowered_base = self.expression_for_reference(base, ctx)?; + let index = self.expression(index, ctx)?; + + if let Typed::Plain(handle) = lowered_base { + if resolve_inner!(ctx, handle).pointer_space().is_some() { + return Err(Error::Pointer( + "the value indexed by a `[]` subscripting expression", + ctx.ast_expressions.get_span(base), + )); + } + } + + lowered_base.map(|base| match ctx.const_access(index) { + Some(index) => crate::Expression::AccessIndex { base, index }, + None => crate::Expression::Access { base, index }, + }) + } + ast::Expression::Member { base, ref field } => { + let lowered_base = self.expression_for_reference(base, ctx)?; + + let temp_inner; + let composite_type: &crate::TypeInner = match lowered_base { + Typed::Reference(handle) => { + let inner = resolve_inner!(ctx, handle); + match *inner { + crate::TypeInner::Pointer { base, .. } => &ctx.module.types[base].inner, + crate::TypeInner::ValuePointer { + size: None, scalar, .. + } => { + temp_inner = crate::TypeInner::Scalar(scalar); + &temp_inner + } + crate::TypeInner::ValuePointer { + size: Some(size), + scalar, + .. + } => { + temp_inner = crate::TypeInner::Vector { size, scalar }; + &temp_inner + } + _ => unreachable!( + "In Typed::Reference(handle), handle must be a Naga pointer" + ), + } + } + + Typed::Plain(handle) => { + let inner = resolve_inner!(ctx, handle); + if let crate::TypeInner::Pointer { .. } + | crate::TypeInner::ValuePointer { .. } = *inner + { + return Err(Error::Pointer( + "the value accessed by a `.member` expression", + ctx.ast_expressions.get_span(base), + )); + } + inner + } + }; + + let access = match *composite_type { + crate::TypeInner::Struct { ref members, .. } => { + let index = members + .iter() + .position(|m| m.name.as_deref() == Some(field.name)) + .ok_or(Error::BadAccessor(field.span))? + as u32; + + lowered_base.map(|base| crate::Expression::AccessIndex { base, index }) + } + crate::TypeInner::Vector { .. } | crate::TypeInner::Matrix { .. } => { + match Components::new(field.name, field.span)? { + Components::Swizzle { size, pattern } => { + // Swizzles aren't allowed on matrices, but + // validation will catch that. + Typed::Plain(crate::Expression::Swizzle { + size, + vector: ctx.apply_load_rule(lowered_base)?, + pattern, + }) + } + Components::Single(index) => lowered_base + .map(|base| crate::Expression::AccessIndex { base, index }), + } + } + _ => return Err(Error::BadAccessor(field.span)), + }; + + access + } + ast::Expression::Bitcast { expr, to, ty_span } => { + let expr = self.expression(expr, ctx)?; + let to_resolved = self.resolve_ast_type(to, &mut ctx.as_global())?; + + let element_scalar = match ctx.module.types[to_resolved].inner { + crate::TypeInner::Scalar(scalar) => scalar, + crate::TypeInner::Vector { scalar, .. } => scalar, + _ => { + let ty = resolve!(ctx, expr); + let gctx = &ctx.module.to_ctx(); + return Err(Error::BadTypeCast { + from_type: ty.to_wgsl(gctx), + span: ty_span, + to_type: to_resolved.to_wgsl(gctx), + }); + } + }; + + Typed::Plain(crate::Expression::As { + expr, + kind: element_scalar.kind, + convert: None, + }) + } + }; + + expr.try_map(|handle| ctx.append_expression(handle, span)) + } + + fn binary( + &mut self, + op: crate::BinaryOperator, + left: Handle<ast::Expression<'source>>, + right: Handle<ast::Expression<'source>>, + span: Span, + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result<Typed<crate::Expression>, Error<'source>> { + // Load both operands. + let mut left = self.expression_for_abstract(left, ctx)?; + let mut right = self.expression_for_abstract(right, ctx)?; + + // Convert `scalar op vector` to `vector op vector` by introducing + // `Splat` expressions. + ctx.binary_op_splat(op, &mut left, &mut right)?; + + // Apply automatic conversions. + match op { + // Shift operators require the right operand to be `u32` or + // `vecN<u32>`. We can let the validator sort out vector length + // issues, but the right operand must be, or convert to, a u32 leaf + // scalar. + crate::BinaryOperator::ShiftLeft | crate::BinaryOperator::ShiftRight => { + right = + ctx.try_automatic_conversion_for_leaf_scalar(right, crate::Scalar::U32, span)?; + } + + // All other operators follow the same pattern: reconcile the + // scalar leaf types. If there's no reconciliation possible, + // leave the expressions as they are: validation will report the + // problem. + _ => { + ctx.grow_types(left)?; + ctx.grow_types(right)?; + if let Ok(consensus_scalar) = + ctx.automatic_conversion_consensus([left, right].iter()) + { + ctx.convert_to_leaf_scalar(&mut left, consensus_scalar)?; + ctx.convert_to_leaf_scalar(&mut right, consensus_scalar)?; + } + } + } + + Ok(Typed::Plain(crate::Expression::Binary { op, left, right })) + } + + /// Generate Naga IR for call expressions and statements, and type + /// constructor expressions. + /// + /// The "function" being called is simply an `Ident` that we know refers to + /// some module-scope definition. + /// + /// - If it is the name of a type, then the expression is a type constructor + /// expression: either constructing a value from components, a conversion + /// expression, or a zero value expression. + /// + /// - If it is the name of a function, then we're generating a [`Call`] + /// statement. We may be in the midst of generating code for an + /// expression, in which case we must generate an `Emit` statement to + /// force evaluation of the IR expressions we've generated so far, add the + /// `Call` statement to the current block, and then resume generating + /// expressions. + /// + /// [`Call`]: crate::Statement::Call + fn call( + &mut self, + span: Span, + function: &ast::Ident<'source>, + arguments: &[Handle<ast::Expression<'source>>], + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result<Option<Handle<crate::Expression>>, Error<'source>> { + match ctx.globals.get(function.name) { + Some(&LoweredGlobalDecl::Type(ty)) => { + let handle = self.construct( + span, + &ast::ConstructorType::Type(ty), + function.span, + arguments, + ctx, + )?; + Ok(Some(handle)) + } + Some(&LoweredGlobalDecl::Const(_) | &LoweredGlobalDecl::Var(_)) => { + Err(Error::Unexpected(function.span, ExpectedToken::Function)) + } + Some(&LoweredGlobalDecl::EntryPoint) => Err(Error::CalledEntryPoint(function.span)), + Some(&LoweredGlobalDecl::Function(function)) => { + let arguments = arguments + .iter() + .map(|&arg| self.expression(arg, ctx)) + .collect::<Result<Vec<_>, _>>()?; + + let has_result = ctx.module.functions[function].result.is_some(); + let rctx = ctx.runtime_expression_ctx(span)?; + // we need to always do this before a fn call since all arguments need to be emitted before the fn call + rctx.block + .extend(rctx.emitter.finish(&rctx.function.expressions)); + let result = has_result.then(|| { + rctx.function + .expressions + .append(crate::Expression::CallResult(function), span) + }); + rctx.emitter.start(&rctx.function.expressions); + rctx.block.push( + crate::Statement::Call { + function, + arguments, + result, + }, + span, + ); + + Ok(result) + } + None => { + let span = function.span; + let expr = if let Some(fun) = conv::map_relational_fun(function.name) { + let mut args = ctx.prepare_args(arguments, 1, span); + let argument = self.expression(args.next()?, ctx)?; + args.finish()?; + + // Check for no-op all(bool) and any(bool): + let argument_unmodified = matches!( + fun, + crate::RelationalFunction::All | crate::RelationalFunction::Any + ) && { + matches!( + resolve_inner!(ctx, argument), + &crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Bool, + .. + }) + ) + }; + + if argument_unmodified { + return Ok(Some(argument)); + } else { + crate::Expression::Relational { fun, argument } + } + } else if let Some((axis, ctrl)) = conv::map_derivative(function.name) { + let mut args = ctx.prepare_args(arguments, 1, span); + let expr = self.expression(args.next()?, ctx)?; + args.finish()?; + + crate::Expression::Derivative { axis, ctrl, expr } + } else if let Some(fun) = conv::map_standard_fun(function.name) { + let expected = fun.argument_count() as _; + let mut args = ctx.prepare_args(arguments, expected, span); + + let arg = self.expression(args.next()?, ctx)?; + let arg1 = args + .next() + .map(|x| self.expression(x, ctx)) + .ok() + .transpose()?; + let arg2 = args + .next() + .map(|x| self.expression(x, ctx)) + .ok() + .transpose()?; + let arg3 = args + .next() + .map(|x| self.expression(x, ctx)) + .ok() + .transpose()?; + + args.finish()?; + + if fun == crate::MathFunction::Modf || fun == crate::MathFunction::Frexp { + if let Some((size, width)) = match *resolve_inner!(ctx, arg) { + crate::TypeInner::Scalar(crate::Scalar { width, .. }) => { + Some((None, width)) + } + crate::TypeInner::Vector { + size, + scalar: crate::Scalar { width, .. }, + .. + } => Some((Some(size), width)), + _ => None, + } { + ctx.module.generate_predeclared_type( + if fun == crate::MathFunction::Modf { + crate::PredeclaredType::ModfResult { size, width } + } else { + crate::PredeclaredType::FrexpResult { size, width } + }, + ); + } + } + + crate::Expression::Math { + fun, + arg, + arg1, + arg2, + arg3, + } + } else if let Some(fun) = Texture::map(function.name) { + self.texture_sample_helper(fun, arguments, span, ctx)? + } else { + match function.name { + "select" => { + let mut args = ctx.prepare_args(arguments, 3, span); + + let reject = self.expression(args.next()?, ctx)?; + let accept = self.expression(args.next()?, ctx)?; + let condition = self.expression(args.next()?, ctx)?; + + args.finish()?; + + crate::Expression::Select { + reject, + accept, + condition, + } + } + "arrayLength" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let expr = self.expression(args.next()?, ctx)?; + args.finish()?; + + crate::Expression::ArrayLength(expr) + } + "atomicLoad" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let pointer = self.atomic_pointer(args.next()?, ctx)?; + args.finish()?; + + crate::Expression::Load { pointer } + } + "atomicStore" => { + let mut args = ctx.prepare_args(arguments, 2, span); + let pointer = self.atomic_pointer(args.next()?, ctx)?; + let value = self.expression(args.next()?, ctx)?; + args.finish()?; + + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .extend(rctx.emitter.finish(&rctx.function.expressions)); + rctx.emitter.start(&rctx.function.expressions); + rctx.block + .push(crate::Statement::Store { pointer, value }, span); + return Ok(None); + } + "atomicAdd" => { + return Ok(Some(self.atomic_helper( + span, + crate::AtomicFunction::Add, + arguments, + ctx, + )?)) + } + "atomicSub" => { + return Ok(Some(self.atomic_helper( + span, + crate::AtomicFunction::Subtract, + arguments, + ctx, + )?)) + } + "atomicAnd" => { + return Ok(Some(self.atomic_helper( + span, + crate::AtomicFunction::And, + arguments, + ctx, + )?)) + } + "atomicOr" => { + return Ok(Some(self.atomic_helper( + span, + crate::AtomicFunction::InclusiveOr, + arguments, + ctx, + )?)) + } + "atomicXor" => { + return Ok(Some(self.atomic_helper( + span, + crate::AtomicFunction::ExclusiveOr, + arguments, + ctx, + )?)) + } + "atomicMin" => { + return Ok(Some(self.atomic_helper( + span, + crate::AtomicFunction::Min, + arguments, + ctx, + )?)) + } + "atomicMax" => { + return Ok(Some(self.atomic_helper( + span, + crate::AtomicFunction::Max, + arguments, + ctx, + )?)) + } + "atomicExchange" => { + return Ok(Some(self.atomic_helper( + span, + crate::AtomicFunction::Exchange { compare: None }, + arguments, + ctx, + )?)) + } + "atomicCompareExchangeWeak" => { + let mut args = ctx.prepare_args(arguments, 3, span); + + let pointer = self.atomic_pointer(args.next()?, ctx)?; + + let compare = self.expression(args.next()?, ctx)?; + + let value = args.next()?; + let value_span = ctx.ast_expressions.get_span(value); + let value = self.expression(value, ctx)?; + + args.finish()?; + + let expression = match *resolve_inner!(ctx, value) { + crate::TypeInner::Scalar(scalar) => { + crate::Expression::AtomicResult { + ty: ctx.module.generate_predeclared_type( + crate::PredeclaredType::AtomicCompareExchangeWeakResult( + scalar, + ), + ), + comparison: true, + } + } + _ => return Err(Error::InvalidAtomicOperandType(value_span)), + }; + + let result = ctx.interrupt_emitter(expression, span)?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::Atomic { + pointer, + fun: crate::AtomicFunction::Exchange { + compare: Some(compare), + }, + value, + result, + }, + span, + ); + return Ok(Some(result)); + } + "storageBarrier" => { + ctx.prepare_args(arguments, 0, span).finish()?; + + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .push(crate::Statement::Barrier(crate::Barrier::STORAGE), span); + return Ok(None); + } + "workgroupBarrier" => { + ctx.prepare_args(arguments, 0, span).finish()?; + + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .push(crate::Statement::Barrier(crate::Barrier::WORK_GROUP), span); + return Ok(None); + } + "workgroupUniformLoad" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let expr = args.next()?; + args.finish()?; + + let pointer = self.expression(expr, ctx)?; + let result_ty = match *resolve_inner!(ctx, pointer) { + crate::TypeInner::Pointer { + base, + space: crate::AddressSpace::WorkGroup, + } => base, + ref other => { + log::error!("Type {other:?} passed to workgroupUniformLoad"); + let span = ctx.ast_expressions.get_span(expr); + return Err(Error::InvalidWorkGroupUniformLoad(span)); + } + }; + let result = ctx.interrupt_emitter( + crate::Expression::WorkGroupUniformLoadResult { ty: result_ty }, + span, + )?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::WorkGroupUniformLoad { pointer, result }, + span, + ); + + return Ok(Some(result)); + } + "textureStore" => { + let mut args = ctx.prepare_args(arguments, 3, span); + + let image = args.next()?; + let image_span = ctx.ast_expressions.get_span(image); + let image = self.expression(image, ctx)?; + + let coordinate = self.expression(args.next()?, ctx)?; + + let (_, arrayed) = ctx.image_data(image, image_span)?; + let array_index = arrayed + .then(|| { + args.min_args += 1; + self.expression(args.next()?, ctx) + }) + .transpose()?; + + let value = self.expression(args.next()?, ctx)?; + + args.finish()?; + + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .extend(rctx.emitter.finish(&rctx.function.expressions)); + rctx.emitter.start(&rctx.function.expressions); + let stmt = crate::Statement::ImageStore { + image, + coordinate, + array_index, + value, + }; + rctx.block.push(stmt, span); + return Ok(None); + } + "textureLoad" => { + let mut args = ctx.prepare_args(arguments, 2, span); + + let image = args.next()?; + let image_span = ctx.ast_expressions.get_span(image); + let image = self.expression(image, ctx)?; + + let coordinate = self.expression(args.next()?, ctx)?; + + let (class, arrayed) = ctx.image_data(image, image_span)?; + let array_index = arrayed + .then(|| { + args.min_args += 1; + self.expression(args.next()?, ctx) + }) + .transpose()?; + + let level = class + .is_mipmapped() + .then(|| { + args.min_args += 1; + self.expression(args.next()?, ctx) + }) + .transpose()?; + + let sample = class + .is_multisampled() + .then(|| self.expression(args.next()?, ctx)) + .transpose()?; + + args.finish()?; + + crate::Expression::ImageLoad { + image, + coordinate, + array_index, + level, + sample, + } + } + "textureDimensions" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let image = self.expression(args.next()?, ctx)?; + let level = args + .next() + .map(|arg| self.expression(arg, ctx)) + .ok() + .transpose()?; + args.finish()?; + + crate::Expression::ImageQuery { + image, + query: crate::ImageQuery::Size { level }, + } + } + "textureNumLevels" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let image = self.expression(args.next()?, ctx)?; + args.finish()?; + + crate::Expression::ImageQuery { + image, + query: crate::ImageQuery::NumLevels, + } + } + "textureNumLayers" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let image = self.expression(args.next()?, ctx)?; + args.finish()?; + + crate::Expression::ImageQuery { + image, + query: crate::ImageQuery::NumLayers, + } + } + "textureNumSamples" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let image = self.expression(args.next()?, ctx)?; + args.finish()?; + + crate::Expression::ImageQuery { + image, + query: crate::ImageQuery::NumSamples, + } + } + "rayQueryInitialize" => { + let mut args = ctx.prepare_args(arguments, 3, span); + let query = self.ray_query_pointer(args.next()?, ctx)?; + let acceleration_structure = self.expression(args.next()?, ctx)?; + let descriptor = self.expression(args.next()?, ctx)?; + args.finish()?; + + let _ = ctx.module.generate_ray_desc_type(); + let fun = crate::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + }; + + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .extend(rctx.emitter.finish(&rctx.function.expressions)); + rctx.emitter.start(&rctx.function.expressions); + rctx.block + .push(crate::Statement::RayQuery { query, fun }, span); + return Ok(None); + } + "rayQueryProceed" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let query = self.ray_query_pointer(args.next()?, ctx)?; + args.finish()?; + + let result = ctx.interrupt_emitter( + crate::Expression::RayQueryProceedResult, + span, + )?; + let fun = crate::RayQueryFunction::Proceed { result }; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .push(crate::Statement::RayQuery { query, fun }, span); + return Ok(Some(result)); + } + "rayQueryGetCommittedIntersection" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let query = self.ray_query_pointer(args.next()?, ctx)?; + args.finish()?; + + let _ = ctx.module.generate_ray_intersection_type(); + + crate::Expression::RayQueryGetIntersection { + query, + committed: true, + } + } + "RayDesc" => { + let ty = ctx.module.generate_ray_desc_type(); + let handle = self.construct( + span, + &ast::ConstructorType::Type(ty), + function.span, + arguments, + ctx, + )?; + return Ok(Some(handle)); + } + _ => return Err(Error::UnknownIdent(function.span, function.name)), + } + }; + + let expr = ctx.append_expression(expr, span)?; + Ok(Some(expr)) + } + } + } + + fn atomic_pointer( + &mut self, + expr: Handle<ast::Expression<'source>>, + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result<Handle<crate::Expression>, Error<'source>> { + let span = ctx.ast_expressions.get_span(expr); + let pointer = self.expression(expr, ctx)?; + + match *resolve_inner!(ctx, pointer) { + crate::TypeInner::Pointer { base, .. } => match ctx.module.types[base].inner { + crate::TypeInner::Atomic { .. } => Ok(pointer), + ref other => { + log::error!("Pointer type to {:?} passed to atomic op", other); + Err(Error::InvalidAtomicPointer(span)) + } + }, + ref other => { + log::error!("Type {:?} passed to atomic op", other); + Err(Error::InvalidAtomicPointer(span)) + } + } + } + + fn atomic_helper( + &mut self, + span: Span, + fun: crate::AtomicFunction, + args: &[Handle<ast::Expression<'source>>], + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result<Handle<crate::Expression>, Error<'source>> { + let mut args = ctx.prepare_args(args, 2, span); + + let pointer = self.atomic_pointer(args.next()?, ctx)?; + + let value = args.next()?; + let value = self.expression(value, ctx)?; + let ty = ctx.register_type(value)?; + + args.finish()?; + + let result = ctx.interrupt_emitter( + crate::Expression::AtomicResult { + ty, + comparison: false, + }, + span, + )?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::Atomic { + pointer, + fun, + value, + result, + }, + span, + ); + Ok(result) + } + + fn texture_sample_helper( + &mut self, + fun: Texture, + args: &[Handle<ast::Expression<'source>>], + span: Span, + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result<crate::Expression, Error<'source>> { + let mut args = ctx.prepare_args(args, fun.min_argument_count(), span); + + fn get_image_and_span<'source>( + lowerer: &mut Lowerer<'source, '_>, + args: &mut ArgumentContext<'_, 'source>, + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result<(Handle<crate::Expression>, Span), Error<'source>> { + let image = args.next()?; + let image_span = ctx.ast_expressions.get_span(image); + let image = lowerer.expression(image, ctx)?; + Ok((image, image_span)) + } + + let (image, image_span, gather) = match fun { + Texture::Gather => { + let image_or_component = args.next()?; + let image_or_component_span = ctx.ast_expressions.get_span(image_or_component); + // Gathers from depth textures don't take an initial `component` argument. + let lowered_image_or_component = self.expression(image_or_component, ctx)?; + + match *resolve_inner!(ctx, lowered_image_or_component) { + crate::TypeInner::Image { + class: crate::ImageClass::Depth { .. }, + .. + } => ( + lowered_image_or_component, + image_or_component_span, + Some(crate::SwizzleComponent::X), + ), + _ => { + let (image, image_span) = get_image_and_span(self, &mut args, ctx)?; + ( + image, + image_span, + Some(ctx.gather_component( + lowered_image_or_component, + image_or_component_span, + span, + )?), + ) + } + } + } + Texture::GatherCompare => { + let (image, image_span) = get_image_and_span(self, &mut args, ctx)?; + (image, image_span, Some(crate::SwizzleComponent::X)) + } + + _ => { + let (image, image_span) = get_image_and_span(self, &mut args, ctx)?; + (image, image_span, None) + } + }; + + let sampler = self.expression(args.next()?, ctx)?; + + let coordinate = self.expression(args.next()?, ctx)?; + + let (_, arrayed) = ctx.image_data(image, image_span)?; + let array_index = arrayed + .then(|| self.expression(args.next()?, ctx)) + .transpose()?; + + let (level, depth_ref) = match fun { + Texture::Gather => (crate::SampleLevel::Zero, None), + Texture::GatherCompare => { + let reference = self.expression(args.next()?, ctx)?; + (crate::SampleLevel::Zero, Some(reference)) + } + + Texture::Sample => (crate::SampleLevel::Auto, None), + Texture::SampleBias => { + let bias = self.expression(args.next()?, ctx)?; + (crate::SampleLevel::Bias(bias), None) + } + Texture::SampleCompare => { + let reference = self.expression(args.next()?, ctx)?; + (crate::SampleLevel::Auto, Some(reference)) + } + Texture::SampleCompareLevel => { + let reference = self.expression(args.next()?, ctx)?; + (crate::SampleLevel::Zero, Some(reference)) + } + Texture::SampleGrad => { + let x = self.expression(args.next()?, ctx)?; + let y = self.expression(args.next()?, ctx)?; + (crate::SampleLevel::Gradient { x, y }, None) + } + Texture::SampleLevel => { + let level = self.expression(args.next()?, ctx)?; + (crate::SampleLevel::Exact(level), None) + } + }; + + let offset = args + .next() + .map(|arg| self.expression(arg, &mut ctx.as_const())) + .ok() + .transpose()?; + + args.finish()?; + + Ok(crate::Expression::ImageSample { + image, + sampler, + gather, + coordinate, + array_index, + offset, + level, + depth_ref, + }) + } + + fn r#struct( + &mut self, + s: &ast::Struct<'source>, + span: Span, + ctx: &mut GlobalContext<'source, '_, '_>, + ) -> Result<Handle<crate::Type>, Error<'source>> { + let mut offset = 0; + let mut struct_alignment = Alignment::ONE; + let mut members = Vec::with_capacity(s.members.len()); + + for member in s.members.iter() { + let ty = self.resolve_ast_type(member.ty, ctx)?; + + self.layouter.update(ctx.module.to_ctx()).unwrap(); + + let member_min_size = self.layouter[ty].size; + let member_min_alignment = self.layouter[ty].alignment; + + let member_size = if let Some(size_expr) = member.size { + let (size, span) = self.const_u32(size_expr, &mut ctx.as_const())?; + if size < member_min_size { + return Err(Error::SizeAttributeTooLow(span, member_min_size)); + } else { + size + } + } else { + member_min_size + }; + + let member_alignment = if let Some(align_expr) = member.align { + let (align, span) = self.const_u32(align_expr, &mut ctx.as_const())?; + if let Some(alignment) = Alignment::new(align) { + if alignment < member_min_alignment { + return Err(Error::AlignAttributeTooLow(span, member_min_alignment)); + } else { + alignment + } + } else { + return Err(Error::NonPowerOfTwoAlignAttribute(span)); + } + } else { + member_min_alignment + }; + + let binding = self.binding(&member.binding, ty, ctx)?; + + offset = member_alignment.round_up(offset); + struct_alignment = struct_alignment.max(member_alignment); + + members.push(crate::StructMember { + name: Some(member.name.name.to_owned()), + ty, + binding, + offset, + }); + + offset += member_size; + } + + let size = struct_alignment.round_up(offset); + let inner = crate::TypeInner::Struct { + members, + span: size, + }; + + let handle = ctx.module.types.insert( + crate::Type { + name: Some(s.name.name.to_string()), + inner, + }, + span, + ); + Ok(handle) + } + + fn const_u32( + &mut self, + expr: Handle<ast::Expression<'source>>, + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result<(u32, Span), Error<'source>> { + let span = ctx.ast_expressions.get_span(expr); + let expr = self.expression(expr, ctx)?; + let value = ctx + .module + .to_ctx() + .eval_expr_to_u32(expr) + .map_err(|err| match err { + crate::proc::U32EvalError::NonConst => { + Error::ExpectedConstExprConcreteIntegerScalar(span) + } + crate::proc::U32EvalError::Negative => Error::ExpectedNonNegative(span), + })?; + Ok((value, span)) + } + + fn array_size( + &mut self, + size: ast::ArraySize<'source>, + ctx: &mut GlobalContext<'source, '_, '_>, + ) -> Result<crate::ArraySize, Error<'source>> { + Ok(match size { + ast::ArraySize::Constant(expr) => { + let span = ctx.ast_expressions.get_span(expr); + let const_expr = self.expression(expr, &mut ctx.as_const())?; + let len = + ctx.module + .to_ctx() + .eval_expr_to_u32(const_expr) + .map_err(|err| match err { + crate::proc::U32EvalError::NonConst => { + Error::ExpectedConstExprConcreteIntegerScalar(span) + } + crate::proc::U32EvalError::Negative => { + Error::ExpectedPositiveArrayLength(span) + } + })?; + let size = NonZeroU32::new(len).ok_or(Error::ExpectedPositiveArrayLength(span))?; + crate::ArraySize::Constant(size) + } + ast::ArraySize::Dynamic => crate::ArraySize::Dynamic, + }) + } + + /// Build the Naga equivalent of a named AST type. + /// + /// Return a Naga `Handle<Type>` representing the front-end type + /// `handle`, which should be named `name`, if given. + /// + /// If `handle` refers to a type cached in [`SpecialTypes`], + /// `name` may be ignored. + /// + /// [`SpecialTypes`]: crate::SpecialTypes + fn resolve_named_ast_type( + &mut self, + handle: Handle<ast::Type<'source>>, + name: Option<String>, + ctx: &mut GlobalContext<'source, '_, '_>, + ) -> Result<Handle<crate::Type>, Error<'source>> { + let inner = match ctx.types[handle] { + ast::Type::Scalar(scalar) => scalar.to_inner_scalar(), + ast::Type::Vector { size, scalar } => scalar.to_inner_vector(size), + ast::Type::Matrix { + rows, + columns, + width, + } => crate::TypeInner::Matrix { + columns, + rows, + scalar: crate::Scalar::float(width), + }, + ast::Type::Atomic(scalar) => scalar.to_inner_atomic(), + ast::Type::Pointer { base, space } => { + let base = self.resolve_ast_type(base, ctx)?; + crate::TypeInner::Pointer { base, space } + } + ast::Type::Array { base, size } => { + let base = self.resolve_ast_type(base, ctx)?; + let size = self.array_size(size, ctx)?; + + self.layouter.update(ctx.module.to_ctx()).unwrap(); + let stride = self.layouter[base].to_stride(); + + crate::TypeInner::Array { base, size, stride } + } + ast::Type::Image { + dim, + arrayed, + class, + } => crate::TypeInner::Image { + dim, + arrayed, + class, + }, + ast::Type::Sampler { comparison } => crate::TypeInner::Sampler { comparison }, + ast::Type::AccelerationStructure => crate::TypeInner::AccelerationStructure, + ast::Type::RayQuery => crate::TypeInner::RayQuery, + ast::Type::BindingArray { base, size } => { + let base = self.resolve_ast_type(base, ctx)?; + let size = self.array_size(size, ctx)?; + crate::TypeInner::BindingArray { base, size } + } + ast::Type::RayDesc => { + return Ok(ctx.module.generate_ray_desc_type()); + } + ast::Type::RayIntersection => { + return Ok(ctx.module.generate_ray_intersection_type()); + } + ast::Type::User(ref ident) => { + return match ctx.globals.get(ident.name) { + Some(&LoweredGlobalDecl::Type(handle)) => Ok(handle), + Some(_) => Err(Error::Unexpected(ident.span, ExpectedToken::Type)), + None => Err(Error::UnknownType(ident.span)), + } + } + }; + + Ok(ctx.ensure_type_exists(name, inner)) + } + + /// Return a Naga `Handle<Type>` representing the front-end type `handle`. + fn resolve_ast_type( + &mut self, + handle: Handle<ast::Type<'source>>, + ctx: &mut GlobalContext<'source, '_, '_>, + ) -> Result<Handle<crate::Type>, Error<'source>> { + self.resolve_named_ast_type(handle, None, ctx) + } + + fn binding( + &mut self, + binding: &Option<ast::Binding<'source>>, + ty: Handle<crate::Type>, + ctx: &mut GlobalContext<'source, '_, '_>, + ) -> Result<Option<crate::Binding>, Error<'source>> { + Ok(match *binding { + Some(ast::Binding::BuiltIn(b)) => Some(crate::Binding::BuiltIn(b)), + Some(ast::Binding::Location { + location, + second_blend_source, + interpolation, + sampling, + }) => { + let mut binding = crate::Binding::Location { + location: self.const_u32(location, &mut ctx.as_const())?.0, + second_blend_source, + interpolation, + sampling, + }; + binding.apply_default_interpolation(&ctx.module.types[ty].inner); + Some(binding) + } + None => None, + }) + } + + fn ray_query_pointer( + &mut self, + expr: Handle<ast::Expression<'source>>, + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result<Handle<crate::Expression>, Error<'source>> { + let span = ctx.ast_expressions.get_span(expr); + let pointer = self.expression(expr, ctx)?; + + match *resolve_inner!(ctx, pointer) { + crate::TypeInner::Pointer { base, .. } => match ctx.module.types[base].inner { + crate::TypeInner::RayQuery => Ok(pointer), + ref other => { + log::error!("Pointer type to {:?} passed to ray query op", other); + Err(Error::InvalidRayQueryPointer(span)) + } + }, + ref other => { + log::error!("Type {:?} passed to ray query op", other); + Err(Error::InvalidRayQueryPointer(span)) + } + } + } +} diff --git a/third_party/rust/naga/src/front/wgsl/mod.rs b/third_party/rust/naga/src/front/wgsl/mod.rs new file mode 100644 index 0000000000..b6151fe1c0 --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/mod.rs @@ -0,0 +1,49 @@ +/*! +Frontend for [WGSL][wgsl] (WebGPU Shading Language). + +[wgsl]: https://gpuweb.github.io/gpuweb/wgsl.html +*/ + +mod error; +mod index; +mod lower; +mod parse; +#[cfg(test)] +mod tests; +mod to_wgsl; + +use crate::front::wgsl::error::Error; +use crate::front::wgsl::parse::Parser; +use thiserror::Error; + +pub use crate::front::wgsl::error::ParseError; +use crate::front::wgsl::lower::Lowerer; +use crate::Scalar; + +pub struct Frontend { + parser: Parser, +} + +impl Frontend { + pub const fn new() -> Self { + Self { + parser: Parser::new(), + } + } + + pub fn parse(&mut self, source: &str) -> Result<crate::Module, ParseError> { + self.inner(source).map_err(|x| x.as_parse_error(source)) + } + + fn inner<'a>(&mut self, source: &'a str) -> Result<crate::Module, Error<'a>> { + let tu = self.parser.parse(source)?; + let index = index::Index::generate(&tu)?; + let module = Lowerer::new(&index).lower(&tu)?; + + Ok(module) + } +} + +pub fn parse_str(source: &str) -> Result<crate::Module, ParseError> { + Frontend::new().parse(source) +} diff --git a/third_party/rust/naga/src/front/wgsl/parse/ast.rs b/third_party/rust/naga/src/front/wgsl/parse/ast.rs new file mode 100644 index 0000000000..dbaac523cb --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/parse/ast.rs @@ -0,0 +1,491 @@ +use crate::front::wgsl::parse::number::Number; +use crate::front::wgsl::Scalar; +use crate::{Arena, FastIndexSet, Handle, Span}; +use std::hash::Hash; + +#[derive(Debug, Default)] +pub struct TranslationUnit<'a> { + pub decls: Arena<GlobalDecl<'a>>, + /// The common expressions arena for the entire translation unit. + /// + /// All functions, global initializers, array lengths, etc. store their + /// expressions here. We apportion these out to individual Naga + /// [`Function`]s' expression arenas at lowering time. Keeping them all in a + /// single arena simplifies handling of things like array lengths (which are + /// effectively global and thus don't clearly belong to any function) and + /// initializers (which can appear in both function-local and module-scope + /// contexts). + /// + /// [`Function`]: crate::Function + pub expressions: Arena<Expression<'a>>, + + /// Non-user-defined types, like `vec4<f32>` or `array<i32, 10>`. + /// + /// These are referred to by `Handle<ast::Type<'a>>` values. + /// User-defined types are referred to by name until lowering. + pub types: Arena<Type<'a>>, +} + +#[derive(Debug, Clone, Copy)] +pub struct Ident<'a> { + pub name: &'a str, + pub span: Span, +} + +#[derive(Debug)] +pub enum IdentExpr<'a> { + Unresolved(&'a str), + Local(Handle<Local>), +} + +/// A reference to a module-scope definition or predeclared object. +/// +/// Each [`GlobalDecl`] holds a set of these values, to be resolved to +/// specific definitions later. To support de-duplication, `Eq` and +/// `Hash` on a `Dependency` value consider only the name, not the +/// source location at which the reference occurs. +#[derive(Debug)] +pub struct Dependency<'a> { + /// The name referred to. + pub ident: &'a str, + + /// The location at which the reference to that name occurs. + pub usage: Span, +} + +impl Hash for Dependency<'_> { + fn hash<H: std::hash::Hasher>(&self, state: &mut H) { + self.ident.hash(state); + } +} + +impl PartialEq for Dependency<'_> { + fn eq(&self, other: &Self) -> bool { + self.ident == other.ident + } +} + +impl Eq for Dependency<'_> {} + +/// A module-scope declaration. +#[derive(Debug)] +pub struct GlobalDecl<'a> { + pub kind: GlobalDeclKind<'a>, + + /// Names of all module-scope or predeclared objects this + /// declaration uses. + pub dependencies: FastIndexSet<Dependency<'a>>, +} + +#[derive(Debug)] +pub enum GlobalDeclKind<'a> { + Fn(Function<'a>), + Var(GlobalVariable<'a>), + Const(Const<'a>), + Struct(Struct<'a>), + Type(TypeAlias<'a>), +} + +#[derive(Debug)] +pub struct FunctionArgument<'a> { + pub name: Ident<'a>, + pub ty: Handle<Type<'a>>, + pub binding: Option<Binding<'a>>, + pub handle: Handle<Local>, +} + +#[derive(Debug)] +pub struct FunctionResult<'a> { + pub ty: Handle<Type<'a>>, + pub binding: Option<Binding<'a>>, +} + +#[derive(Debug)] +pub struct EntryPoint<'a> { + pub stage: crate::ShaderStage, + pub early_depth_test: Option<crate::EarlyDepthTest>, + pub workgroup_size: Option<[Option<Handle<Expression<'a>>>; 3]>, +} + +#[cfg(doc)] +use crate::front::wgsl::lower::{RuntimeExpressionContext, StatementContext}; + +#[derive(Debug)] +pub struct Function<'a> { + pub entry_point: Option<EntryPoint<'a>>, + pub name: Ident<'a>, + pub arguments: Vec<FunctionArgument<'a>>, + pub result: Option<FunctionResult<'a>>, + + /// Local variable and function argument arena. + /// + /// Note that the `Local` here is actually a zero-sized type. The AST keeps + /// all the detailed information about locals - names, types, etc. - in + /// [`LocalDecl`] statements. For arguments, that information is kept in + /// [`arguments`]. This `Arena`'s only role is to assign a unique `Handle` + /// to each of them, and track their definitions' spans for use in + /// diagnostics. + /// + /// In the AST, when an [`Ident`] expression refers to a local variable or + /// argument, its [`IdentExpr`] holds the referent's `Handle<Local>` in this + /// arena. + /// + /// During lowering, [`LocalDecl`] statements add entries to a per-function + /// table that maps `Handle<Local>` values to their Naga representations, + /// accessed via [`StatementContext::local_table`] and + /// [`RuntimeExpressionContext::local_table`]. This table is then consulted when + /// lowering subsequent [`Ident`] expressions. + /// + /// [`LocalDecl`]: StatementKind::LocalDecl + /// [`arguments`]: Function::arguments + /// [`Ident`]: Expression::Ident + /// [`StatementContext::local_table`]: StatementContext::local_table + /// [`RuntimeExpressionContext::local_table`]: RuntimeExpressionContext::local_table + pub locals: Arena<Local>, + + pub body: Block<'a>, +} + +#[derive(Debug)] +pub enum Binding<'a> { + BuiltIn(crate::BuiltIn), + Location { + location: Handle<Expression<'a>>, + second_blend_source: bool, + interpolation: Option<crate::Interpolation>, + sampling: Option<crate::Sampling>, + }, +} + +#[derive(Debug)] +pub struct ResourceBinding<'a> { + pub group: Handle<Expression<'a>>, + pub binding: Handle<Expression<'a>>, +} + +#[derive(Debug)] +pub struct GlobalVariable<'a> { + pub name: Ident<'a>, + pub space: crate::AddressSpace, + pub binding: Option<ResourceBinding<'a>>, + pub ty: Handle<Type<'a>>, + pub init: Option<Handle<Expression<'a>>>, +} + +#[derive(Debug)] +pub struct StructMember<'a> { + pub name: Ident<'a>, + pub ty: Handle<Type<'a>>, + pub binding: Option<Binding<'a>>, + pub align: Option<Handle<Expression<'a>>>, + pub size: Option<Handle<Expression<'a>>>, +} + +#[derive(Debug)] +pub struct Struct<'a> { + pub name: Ident<'a>, + pub members: Vec<StructMember<'a>>, +} + +#[derive(Debug)] +pub struct TypeAlias<'a> { + pub name: Ident<'a>, + pub ty: Handle<Type<'a>>, +} + +#[derive(Debug)] +pub struct Const<'a> { + pub name: Ident<'a>, + pub ty: Option<Handle<Type<'a>>>, + pub init: Handle<Expression<'a>>, +} + +/// The size of an [`Array`] or [`BindingArray`]. +/// +/// [`Array`]: Type::Array +/// [`BindingArray`]: Type::BindingArray +#[derive(Debug, Copy, Clone)] +pub enum ArraySize<'a> { + /// The length as a constant expression. + Constant(Handle<Expression<'a>>), + Dynamic, +} + +#[derive(Debug)] +pub enum Type<'a> { + Scalar(Scalar), + Vector { + size: crate::VectorSize, + scalar: Scalar, + }, + Matrix { + columns: crate::VectorSize, + rows: crate::VectorSize, + width: crate::Bytes, + }, + Atomic(Scalar), + Pointer { + base: Handle<Type<'a>>, + space: crate::AddressSpace, + }, + Array { + base: Handle<Type<'a>>, + size: ArraySize<'a>, + }, + Image { + dim: crate::ImageDimension, + arrayed: bool, + class: crate::ImageClass, + }, + Sampler { + comparison: bool, + }, + AccelerationStructure, + RayQuery, + RayDesc, + RayIntersection, + BindingArray { + base: Handle<Type<'a>>, + size: ArraySize<'a>, + }, + + /// A user-defined type, like a struct or a type alias. + User(Ident<'a>), +} + +#[derive(Debug, Default)] +pub struct Block<'a> { + pub stmts: Vec<Statement<'a>>, +} + +#[derive(Debug)] +pub struct Statement<'a> { + pub kind: StatementKind<'a>, + pub span: Span, +} + +#[derive(Debug)] +pub enum StatementKind<'a> { + LocalDecl(LocalDecl<'a>), + Block(Block<'a>), + If { + condition: Handle<Expression<'a>>, + accept: Block<'a>, + reject: Block<'a>, + }, + Switch { + selector: Handle<Expression<'a>>, + cases: Vec<SwitchCase<'a>>, + }, + Loop { + body: Block<'a>, + continuing: Block<'a>, + break_if: Option<Handle<Expression<'a>>>, + }, + Break, + Continue, + Return { + value: Option<Handle<Expression<'a>>>, + }, + Kill, + Call { + function: Ident<'a>, + arguments: Vec<Handle<Expression<'a>>>, + }, + Assign { + target: Handle<Expression<'a>>, + op: Option<crate::BinaryOperator>, + value: Handle<Expression<'a>>, + }, + Increment(Handle<Expression<'a>>), + Decrement(Handle<Expression<'a>>), + Ignore(Handle<Expression<'a>>), +} + +#[derive(Debug)] +pub enum SwitchValue<'a> { + Expr(Handle<Expression<'a>>), + Default, +} + +#[derive(Debug)] +pub struct SwitchCase<'a> { + pub value: SwitchValue<'a>, + pub body: Block<'a>, + pub fall_through: bool, +} + +/// A type at the head of a [`Construct`] expression. +/// +/// WGSL has two types of [`type constructor expressions`]: +/// +/// - Those that fully specify the type being constructed, like +/// `vec3<f32>(x,y,z)`, which obviously constructs a `vec3<f32>`. +/// +/// - Those that leave the component type of the composite being constructed +/// implicit, to be inferred from the argument types, like `vec3(x,y,z)`, +/// which constructs a `vec3<T>` where `T` is the type of `x`, `y`, and `z`. +/// +/// This enum represents the head type of both cases. The `PartialFoo` variants +/// represent the second case, where the component type is implicit. +/// +/// This does not cover structs or types referred to by type aliases. See the +/// documentation for [`Construct`] and [`Call`] expressions for details. +/// +/// [`Construct`]: Expression::Construct +/// [`type constructor expressions`]: https://gpuweb.github.io/gpuweb/wgsl/#type-constructor-expr +/// [`Call`]: Expression::Call +#[derive(Debug)] +pub enum ConstructorType<'a> { + /// A scalar type or conversion: `f32(1)`. + Scalar(Scalar), + + /// A vector construction whose component type is inferred from the + /// argument: `vec3(1.0)`. + PartialVector { size: crate::VectorSize }, + + /// A vector construction whose component type is written out: + /// `vec3<f32>(1.0)`. + Vector { + size: crate::VectorSize, + scalar: Scalar, + }, + + /// A matrix construction whose component type is inferred from the + /// argument: `mat2x2(1,2,3,4)`. + PartialMatrix { + columns: crate::VectorSize, + rows: crate::VectorSize, + }, + + /// A matrix construction whose component type is written out: + /// `mat2x2<f32>(1,2,3,4)`. + Matrix { + columns: crate::VectorSize, + rows: crate::VectorSize, + width: crate::Bytes, + }, + + /// An array whose component type and size are inferred from the arguments: + /// `array(3,4,5)`. + PartialArray, + + /// An array whose component type and size are written out: + /// `array<u32, 4>(3,4,5)`. + Array { + base: Handle<Type<'a>>, + size: ArraySize<'a>, + }, + + /// Constructing a value of a known Naga IR type. + /// + /// This variant is produced only during lowering, when we have Naga types + /// available, never during parsing. + Type(Handle<crate::Type>), +} + +#[derive(Debug, Copy, Clone)] +pub enum Literal { + Bool(bool), + Number(Number), +} + +#[cfg(doc)] +use crate::front::wgsl::lower::Lowerer; + +#[derive(Debug)] +pub enum Expression<'a> { + Literal(Literal), + Ident(IdentExpr<'a>), + + /// A type constructor expression. + /// + /// This is only used for expressions like `KEYWORD(EXPR...)` and + /// `KEYWORD<PARAM>(EXPR...)`, where `KEYWORD` is a [type-defining keyword] like + /// `vec3`. These keywords cannot be shadowed by user definitions, so we can + /// tell that such an expression is a construction immediately. + /// + /// For ordinary identifiers, we can't tell whether an expression like + /// `IDENTIFIER(EXPR, ...)` is a construction expression or a function call + /// until we know `IDENTIFIER`'s definition, so we represent those as + /// [`Call`] expressions. + /// + /// [type-defining keyword]: https://gpuweb.github.io/gpuweb/wgsl/#type-defining-keywords + /// [`Call`]: Expression::Call + Construct { + ty: ConstructorType<'a>, + ty_span: Span, + components: Vec<Handle<Expression<'a>>>, + }, + Unary { + op: crate::UnaryOperator, + expr: Handle<Expression<'a>>, + }, + AddrOf(Handle<Expression<'a>>), + Deref(Handle<Expression<'a>>), + Binary { + op: crate::BinaryOperator, + left: Handle<Expression<'a>>, + right: Handle<Expression<'a>>, + }, + + /// A function call or type constructor expression. + /// + /// We can't tell whether an expression like `IDENTIFIER(EXPR, ...)` is a + /// construction expression or a function call until we know `IDENTIFIER`'s + /// definition, so we represent everything of that form as one of these + /// expressions until lowering. At that point, [`Lowerer::call`] has + /// everything's definition in hand, and can decide whether to emit a Naga + /// [`Constant`], [`As`], [`Splat`], or [`Compose`] expression. + /// + /// [`Lowerer::call`]: Lowerer::call + /// [`Constant`]: crate::Expression::Constant + /// [`As`]: crate::Expression::As + /// [`Splat`]: crate::Expression::Splat + /// [`Compose`]: crate::Expression::Compose + Call { + function: Ident<'a>, + arguments: Vec<Handle<Expression<'a>>>, + }, + Index { + base: Handle<Expression<'a>>, + index: Handle<Expression<'a>>, + }, + Member { + base: Handle<Expression<'a>>, + field: Ident<'a>, + }, + Bitcast { + expr: Handle<Expression<'a>>, + to: Handle<Type<'a>>, + ty_span: Span, + }, +} + +#[derive(Debug)] +pub struct LocalVariable<'a> { + pub name: Ident<'a>, + pub ty: Option<Handle<Type<'a>>>, + pub init: Option<Handle<Expression<'a>>>, + pub handle: Handle<Local>, +} + +#[derive(Debug)] +pub struct Let<'a> { + pub name: Ident<'a>, + pub ty: Option<Handle<Type<'a>>>, + pub init: Handle<Expression<'a>>, + pub handle: Handle<Local>, +} + +#[derive(Debug)] +pub enum LocalDecl<'a> { + Var(LocalVariable<'a>), + Let(Let<'a>), +} + +#[derive(Debug)] +/// A placeholder for a local variable declaration. +/// +/// See [`Function::locals`] for more information. +pub struct Local; diff --git a/third_party/rust/naga/src/front/wgsl/parse/conv.rs b/third_party/rust/naga/src/front/wgsl/parse/conv.rs new file mode 100644 index 0000000000..08f1e39285 --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/parse/conv.rs @@ -0,0 +1,254 @@ +use super::Error; +use crate::front::wgsl::Scalar; +use crate::Span; + +pub fn map_address_space(word: &str, span: Span) -> Result<crate::AddressSpace, Error<'_>> { + match word { + "private" => Ok(crate::AddressSpace::Private), + "workgroup" => Ok(crate::AddressSpace::WorkGroup), + "uniform" => Ok(crate::AddressSpace::Uniform), + "storage" => Ok(crate::AddressSpace::Storage { + access: crate::StorageAccess::default(), + }), + "push_constant" => Ok(crate::AddressSpace::PushConstant), + "function" => Ok(crate::AddressSpace::Function), + _ => Err(Error::UnknownAddressSpace(span)), + } +} + +pub fn map_built_in(word: &str, span: Span) -> Result<crate::BuiltIn, Error<'_>> { + Ok(match word { + "position" => crate::BuiltIn::Position { invariant: false }, + // vertex + "vertex_index" => crate::BuiltIn::VertexIndex, + "instance_index" => crate::BuiltIn::InstanceIndex, + "view_index" => crate::BuiltIn::ViewIndex, + // fragment + "front_facing" => crate::BuiltIn::FrontFacing, + "frag_depth" => crate::BuiltIn::FragDepth, + "primitive_index" => crate::BuiltIn::PrimitiveIndex, + "sample_index" => crate::BuiltIn::SampleIndex, + "sample_mask" => crate::BuiltIn::SampleMask, + // compute + "global_invocation_id" => crate::BuiltIn::GlobalInvocationId, + "local_invocation_id" => crate::BuiltIn::LocalInvocationId, + "local_invocation_index" => crate::BuiltIn::LocalInvocationIndex, + "workgroup_id" => crate::BuiltIn::WorkGroupId, + "num_workgroups" => crate::BuiltIn::NumWorkGroups, + _ => return Err(Error::UnknownBuiltin(span)), + }) +} + +pub fn map_interpolation(word: &str, span: Span) -> Result<crate::Interpolation, Error<'_>> { + match word { + "linear" => Ok(crate::Interpolation::Linear), + "flat" => Ok(crate::Interpolation::Flat), + "perspective" => Ok(crate::Interpolation::Perspective), + _ => Err(Error::UnknownAttribute(span)), + } +} + +pub fn map_sampling(word: &str, span: Span) -> Result<crate::Sampling, Error<'_>> { + match word { + "center" => Ok(crate::Sampling::Center), + "centroid" => Ok(crate::Sampling::Centroid), + "sample" => Ok(crate::Sampling::Sample), + _ => Err(Error::UnknownAttribute(span)), + } +} + +pub fn map_storage_format(word: &str, span: Span) -> Result<crate::StorageFormat, Error<'_>> { + use crate::StorageFormat as Sf; + Ok(match word { + "r8unorm" => Sf::R8Unorm, + "r8snorm" => Sf::R8Snorm, + "r8uint" => Sf::R8Uint, + "r8sint" => Sf::R8Sint, + "r16unorm" => Sf::R16Unorm, + "r16snorm" => Sf::R16Snorm, + "r16uint" => Sf::R16Uint, + "r16sint" => Sf::R16Sint, + "r16float" => Sf::R16Float, + "rg8unorm" => Sf::Rg8Unorm, + "rg8snorm" => Sf::Rg8Snorm, + "rg8uint" => Sf::Rg8Uint, + "rg8sint" => Sf::Rg8Sint, + "r32uint" => Sf::R32Uint, + "r32sint" => Sf::R32Sint, + "r32float" => Sf::R32Float, + "rg16unorm" => Sf::Rg16Unorm, + "rg16snorm" => Sf::Rg16Snorm, + "rg16uint" => Sf::Rg16Uint, + "rg16sint" => Sf::Rg16Sint, + "rg16float" => Sf::Rg16Float, + "rgba8unorm" => Sf::Rgba8Unorm, + "rgba8snorm" => Sf::Rgba8Snorm, + "rgba8uint" => Sf::Rgba8Uint, + "rgba8sint" => Sf::Rgba8Sint, + "rgb10a2uint" => Sf::Rgb10a2Uint, + "rgb10a2unorm" => Sf::Rgb10a2Unorm, + "rg11b10float" => Sf::Rg11b10Float, + "rg32uint" => Sf::Rg32Uint, + "rg32sint" => Sf::Rg32Sint, + "rg32float" => Sf::Rg32Float, + "rgba16unorm" => Sf::Rgba16Unorm, + "rgba16snorm" => Sf::Rgba16Snorm, + "rgba16uint" => Sf::Rgba16Uint, + "rgba16sint" => Sf::Rgba16Sint, + "rgba16float" => Sf::Rgba16Float, + "rgba32uint" => Sf::Rgba32Uint, + "rgba32sint" => Sf::Rgba32Sint, + "rgba32float" => Sf::Rgba32Float, + "bgra8unorm" => Sf::Bgra8Unorm, + _ => return Err(Error::UnknownStorageFormat(span)), + }) +} + +pub fn get_scalar_type(word: &str) -> Option<Scalar> { + use crate::ScalarKind as Sk; + match word { + // "f16" => Some(Scalar { kind: Sk::Float, width: 2 }), + "f32" => Some(Scalar { + kind: Sk::Float, + width: 4, + }), + "f64" => Some(Scalar { + kind: Sk::Float, + width: 8, + }), + "i32" => Some(Scalar { + kind: Sk::Sint, + width: 4, + }), + "u32" => Some(Scalar { + kind: Sk::Uint, + width: 4, + }), + "bool" => Some(Scalar { + kind: Sk::Bool, + width: crate::BOOL_WIDTH, + }), + _ => None, + } +} + +pub fn map_derivative(word: &str) -> Option<(crate::DerivativeAxis, crate::DerivativeControl)> { + use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl}; + match word { + "dpdxCoarse" => Some((Axis::X, Ctrl::Coarse)), + "dpdyCoarse" => Some((Axis::Y, Ctrl::Coarse)), + "fwidthCoarse" => Some((Axis::Width, Ctrl::Coarse)), + "dpdxFine" => Some((Axis::X, Ctrl::Fine)), + "dpdyFine" => Some((Axis::Y, Ctrl::Fine)), + "fwidthFine" => Some((Axis::Width, Ctrl::Fine)), + "dpdx" => Some((Axis::X, Ctrl::None)), + "dpdy" => Some((Axis::Y, Ctrl::None)), + "fwidth" => Some((Axis::Width, Ctrl::None)), + _ => None, + } +} + +pub fn map_relational_fun(word: &str) -> Option<crate::RelationalFunction> { + match word { + "any" => Some(crate::RelationalFunction::Any), + "all" => Some(crate::RelationalFunction::All), + _ => None, + } +} + +pub fn map_standard_fun(word: &str) -> Option<crate::MathFunction> { + use crate::MathFunction as Mf; + Some(match word { + // comparison + "abs" => Mf::Abs, + "min" => Mf::Min, + "max" => Mf::Max, + "clamp" => Mf::Clamp, + "saturate" => Mf::Saturate, + // trigonometry + "cos" => Mf::Cos, + "cosh" => Mf::Cosh, + "sin" => Mf::Sin, + "sinh" => Mf::Sinh, + "tan" => Mf::Tan, + "tanh" => Mf::Tanh, + "acos" => Mf::Acos, + "acosh" => Mf::Acosh, + "asin" => Mf::Asin, + "asinh" => Mf::Asinh, + "atan" => Mf::Atan, + "atanh" => Mf::Atanh, + "atan2" => Mf::Atan2, + "radians" => Mf::Radians, + "degrees" => Mf::Degrees, + // decomposition + "ceil" => Mf::Ceil, + "floor" => Mf::Floor, + "round" => Mf::Round, + "fract" => Mf::Fract, + "trunc" => Mf::Trunc, + "modf" => Mf::Modf, + "frexp" => Mf::Frexp, + "ldexp" => Mf::Ldexp, + // exponent + "exp" => Mf::Exp, + "exp2" => Mf::Exp2, + "log" => Mf::Log, + "log2" => Mf::Log2, + "pow" => Mf::Pow, + // geometry + "dot" => Mf::Dot, + "cross" => Mf::Cross, + "distance" => Mf::Distance, + "length" => Mf::Length, + "normalize" => Mf::Normalize, + "faceForward" => Mf::FaceForward, + "reflect" => Mf::Reflect, + "refract" => Mf::Refract, + // computational + "sign" => Mf::Sign, + "fma" => Mf::Fma, + "mix" => Mf::Mix, + "step" => Mf::Step, + "smoothstep" => Mf::SmoothStep, + "sqrt" => Mf::Sqrt, + "inverseSqrt" => Mf::InverseSqrt, + "transpose" => Mf::Transpose, + "determinant" => Mf::Determinant, + // bits + "countTrailingZeros" => Mf::CountTrailingZeros, + "countLeadingZeros" => Mf::CountLeadingZeros, + "countOneBits" => Mf::CountOneBits, + "reverseBits" => Mf::ReverseBits, + "extractBits" => Mf::ExtractBits, + "insertBits" => Mf::InsertBits, + "firstTrailingBit" => Mf::FindLsb, + "firstLeadingBit" => Mf::FindMsb, + // data packing + "pack4x8snorm" => Mf::Pack4x8snorm, + "pack4x8unorm" => Mf::Pack4x8unorm, + "pack2x16snorm" => Mf::Pack2x16snorm, + "pack2x16unorm" => Mf::Pack2x16unorm, + "pack2x16float" => Mf::Pack2x16float, + // data unpacking + "unpack4x8snorm" => Mf::Unpack4x8snorm, + "unpack4x8unorm" => Mf::Unpack4x8unorm, + "unpack2x16snorm" => Mf::Unpack2x16snorm, + "unpack2x16unorm" => Mf::Unpack2x16unorm, + "unpack2x16float" => Mf::Unpack2x16float, + _ => return None, + }) +} + +pub fn map_conservative_depth( + word: &str, + span: Span, +) -> Result<crate::ConservativeDepth, Error<'_>> { + use crate::ConservativeDepth as Cd; + match word { + "greater_equal" => Ok(Cd::GreaterEqual), + "less_equal" => Ok(Cd::LessEqual), + "unchanged" => Ok(Cd::Unchanged), + _ => Err(Error::UnknownConservativeDepth(span)), + } +} diff --git a/third_party/rust/naga/src/front/wgsl/parse/lexer.rs b/third_party/rust/naga/src/front/wgsl/parse/lexer.rs new file mode 100644 index 0000000000..d03a448561 --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/parse/lexer.rs @@ -0,0 +1,739 @@ +use super::{number::consume_number, Error, ExpectedToken}; +use crate::front::wgsl::error::NumberError; +use crate::front::wgsl::parse::{conv, Number}; +use crate::front::wgsl::Scalar; +use crate::Span; + +type TokenSpan<'a> = (Token<'a>, Span); + +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum Token<'a> { + Separator(char), + Paren(char), + Attribute, + Number(Result<Number, NumberError>), + Word(&'a str), + Operation(char), + LogicalOperation(char), + ShiftOperation(char), + AssignmentOperation(char), + IncrementOperation, + DecrementOperation, + Arrow, + Unknown(char), + Trivia, + End, +} + +fn consume_any(input: &str, what: impl Fn(char) -> bool) -> (&str, &str) { + let pos = input.find(|c| !what(c)).unwrap_or(input.len()); + input.split_at(pos) +} + +/// Return the token at the start of `input`. +/// +/// If `generic` is `false`, then the bit shift operators `>>` or `<<` +/// are valid lookahead tokens for the current parser state (see [§3.1 +/// Parsing] in the WGSL specification). In other words: +/// +/// - If `generic` is `true`, then we are expecting an angle bracket +/// around a generic type parameter, like the `<` and `>` in +/// `vec3<f32>`, so interpret `<` and `>` as `Token::Paren` tokens, +/// even if they're part of `<<` or `>>` sequences. +/// +/// - Otherwise, interpret `<<` and `>>` as shift operators: +/// `Token::LogicalOperation` tokens. +/// +/// [§3.1 Parsing]: https://gpuweb.github.io/gpuweb/wgsl/#parsing +fn consume_token(input: &str, generic: bool) -> (Token<'_>, &str) { + let mut chars = input.chars(); + let cur = match chars.next() { + Some(c) => c, + None => return (Token::End, ""), + }; + match cur { + ':' | ';' | ',' => (Token::Separator(cur), chars.as_str()), + '.' => { + let og_chars = chars.as_str(); + match chars.next() { + Some('0'..='9') => consume_number(input), + _ => (Token::Separator(cur), og_chars), + } + } + '@' => (Token::Attribute, chars.as_str()), + '(' | ')' | '{' | '}' | '[' | ']' => (Token::Paren(cur), chars.as_str()), + '<' | '>' => { + let og_chars = chars.as_str(); + match chars.next() { + Some('=') if !generic => (Token::LogicalOperation(cur), chars.as_str()), + Some(c) if c == cur && !generic => { + let og_chars = chars.as_str(); + match chars.next() { + Some('=') => (Token::AssignmentOperation(cur), chars.as_str()), + _ => (Token::ShiftOperation(cur), og_chars), + } + } + _ => (Token::Paren(cur), og_chars), + } + } + '0'..='9' => consume_number(input), + '/' => { + let og_chars = chars.as_str(); + match chars.next() { + Some('/') => { + let _ = chars.position(is_comment_end); + (Token::Trivia, chars.as_str()) + } + Some('*') => { + let mut depth = 1; + let mut prev = None; + + for c in &mut chars { + match (prev, c) { + (Some('*'), '/') => { + prev = None; + depth -= 1; + if depth == 0 { + return (Token::Trivia, chars.as_str()); + } + } + (Some('/'), '*') => { + prev = None; + depth += 1; + } + _ => { + prev = Some(c); + } + } + } + + (Token::End, "") + } + Some('=') => (Token::AssignmentOperation(cur), chars.as_str()), + _ => (Token::Operation(cur), og_chars), + } + } + '-' => { + let og_chars = chars.as_str(); + match chars.next() { + Some('>') => (Token::Arrow, chars.as_str()), + Some('-') => (Token::DecrementOperation, chars.as_str()), + Some('=') => (Token::AssignmentOperation(cur), chars.as_str()), + _ => (Token::Operation(cur), og_chars), + } + } + '+' => { + let og_chars = chars.as_str(); + match chars.next() { + Some('+') => (Token::IncrementOperation, chars.as_str()), + Some('=') => (Token::AssignmentOperation(cur), chars.as_str()), + _ => (Token::Operation(cur), og_chars), + } + } + '*' | '%' | '^' => { + let og_chars = chars.as_str(); + match chars.next() { + Some('=') => (Token::AssignmentOperation(cur), chars.as_str()), + _ => (Token::Operation(cur), og_chars), + } + } + '~' => (Token::Operation(cur), chars.as_str()), + '=' | '!' => { + let og_chars = chars.as_str(); + match chars.next() { + Some('=') => (Token::LogicalOperation(cur), chars.as_str()), + _ => (Token::Operation(cur), og_chars), + } + } + '&' | '|' => { + let og_chars = chars.as_str(); + match chars.next() { + Some(c) if c == cur => (Token::LogicalOperation(cur), chars.as_str()), + Some('=') => (Token::AssignmentOperation(cur), chars.as_str()), + _ => (Token::Operation(cur), og_chars), + } + } + _ if is_blankspace(cur) => { + let (_, rest) = consume_any(input, is_blankspace); + (Token::Trivia, rest) + } + _ if is_word_start(cur) => { + let (word, rest) = consume_any(input, is_word_part); + (Token::Word(word), rest) + } + _ => (Token::Unknown(cur), chars.as_str()), + } +} + +/// Returns whether or not a char is a comment end +/// (Unicode Pattern_White_Space excluding U+0020, U+0009, U+200E and U+200F) +const fn is_comment_end(c: char) -> bool { + match c { + '\u{000a}'..='\u{000d}' | '\u{0085}' | '\u{2028}' | '\u{2029}' => true, + _ => false, + } +} + +/// Returns whether or not a char is a blankspace (Unicode Pattern_White_Space) +const fn is_blankspace(c: char) -> bool { + match c { + '\u{0020}' + | '\u{0009}'..='\u{000d}' + | '\u{0085}' + | '\u{200e}' + | '\u{200f}' + | '\u{2028}' + | '\u{2029}' => true, + _ => false, + } +} + +/// Returns whether or not a char is a word start (Unicode XID_Start + '_') +fn is_word_start(c: char) -> bool { + c == '_' || unicode_xid::UnicodeXID::is_xid_start(c) +} + +/// Returns whether or not a char is a word part (Unicode XID_Continue) +fn is_word_part(c: char) -> bool { + unicode_xid::UnicodeXID::is_xid_continue(c) +} + +#[derive(Clone)] +pub(in crate::front::wgsl) struct Lexer<'a> { + input: &'a str, + pub(in crate::front::wgsl) source: &'a str, + // The byte offset of the end of the last non-trivia token. + last_end_offset: usize, +} + +impl<'a> Lexer<'a> { + pub(in crate::front::wgsl) const fn new(input: &'a str) -> Self { + Lexer { + input, + source: input, + last_end_offset: 0, + } + } + + /// Calls the function with a lexer and returns the result of the function as well as the span for everything the function parsed + /// + /// # Examples + /// ```ignore + /// let lexer = Lexer::new("5"); + /// let (value, span) = lexer.capture_span(Lexer::next_uint_literal); + /// assert_eq!(value, 5); + /// ``` + #[inline] + pub fn capture_span<T, E>( + &mut self, + inner: impl FnOnce(&mut Self) -> Result<T, E>, + ) -> Result<(T, Span), E> { + let start = self.current_byte_offset(); + let res = inner(self)?; + let end = self.current_byte_offset(); + Ok((res, Span::from(start..end))) + } + + pub(in crate::front::wgsl) fn start_byte_offset(&mut self) -> usize { + loop { + // Eat all trivia because `next` doesn't eat trailing trivia. + let (token, rest) = consume_token(self.input, false); + if let Token::Trivia = token { + self.input = rest; + } else { + return self.current_byte_offset(); + } + } + } + + fn peek_token_and_rest(&mut self) -> (TokenSpan<'a>, &'a str) { + let mut cloned = self.clone(); + let token = cloned.next(); + let rest = cloned.input; + (token, rest) + } + + const fn current_byte_offset(&self) -> usize { + self.source.len() - self.input.len() + } + + pub(in crate::front::wgsl) fn span_from(&self, offset: usize) -> Span { + Span::from(offset..self.last_end_offset) + } + + /// Return the next non-whitespace token from `self`. + /// + /// Assume we are a parse state where bit shift operators may + /// occur, but not angle brackets. + #[must_use] + pub(in crate::front::wgsl) fn next(&mut self) -> TokenSpan<'a> { + self.next_impl(false) + } + + /// Return the next non-whitespace token from `self`. + /// + /// Assume we are in a parse state where angle brackets may occur, + /// but not bit shift operators. + #[must_use] + pub(in crate::front::wgsl) fn next_generic(&mut self) -> TokenSpan<'a> { + self.next_impl(true) + } + + /// Return the next non-whitespace token from `self`, with a span. + /// + /// See [`consume_token`] for the meaning of `generic`. + fn next_impl(&mut self, generic: bool) -> TokenSpan<'a> { + let mut start_byte_offset = self.current_byte_offset(); + loop { + let (token, rest) = consume_token(self.input, generic); + self.input = rest; + match token { + Token::Trivia => start_byte_offset = self.current_byte_offset(), + _ => { + self.last_end_offset = self.current_byte_offset(); + return (token, self.span_from(start_byte_offset)); + } + } + } + } + + #[must_use] + pub(in crate::front::wgsl) fn peek(&mut self) -> TokenSpan<'a> { + let (token, _) = self.peek_token_and_rest(); + token + } + + pub(in crate::front::wgsl) fn expect_span( + &mut self, + expected: Token<'a>, + ) -> Result<Span, Error<'a>> { + let next = self.next(); + if next.0 == expected { + Ok(next.1) + } else { + Err(Error::Unexpected(next.1, ExpectedToken::Token(expected))) + } + } + + pub(in crate::front::wgsl) fn expect(&mut self, expected: Token<'a>) -> Result<(), Error<'a>> { + self.expect_span(expected)?; + Ok(()) + } + + pub(in crate::front::wgsl) fn expect_generic_paren( + &mut self, + expected: char, + ) -> Result<(), Error<'a>> { + let next = self.next_generic(); + if next.0 == Token::Paren(expected) { + Ok(()) + } else { + Err(Error::Unexpected( + next.1, + ExpectedToken::Token(Token::Paren(expected)), + )) + } + } + + /// If the next token matches it is skipped and true is returned + pub(in crate::front::wgsl) fn skip(&mut self, what: Token<'_>) -> bool { + let (peeked_token, rest) = self.peek_token_and_rest(); + if peeked_token.0 == what { + self.input = rest; + true + } else { + false + } + } + + pub(in crate::front::wgsl) fn next_ident_with_span( + &mut self, + ) -> Result<(&'a str, Span), Error<'a>> { + match self.next() { + (Token::Word("_"), span) => Err(Error::InvalidIdentifierUnderscore(span)), + (Token::Word(word), span) if word.starts_with("__") => { + Err(Error::ReservedIdentifierPrefix(span)) + } + (Token::Word(word), span) => Ok((word, span)), + other => Err(Error::Unexpected(other.1, ExpectedToken::Identifier)), + } + } + + pub(in crate::front::wgsl) fn next_ident( + &mut self, + ) -> Result<super::ast::Ident<'a>, Error<'a>> { + let ident = self + .next_ident_with_span() + .map(|(name, span)| super::ast::Ident { name, span })?; + + if crate::keywords::wgsl::RESERVED.contains(&ident.name) { + return Err(Error::ReservedKeyword(ident.span)); + } + + Ok(ident) + } + + /// Parses a generic scalar type, for example `<f32>`. + pub(in crate::front::wgsl) fn next_scalar_generic(&mut self) -> Result<Scalar, Error<'a>> { + self.expect_generic_paren('<')?; + let pair = match self.next() { + (Token::Word(word), span) => { + conv::get_scalar_type(word).ok_or(Error::UnknownScalarType(span)) + } + (_, span) => Err(Error::UnknownScalarType(span)), + }?; + self.expect_generic_paren('>')?; + Ok(pair) + } + + /// Parses a generic scalar type, for example `<f32>`. + /// + /// Returns the span covering the inner type, excluding the brackets. + pub(in crate::front::wgsl) fn next_scalar_generic_with_span( + &mut self, + ) -> Result<(Scalar, Span), Error<'a>> { + self.expect_generic_paren('<')?; + let pair = match self.next() { + (Token::Word(word), span) => conv::get_scalar_type(word) + .map(|scalar| (scalar, span)) + .ok_or(Error::UnknownScalarType(span)), + (_, span) => Err(Error::UnknownScalarType(span)), + }?; + self.expect_generic_paren('>')?; + Ok(pair) + } + + pub(in crate::front::wgsl) fn next_storage_access( + &mut self, + ) -> Result<crate::StorageAccess, Error<'a>> { + let (ident, span) = self.next_ident_with_span()?; + match ident { + "read" => Ok(crate::StorageAccess::LOAD), + "write" => Ok(crate::StorageAccess::STORE), + "read_write" => Ok(crate::StorageAccess::LOAD | crate::StorageAccess::STORE), + _ => Err(Error::UnknownAccess(span)), + } + } + + pub(in crate::front::wgsl) fn next_format_generic( + &mut self, + ) -> Result<(crate::StorageFormat, crate::StorageAccess), Error<'a>> { + self.expect(Token::Paren('<'))?; + let (ident, ident_span) = self.next_ident_with_span()?; + let format = conv::map_storage_format(ident, ident_span)?; + self.expect(Token::Separator(','))?; + let access = self.next_storage_access()?; + self.expect(Token::Paren('>'))?; + Ok((format, access)) + } + + pub(in crate::front::wgsl) fn open_arguments(&mut self) -> Result<(), Error<'a>> { + self.expect(Token::Paren('(')) + } + + pub(in crate::front::wgsl) fn close_arguments(&mut self) -> Result<(), Error<'a>> { + let _ = self.skip(Token::Separator(',')); + self.expect(Token::Paren(')')) + } + + pub(in crate::front::wgsl) fn next_argument(&mut self) -> Result<bool, Error<'a>> { + let paren = Token::Paren(')'); + if self.skip(Token::Separator(',')) { + Ok(!self.skip(paren)) + } else { + self.expect(paren).map(|()| false) + } + } +} + +#[cfg(test)] +#[track_caller] +fn sub_test(source: &str, expected_tokens: &[Token]) { + let mut lex = Lexer::new(source); + for &token in expected_tokens { + assert_eq!(lex.next().0, token); + } + assert_eq!(lex.next().0, Token::End); +} + +#[test] +fn test_numbers() { + // WGSL spec examples // + + // decimal integer + sub_test( + "0x123 0X123u 1u 123 0 0i 0x3f", + &[ + Token::Number(Ok(Number::AbstractInt(291))), + Token::Number(Ok(Number::U32(291))), + Token::Number(Ok(Number::U32(1))), + Token::Number(Ok(Number::AbstractInt(123))), + Token::Number(Ok(Number::AbstractInt(0))), + Token::Number(Ok(Number::I32(0))), + Token::Number(Ok(Number::AbstractInt(63))), + ], + ); + // decimal floating point + sub_test( + "0.e+4f 01. .01 12.34 .0f 0h 1e-3 0xa.fp+2 0x1P+4f 0X.3 0x3p+2h 0X1.fp-4 0x3.2p+2h", + &[ + Token::Number(Ok(Number::F32(0.))), + Token::Number(Ok(Number::AbstractFloat(1.))), + Token::Number(Ok(Number::AbstractFloat(0.01))), + Token::Number(Ok(Number::AbstractFloat(12.34))), + Token::Number(Ok(Number::F32(0.))), + Token::Number(Err(NumberError::UnimplementedF16)), + Token::Number(Ok(Number::AbstractFloat(0.001))), + Token::Number(Ok(Number::AbstractFloat(43.75))), + Token::Number(Ok(Number::F32(16.))), + Token::Number(Ok(Number::AbstractFloat(0.1875))), + Token::Number(Err(NumberError::UnimplementedF16)), + Token::Number(Ok(Number::AbstractFloat(0.12109375))), + Token::Number(Err(NumberError::UnimplementedF16)), + ], + ); + + // MIN / MAX // + + // min / max decimal integer + sub_test( + "0i 2147483647i 2147483648i", + &[ + Token::Number(Ok(Number::I32(0))), + Token::Number(Ok(Number::I32(i32::MAX))), + Token::Number(Err(NumberError::NotRepresentable)), + ], + ); + // min / max decimal unsigned integer + sub_test( + "0u 4294967295u 4294967296u", + &[ + Token::Number(Ok(Number::U32(u32::MIN))), + Token::Number(Ok(Number::U32(u32::MAX))), + Token::Number(Err(NumberError::NotRepresentable)), + ], + ); + + // min / max hexadecimal signed integer + sub_test( + "0x0i 0x7FFFFFFFi 0x80000000i", + &[ + Token::Number(Ok(Number::I32(0))), + Token::Number(Ok(Number::I32(i32::MAX))), + Token::Number(Err(NumberError::NotRepresentable)), + ], + ); + // min / max hexadecimal unsigned integer + sub_test( + "0x0u 0xFFFFFFFFu 0x100000000u", + &[ + Token::Number(Ok(Number::U32(u32::MIN))), + Token::Number(Ok(Number::U32(u32::MAX))), + Token::Number(Err(NumberError::NotRepresentable)), + ], + ); + + // min/max decimal abstract int + sub_test( + "0 9223372036854775807 9223372036854775808", + &[ + Token::Number(Ok(Number::AbstractInt(0))), + Token::Number(Ok(Number::AbstractInt(i64::MAX))), + Token::Number(Err(NumberError::NotRepresentable)), + ], + ); + + // min/max hexadecimal abstract int + sub_test( + "0 0x7fffffffffffffff 0x8000000000000000", + &[ + Token::Number(Ok(Number::AbstractInt(0))), + Token::Number(Ok(Number::AbstractInt(i64::MAX))), + Token::Number(Err(NumberError::NotRepresentable)), + ], + ); + + /// ≈ 2^-126 * 2^−23 (= 2^−149) + const SMALLEST_POSITIVE_SUBNORMAL_F32: f32 = 1e-45; + /// ≈ 2^-126 * (1 − 2^−23) + const LARGEST_SUBNORMAL_F32: f32 = 1.1754942e-38; + /// ≈ 2^-126 + const SMALLEST_POSITIVE_NORMAL_F32: f32 = f32::MIN_POSITIVE; + /// ≈ 1 − 2^−24 + const LARGEST_F32_LESS_THAN_ONE: f32 = 0.99999994; + /// ≈ 1 + 2^−23 + const SMALLEST_F32_LARGER_THAN_ONE: f32 = 1.0000001; + /// ≈ 2^127 * (2 − 2^−23) + const LARGEST_NORMAL_F32: f32 = f32::MAX; + + // decimal floating point + sub_test( + "1e-45f 1.1754942e-38f 1.17549435e-38f 0.99999994f 1.0000001f 3.40282347e+38f", + &[ + Token::Number(Ok(Number::F32(SMALLEST_POSITIVE_SUBNORMAL_F32))), + Token::Number(Ok(Number::F32(LARGEST_SUBNORMAL_F32))), + Token::Number(Ok(Number::F32(SMALLEST_POSITIVE_NORMAL_F32))), + Token::Number(Ok(Number::F32(LARGEST_F32_LESS_THAN_ONE))), + Token::Number(Ok(Number::F32(SMALLEST_F32_LARGER_THAN_ONE))), + Token::Number(Ok(Number::F32(LARGEST_NORMAL_F32))), + ], + ); + sub_test( + "3.40282367e+38f", + &[ + Token::Number(Err(NumberError::NotRepresentable)), // ≈ 2^128 + ], + ); + + // hexadecimal floating point + sub_test( + "0x1p-149f 0x7FFFFFp-149f 0x1p-126f 0xFFFFFFp-24f 0x800001p-23f 0xFFFFFFp+104f", + &[ + Token::Number(Ok(Number::F32(SMALLEST_POSITIVE_SUBNORMAL_F32))), + Token::Number(Ok(Number::F32(LARGEST_SUBNORMAL_F32))), + Token::Number(Ok(Number::F32(SMALLEST_POSITIVE_NORMAL_F32))), + Token::Number(Ok(Number::F32(LARGEST_F32_LESS_THAN_ONE))), + Token::Number(Ok(Number::F32(SMALLEST_F32_LARGER_THAN_ONE))), + Token::Number(Ok(Number::F32(LARGEST_NORMAL_F32))), + ], + ); + sub_test( + "0x1p128f 0x1.000001p0f", + &[ + Token::Number(Err(NumberError::NotRepresentable)), // = 2^128 + Token::Number(Err(NumberError::NotRepresentable)), + ], + ); +} + +#[test] +fn double_floats() { + sub_test( + "0x1.2p4lf 0x1p8lf 0.0625lf 625e-4lf 10lf 10l", + &[ + Token::Number(Ok(Number::F64(18.0))), + Token::Number(Ok(Number::F64(256.0))), + Token::Number(Ok(Number::F64(0.0625))), + Token::Number(Ok(Number::F64(0.0625))), + Token::Number(Ok(Number::F64(10.0))), + Token::Number(Ok(Number::AbstractInt(10))), + Token::Word("l"), + ], + ) +} + +#[test] +fn test_tokens() { + sub_test("id123_OK", &[Token::Word("id123_OK")]); + sub_test( + "92No", + &[ + Token::Number(Ok(Number::AbstractInt(92))), + Token::Word("No"), + ], + ); + sub_test( + "2u3o", + &[ + Token::Number(Ok(Number::U32(2))), + Token::Number(Ok(Number::AbstractInt(3))), + Token::Word("o"), + ], + ); + sub_test( + "2.4f44po", + &[ + Token::Number(Ok(Number::F32(2.4))), + Token::Number(Ok(Number::AbstractInt(44))), + Token::Word("po"), + ], + ); + sub_test( + "Δέλτα réflexion Кызыл 𐰓𐰏𐰇 朝焼け سلام 검정 שָׁלוֹם गुलाबी փիրուզ", + &[ + Token::Word("Δέλτα"), + Token::Word("réflexion"), + Token::Word("Кызыл"), + Token::Word("𐰓𐰏𐰇"), + Token::Word("朝焼け"), + Token::Word("سلام"), + Token::Word("검정"), + Token::Word("שָׁלוֹם"), + Token::Word("गुलाबी"), + Token::Word("փիրուզ"), + ], + ); + sub_test("æNoø", &[Token::Word("æNoø")]); + sub_test("No¾", &[Token::Word("No"), Token::Unknown('¾')]); + sub_test("No好", &[Token::Word("No好")]); + sub_test("_No", &[Token::Word("_No")]); + sub_test( + "*/*/***/*//=/*****//", + &[ + Token::Operation('*'), + Token::AssignmentOperation('/'), + Token::Operation('/'), + ], + ); + + // Type suffixes are only allowed on hex float literals + // if you provided an exponent. + sub_test( + "0x1.2f 0x1.2f 0x1.2h 0x1.2H 0x1.2lf", + &[ + // The 'f' suffixes are taken as a hex digit: + // the fractional part is 0x2f / 256. + Token::Number(Ok(Number::AbstractFloat(1.0 + 0x2f as f64 / 256.0))), + Token::Number(Ok(Number::AbstractFloat(1.0 + 0x2f as f64 / 256.0))), + Token::Number(Ok(Number::AbstractFloat(1.125))), + Token::Word("h"), + Token::Number(Ok(Number::AbstractFloat(1.125))), + Token::Word("H"), + Token::Number(Ok(Number::AbstractFloat(1.125))), + Token::Word("lf"), + ], + ) +} + +#[test] +fn test_variable_decl() { + sub_test( + "@group(0 ) var< uniform> texture: texture_multisampled_2d <f32 >;", + &[ + Token::Attribute, + Token::Word("group"), + Token::Paren('('), + Token::Number(Ok(Number::AbstractInt(0))), + Token::Paren(')'), + Token::Word("var"), + Token::Paren('<'), + Token::Word("uniform"), + Token::Paren('>'), + Token::Word("texture"), + Token::Separator(':'), + Token::Word("texture_multisampled_2d"), + Token::Paren('<'), + Token::Word("f32"), + Token::Paren('>'), + Token::Separator(';'), + ], + ); + sub_test( + "var<storage,read_write> buffer: array<u32>;", + &[ + Token::Word("var"), + Token::Paren('<'), + Token::Word("storage"), + Token::Separator(','), + Token::Word("read_write"), + Token::Paren('>'), + Token::Word("buffer"), + Token::Separator(':'), + Token::Word("array"), + Token::Paren('<'), + Token::Word("u32"), + Token::Paren('>'), + Token::Separator(';'), + ], + ); +} diff --git a/third_party/rust/naga/src/front/wgsl/parse/mod.rs b/third_party/rust/naga/src/front/wgsl/parse/mod.rs new file mode 100644 index 0000000000..51fc2f013b --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/parse/mod.rs @@ -0,0 +1,2350 @@ +use crate::front::wgsl::error::{Error, ExpectedToken}; +use crate::front::wgsl::parse::lexer::{Lexer, Token}; +use crate::front::wgsl::parse::number::Number; +use crate::front::wgsl::Scalar; +use crate::front::SymbolTable; +use crate::{Arena, FastIndexSet, Handle, ShaderStage, Span}; + +pub mod ast; +pub mod conv; +pub mod lexer; +pub mod number; + +/// State for constructing an AST expression. +/// +/// Not to be confused with [`lower::ExpressionContext`], which is for producing +/// Naga IR from the AST we produce here. +/// +/// [`lower::ExpressionContext`]: super::lower::ExpressionContext +struct ExpressionContext<'input, 'temp, 'out> { + /// The [`TranslationUnit::expressions`] arena to which we should contribute + /// expressions. + /// + /// [`TranslationUnit::expressions`]: ast::TranslationUnit::expressions + expressions: &'out mut Arena<ast::Expression<'input>>, + + /// The [`TranslationUnit::types`] arena to which we should contribute new + /// types. + /// + /// [`TranslationUnit::types`]: ast::TranslationUnit::types + types: &'out mut Arena<ast::Type<'input>>, + + /// A map from identifiers in scope to the locals/arguments they represent. + /// + /// The handles refer to the [`Function::locals`] area; see that field's + /// documentation for details. + /// + /// [`Function::locals`]: ast::Function::locals + local_table: &'temp mut SymbolTable<&'input str, Handle<ast::Local>>, + + /// The [`Function::locals`] arena for the function we're building. + /// + /// [`Function::locals`]: ast::Function::locals + locals: &'out mut Arena<ast::Local>, + + /// Identifiers used by the current global declaration that have no local definition. + /// + /// This becomes the [`GlobalDecl`]'s [`dependencies`] set. + /// + /// Note that we don't know at parse time what kind of [`GlobalDecl`] the + /// name refers to. We can't look up names until we've seen the entire + /// translation unit. + /// + /// [`GlobalDecl`]: ast::GlobalDecl + /// [`dependencies`]: ast::GlobalDecl::dependencies + unresolved: &'out mut FastIndexSet<ast::Dependency<'input>>, +} + +impl<'a> ExpressionContext<'a, '_, '_> { + fn parse_binary_op( + &mut self, + lexer: &mut Lexer<'a>, + classifier: impl Fn(Token<'a>) -> Option<crate::BinaryOperator>, + mut parser: impl FnMut( + &mut Lexer<'a>, + &mut Self, + ) -> Result<Handle<ast::Expression<'a>>, Error<'a>>, + ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> { + let start = lexer.start_byte_offset(); + let mut accumulator = parser(lexer, self)?; + while let Some(op) = classifier(lexer.peek().0) { + let _ = lexer.next(); + let left = accumulator; + let right = parser(lexer, self)?; + accumulator = self.expressions.append( + ast::Expression::Binary { op, left, right }, + lexer.span_from(start), + ); + } + Ok(accumulator) + } + + fn declare_local(&mut self, name: ast::Ident<'a>) -> Result<Handle<ast::Local>, Error<'a>> { + let handle = self.locals.append(ast::Local, name.span); + if let Some(old) = self.local_table.add(name.name, handle) { + Err(Error::Redefinition { + previous: self.locals.get_span(old), + current: name.span, + }) + } else { + Ok(handle) + } + } +} + +/// Which grammar rule we are in the midst of parsing. +/// +/// This is used for error checking. `Parser` maintains a stack of +/// these and (occasionally) checks that it is being pushed and popped +/// as expected. +#[derive(Clone, Debug, PartialEq)] +enum Rule { + Attribute, + VariableDecl, + TypeDecl, + FunctionDecl, + Block, + Statement, + PrimaryExpr, + SingularExpr, + UnaryExpr, + GeneralExpr, +} + +struct ParsedAttribute<T> { + value: Option<T>, +} + +impl<T> Default for ParsedAttribute<T> { + fn default() -> Self { + Self { value: None } + } +} + +impl<T> ParsedAttribute<T> { + fn set(&mut self, value: T, name_span: Span) -> Result<(), Error<'static>> { + if self.value.is_some() { + return Err(Error::RepeatedAttribute(name_span)); + } + self.value = Some(value); + Ok(()) + } +} + +#[derive(Default)] +struct BindingParser<'a> { + location: ParsedAttribute<Handle<ast::Expression<'a>>>, + second_blend_source: ParsedAttribute<bool>, + built_in: ParsedAttribute<crate::BuiltIn>, + interpolation: ParsedAttribute<crate::Interpolation>, + sampling: ParsedAttribute<crate::Sampling>, + invariant: ParsedAttribute<bool>, +} + +impl<'a> BindingParser<'a> { + fn parse( + &mut self, + parser: &mut Parser, + lexer: &mut Lexer<'a>, + name: &'a str, + name_span: Span, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<(), Error<'a>> { + match name { + "location" => { + lexer.expect(Token::Paren('('))?; + self.location + .set(parser.general_expression(lexer, ctx)?, name_span)?; + lexer.expect(Token::Paren(')'))?; + } + "builtin" => { + lexer.expect(Token::Paren('('))?; + let (raw, span) = lexer.next_ident_with_span()?; + self.built_in + .set(conv::map_built_in(raw, span)?, name_span)?; + lexer.expect(Token::Paren(')'))?; + } + "interpolate" => { + lexer.expect(Token::Paren('('))?; + let (raw, span) = lexer.next_ident_with_span()?; + self.interpolation + .set(conv::map_interpolation(raw, span)?, name_span)?; + if lexer.skip(Token::Separator(',')) { + let (raw, span) = lexer.next_ident_with_span()?; + self.sampling + .set(conv::map_sampling(raw, span)?, name_span)?; + } + lexer.expect(Token::Paren(')'))?; + } + "second_blend_source" => { + self.second_blend_source.set(true, name_span)?; + } + "invariant" => { + self.invariant.set(true, name_span)?; + } + _ => return Err(Error::UnknownAttribute(name_span)), + } + Ok(()) + } + + fn finish(self, span: Span) -> Result<Option<ast::Binding<'a>>, Error<'a>> { + match ( + self.location.value, + self.built_in.value, + self.interpolation.value, + self.sampling.value, + self.invariant.value.unwrap_or_default(), + ) { + (None, None, None, None, false) => Ok(None), + (Some(location), None, interpolation, sampling, false) => { + // Before handing over the completed `Module`, we call + // `apply_default_interpolation` to ensure that the interpolation and + // sampling have been explicitly specified on all vertex shader output and fragment + // shader input user bindings, so leaving them potentially `None` here is fine. + Ok(Some(ast::Binding::Location { + location, + interpolation, + sampling, + second_blend_source: self.second_blend_source.value.unwrap_or(false), + })) + } + (None, Some(crate::BuiltIn::Position { .. }), None, None, invariant) => { + Ok(Some(ast::Binding::BuiltIn(crate::BuiltIn::Position { + invariant, + }))) + } + (None, Some(built_in), None, None, false) => Ok(Some(ast::Binding::BuiltIn(built_in))), + (_, _, _, _, _) => Err(Error::InconsistentBinding(span)), + } + } +} + +pub struct Parser { + rules: Vec<(Rule, usize)>, +} + +impl Parser { + pub const fn new() -> Self { + Parser { rules: Vec::new() } + } + + fn reset(&mut self) { + self.rules.clear(); + } + + fn push_rule_span(&mut self, rule: Rule, lexer: &mut Lexer<'_>) { + self.rules.push((rule, lexer.start_byte_offset())); + } + + fn pop_rule_span(&mut self, lexer: &Lexer<'_>) -> Span { + let (_, initial) = self.rules.pop().unwrap(); + lexer.span_from(initial) + } + + fn peek_rule_span(&mut self, lexer: &Lexer<'_>) -> Span { + let &(_, initial) = self.rules.last().unwrap(); + lexer.span_from(initial) + } + + fn switch_value<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<ast::SwitchValue<'a>, Error<'a>> { + if let Token::Word("default") = lexer.peek().0 { + let _ = lexer.next(); + return Ok(ast::SwitchValue::Default); + } + + let expr = self.general_expression(lexer, ctx)?; + Ok(ast::SwitchValue::Expr(expr)) + } + + /// Decide if we're looking at a construction expression, and return its + /// type if so. + /// + /// If the identifier `word` is a [type-defining keyword], then return a + /// [`ConstructorType`] value describing the type to build. Return an error + /// if the type is not constructible (like `sampler`). + /// + /// If `word` isn't a type name, then return `None`. + /// + /// [type-defining keyword]: https://gpuweb.github.io/gpuweb/wgsl/#type-defining-keywords + /// [`ConstructorType`]: ast::ConstructorType + fn constructor_type<'a>( + &mut self, + lexer: &mut Lexer<'a>, + word: &'a str, + span: Span, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<Option<ast::ConstructorType<'a>>, Error<'a>> { + if let Some(scalar) = conv::get_scalar_type(word) { + return Ok(Some(ast::ConstructorType::Scalar(scalar))); + } + + let partial = match word { + "vec2" => ast::ConstructorType::PartialVector { + size: crate::VectorSize::Bi, + }, + "vec2i" => { + return Ok(Some(ast::ConstructorType::Vector { + size: crate::VectorSize::Bi, + scalar: Scalar { + kind: crate::ScalarKind::Sint, + width: 4, + }, + })) + } + "vec2u" => { + return Ok(Some(ast::ConstructorType::Vector { + size: crate::VectorSize::Bi, + scalar: Scalar { + kind: crate::ScalarKind::Uint, + width: 4, + }, + })) + } + "vec2f" => { + return Ok(Some(ast::ConstructorType::Vector { + size: crate::VectorSize::Bi, + scalar: Scalar::F32, + })) + } + "vec3" => ast::ConstructorType::PartialVector { + size: crate::VectorSize::Tri, + }, + "vec3i" => { + return Ok(Some(ast::ConstructorType::Vector { + size: crate::VectorSize::Tri, + scalar: Scalar::I32, + })) + } + "vec3u" => { + return Ok(Some(ast::ConstructorType::Vector { + size: crate::VectorSize::Tri, + scalar: Scalar::U32, + })) + } + "vec3f" => { + return Ok(Some(ast::ConstructorType::Vector { + size: crate::VectorSize::Tri, + scalar: Scalar::F32, + })) + } + "vec4" => ast::ConstructorType::PartialVector { + size: crate::VectorSize::Quad, + }, + "vec4i" => { + return Ok(Some(ast::ConstructorType::Vector { + size: crate::VectorSize::Quad, + scalar: Scalar::I32, + })) + } + "vec4u" => { + return Ok(Some(ast::ConstructorType::Vector { + size: crate::VectorSize::Quad, + scalar: Scalar::U32, + })) + } + "vec4f" => { + return Ok(Some(ast::ConstructorType::Vector { + size: crate::VectorSize::Quad, + scalar: Scalar::F32, + })) + } + "mat2x2" => ast::ConstructorType::PartialMatrix { + columns: crate::VectorSize::Bi, + rows: crate::VectorSize::Bi, + }, + "mat2x2f" => { + return Ok(Some(ast::ConstructorType::Matrix { + columns: crate::VectorSize::Bi, + rows: crate::VectorSize::Bi, + width: 4, + })) + } + "mat2x3" => ast::ConstructorType::PartialMatrix { + columns: crate::VectorSize::Bi, + rows: crate::VectorSize::Tri, + }, + "mat2x3f" => { + return Ok(Some(ast::ConstructorType::Matrix { + columns: crate::VectorSize::Bi, + rows: crate::VectorSize::Tri, + width: 4, + })) + } + "mat2x4" => ast::ConstructorType::PartialMatrix { + columns: crate::VectorSize::Bi, + rows: crate::VectorSize::Quad, + }, + "mat2x4f" => { + return Ok(Some(ast::ConstructorType::Matrix { + columns: crate::VectorSize::Bi, + rows: crate::VectorSize::Quad, + width: 4, + })) + } + "mat3x2" => ast::ConstructorType::PartialMatrix { + columns: crate::VectorSize::Tri, + rows: crate::VectorSize::Bi, + }, + "mat3x2f" => { + return Ok(Some(ast::ConstructorType::Matrix { + columns: crate::VectorSize::Tri, + rows: crate::VectorSize::Bi, + width: 4, + })) + } + "mat3x3" => ast::ConstructorType::PartialMatrix { + columns: crate::VectorSize::Tri, + rows: crate::VectorSize::Tri, + }, + "mat3x3f" => { + return Ok(Some(ast::ConstructorType::Matrix { + columns: crate::VectorSize::Tri, + rows: crate::VectorSize::Tri, + width: 4, + })) + } + "mat3x4" => ast::ConstructorType::PartialMatrix { + columns: crate::VectorSize::Tri, + rows: crate::VectorSize::Quad, + }, + "mat3x4f" => { + return Ok(Some(ast::ConstructorType::Matrix { + columns: crate::VectorSize::Tri, + rows: crate::VectorSize::Quad, + width: 4, + })) + } + "mat4x2" => ast::ConstructorType::PartialMatrix { + columns: crate::VectorSize::Quad, + rows: crate::VectorSize::Bi, + }, + "mat4x2f" => { + return Ok(Some(ast::ConstructorType::Matrix { + columns: crate::VectorSize::Quad, + rows: crate::VectorSize::Bi, + width: 4, + })) + } + "mat4x3" => ast::ConstructorType::PartialMatrix { + columns: crate::VectorSize::Quad, + rows: crate::VectorSize::Tri, + }, + "mat4x3f" => { + return Ok(Some(ast::ConstructorType::Matrix { + columns: crate::VectorSize::Quad, + rows: crate::VectorSize::Tri, + width: 4, + })) + } + "mat4x4" => ast::ConstructorType::PartialMatrix { + columns: crate::VectorSize::Quad, + rows: crate::VectorSize::Quad, + }, + "mat4x4f" => { + return Ok(Some(ast::ConstructorType::Matrix { + columns: crate::VectorSize::Quad, + rows: crate::VectorSize::Quad, + width: 4, + })) + } + "array" => ast::ConstructorType::PartialArray, + "atomic" + | "binding_array" + | "sampler" + | "sampler_comparison" + | "texture_1d" + | "texture_1d_array" + | "texture_2d" + | "texture_2d_array" + | "texture_3d" + | "texture_cube" + | "texture_cube_array" + | "texture_multisampled_2d" + | "texture_multisampled_2d_array" + | "texture_depth_2d" + | "texture_depth_2d_array" + | "texture_depth_cube" + | "texture_depth_cube_array" + | "texture_depth_multisampled_2d" + | "texture_storage_1d" + | "texture_storage_1d_array" + | "texture_storage_2d" + | "texture_storage_2d_array" + | "texture_storage_3d" => return Err(Error::TypeNotConstructible(span)), + _ => return Ok(None), + }; + + // parse component type if present + match (lexer.peek().0, partial) { + (Token::Paren('<'), ast::ConstructorType::PartialVector { size }) => { + let scalar = lexer.next_scalar_generic()?; + Ok(Some(ast::ConstructorType::Vector { size, scalar })) + } + (Token::Paren('<'), ast::ConstructorType::PartialMatrix { columns, rows }) => { + let (scalar, span) = lexer.next_scalar_generic_with_span()?; + match scalar.kind { + crate::ScalarKind::Float => Ok(Some(ast::ConstructorType::Matrix { + columns, + rows, + width: scalar.width, + })), + _ => Err(Error::BadMatrixScalarKind(span, scalar)), + } + } + (Token::Paren('<'), ast::ConstructorType::PartialArray) => { + lexer.expect_generic_paren('<')?; + let base = self.type_decl(lexer, ctx)?; + let size = if lexer.skip(Token::Separator(',')) { + let expr = self.unary_expression(lexer, ctx)?; + ast::ArraySize::Constant(expr) + } else { + ast::ArraySize::Dynamic + }; + lexer.expect_generic_paren('>')?; + + Ok(Some(ast::ConstructorType::Array { base, size })) + } + (_, partial) => Ok(Some(partial)), + } + } + + /// Expects `name` to be consumed (not in lexer). + fn arguments<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<Vec<Handle<ast::Expression<'a>>>, Error<'a>> { + lexer.open_arguments()?; + let mut arguments = Vec::new(); + loop { + if !arguments.is_empty() { + if !lexer.next_argument()? { + break; + } + } else if lexer.skip(Token::Paren(')')) { + break; + } + let arg = self.general_expression(lexer, ctx)?; + arguments.push(arg); + } + + Ok(arguments) + } + + /// Expects [`Rule::PrimaryExpr`] or [`Rule::SingularExpr`] on top; does not pop it. + /// Expects `name` to be consumed (not in lexer). + fn function_call<'a>( + &mut self, + lexer: &mut Lexer<'a>, + name: &'a str, + name_span: Span, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> { + assert!(self.rules.last().is_some()); + + let expr = match name { + // bitcast looks like a function call, but it's an operator and must be handled differently. + "bitcast" => { + lexer.expect_generic_paren('<')?; + let start = lexer.start_byte_offset(); + let to = self.type_decl(lexer, ctx)?; + let span = lexer.span_from(start); + lexer.expect_generic_paren('>')?; + + lexer.open_arguments()?; + let expr = self.general_expression(lexer, ctx)?; + lexer.close_arguments()?; + + ast::Expression::Bitcast { + expr, + to, + ty_span: span, + } + } + // everything else must be handled later, since they can be hidden by user-defined functions. + _ => { + let arguments = self.arguments(lexer, ctx)?; + ctx.unresolved.insert(ast::Dependency { + ident: name, + usage: name_span, + }); + ast::Expression::Call { + function: ast::Ident { + name, + span: name_span, + }, + arguments, + } + } + }; + + let span = self.peek_rule_span(lexer); + let expr = ctx.expressions.append(expr, span); + Ok(expr) + } + + fn ident_expr<'a>( + &mut self, + name: &'a str, + name_span: Span, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> ast::IdentExpr<'a> { + match ctx.local_table.lookup(name) { + Some(&local) => ast::IdentExpr::Local(local), + None => { + ctx.unresolved.insert(ast::Dependency { + ident: name, + usage: name_span, + }); + ast::IdentExpr::Unresolved(name) + } + } + } + + fn primary_expression<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> { + self.push_rule_span(Rule::PrimaryExpr, lexer); + + let expr = match lexer.peek() { + (Token::Paren('('), _) => { + let _ = lexer.next(); + let expr = self.general_expression(lexer, ctx)?; + lexer.expect(Token::Paren(')'))?; + self.pop_rule_span(lexer); + return Ok(expr); + } + (Token::Word("true"), _) => { + let _ = lexer.next(); + ast::Expression::Literal(ast::Literal::Bool(true)) + } + (Token::Word("false"), _) => { + let _ = lexer.next(); + ast::Expression::Literal(ast::Literal::Bool(false)) + } + (Token::Number(res), span) => { + let _ = lexer.next(); + let num = res.map_err(|err| Error::BadNumber(span, err))?; + ast::Expression::Literal(ast::Literal::Number(num)) + } + (Token::Word("RAY_FLAG_NONE"), _) => { + let _ = lexer.next(); + ast::Expression::Literal(ast::Literal::Number(Number::U32(0))) + } + (Token::Word("RAY_FLAG_TERMINATE_ON_FIRST_HIT"), _) => { + let _ = lexer.next(); + ast::Expression::Literal(ast::Literal::Number(Number::U32(4))) + } + (Token::Word("RAY_QUERY_INTERSECTION_NONE"), _) => { + let _ = lexer.next(); + ast::Expression::Literal(ast::Literal::Number(Number::U32(0))) + } + (Token::Word(word), span) => { + let start = lexer.start_byte_offset(); + let _ = lexer.next(); + + if let Some(ty) = self.constructor_type(lexer, word, span, ctx)? { + let ty_span = lexer.span_from(start); + let components = self.arguments(lexer, ctx)?; + ast::Expression::Construct { + ty, + ty_span, + components, + } + } else if let Token::Paren('(') = lexer.peek().0 { + self.pop_rule_span(lexer); + return self.function_call(lexer, word, span, ctx); + } else if word == "bitcast" { + self.pop_rule_span(lexer); + return self.function_call(lexer, word, span, ctx); + } else { + let ident = self.ident_expr(word, span, ctx); + ast::Expression::Ident(ident) + } + } + other => return Err(Error::Unexpected(other.1, ExpectedToken::PrimaryExpression)), + }; + + let span = self.pop_rule_span(lexer); + let expr = ctx.expressions.append(expr, span); + Ok(expr) + } + + fn postfix<'a>( + &mut self, + span_start: usize, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + expr: Handle<ast::Expression<'a>>, + ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> { + let mut expr = expr; + + loop { + let expression = match lexer.peek().0 { + Token::Separator('.') => { + let _ = lexer.next(); + let field = lexer.next_ident()?; + + ast::Expression::Member { base: expr, field } + } + Token::Paren('[') => { + let _ = lexer.next(); + let index = self.general_expression(lexer, ctx)?; + lexer.expect(Token::Paren(']'))?; + + ast::Expression::Index { base: expr, index } + } + _ => break, + }; + + let span = lexer.span_from(span_start); + expr = ctx.expressions.append(expression, span); + } + + Ok(expr) + } + + /// Parse a `unary_expression`. + fn unary_expression<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> { + self.push_rule_span(Rule::UnaryExpr, lexer); + //TODO: refactor this to avoid backing up + let expr = match lexer.peek().0 { + Token::Operation('-') => { + let _ = lexer.next(); + let expr = self.unary_expression(lexer, ctx)?; + let expr = ast::Expression::Unary { + op: crate::UnaryOperator::Negate, + expr, + }; + let span = self.peek_rule_span(lexer); + ctx.expressions.append(expr, span) + } + Token::Operation('!') => { + let _ = lexer.next(); + let expr = self.unary_expression(lexer, ctx)?; + let expr = ast::Expression::Unary { + op: crate::UnaryOperator::LogicalNot, + expr, + }; + let span = self.peek_rule_span(lexer); + ctx.expressions.append(expr, span) + } + Token::Operation('~') => { + let _ = lexer.next(); + let expr = self.unary_expression(lexer, ctx)?; + let expr = ast::Expression::Unary { + op: crate::UnaryOperator::BitwiseNot, + expr, + }; + let span = self.peek_rule_span(lexer); + ctx.expressions.append(expr, span) + } + Token::Operation('*') => { + let _ = lexer.next(); + let expr = self.unary_expression(lexer, ctx)?; + let expr = ast::Expression::Deref(expr); + let span = self.peek_rule_span(lexer); + ctx.expressions.append(expr, span) + } + Token::Operation('&') => { + let _ = lexer.next(); + let expr = self.unary_expression(lexer, ctx)?; + let expr = ast::Expression::AddrOf(expr); + let span = self.peek_rule_span(lexer); + ctx.expressions.append(expr, span) + } + _ => self.singular_expression(lexer, ctx)?, + }; + + self.pop_rule_span(lexer); + Ok(expr) + } + + /// Parse a `singular_expression`. + fn singular_expression<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> { + let start = lexer.start_byte_offset(); + self.push_rule_span(Rule::SingularExpr, lexer); + let primary_expr = self.primary_expression(lexer, ctx)?; + let singular_expr = self.postfix(start, lexer, ctx, primary_expr)?; + self.pop_rule_span(lexer); + + Ok(singular_expr) + } + + fn equality_expression<'a>( + &mut self, + lexer: &mut Lexer<'a>, + context: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> { + // equality_expression + context.parse_binary_op( + lexer, + |token| match token { + Token::LogicalOperation('=') => Some(crate::BinaryOperator::Equal), + Token::LogicalOperation('!') => Some(crate::BinaryOperator::NotEqual), + _ => None, + }, + // relational_expression + |lexer, context| { + context.parse_binary_op( + lexer, + |token| match token { + Token::Paren('<') => Some(crate::BinaryOperator::Less), + Token::Paren('>') => Some(crate::BinaryOperator::Greater), + Token::LogicalOperation('<') => Some(crate::BinaryOperator::LessEqual), + Token::LogicalOperation('>') => Some(crate::BinaryOperator::GreaterEqual), + _ => None, + }, + // shift_expression + |lexer, context| { + context.parse_binary_op( + lexer, + |token| match token { + Token::ShiftOperation('<') => { + Some(crate::BinaryOperator::ShiftLeft) + } + Token::ShiftOperation('>') => { + Some(crate::BinaryOperator::ShiftRight) + } + _ => None, + }, + // additive_expression + |lexer, context| { + context.parse_binary_op( + lexer, + |token| match token { + Token::Operation('+') => Some(crate::BinaryOperator::Add), + Token::Operation('-') => { + Some(crate::BinaryOperator::Subtract) + } + _ => None, + }, + // multiplicative_expression + |lexer, context| { + context.parse_binary_op( + lexer, + |token| match token { + Token::Operation('*') => { + Some(crate::BinaryOperator::Multiply) + } + Token::Operation('/') => { + Some(crate::BinaryOperator::Divide) + } + Token::Operation('%') => { + Some(crate::BinaryOperator::Modulo) + } + _ => None, + }, + |lexer, context| self.unary_expression(lexer, context), + ) + }, + ) + }, + ) + }, + ) + }, + ) + } + + fn general_expression<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> { + self.general_expression_with_span(lexer, ctx) + .map(|(expr, _)| expr) + } + + fn general_expression_with_span<'a>( + &mut self, + lexer: &mut Lexer<'a>, + context: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<(Handle<ast::Expression<'a>>, Span), Error<'a>> { + self.push_rule_span(Rule::GeneralExpr, lexer); + // logical_or_expression + let handle = context.parse_binary_op( + lexer, + |token| match token { + Token::LogicalOperation('|') => Some(crate::BinaryOperator::LogicalOr), + _ => None, + }, + // logical_and_expression + |lexer, context| { + context.parse_binary_op( + lexer, + |token| match token { + Token::LogicalOperation('&') => Some(crate::BinaryOperator::LogicalAnd), + _ => None, + }, + // inclusive_or_expression + |lexer, context| { + context.parse_binary_op( + lexer, + |token| match token { + Token::Operation('|') => Some(crate::BinaryOperator::InclusiveOr), + _ => None, + }, + // exclusive_or_expression + |lexer, context| { + context.parse_binary_op( + lexer, + |token| match token { + Token::Operation('^') => { + Some(crate::BinaryOperator::ExclusiveOr) + } + _ => None, + }, + // and_expression + |lexer, context| { + context.parse_binary_op( + lexer, + |token| match token { + Token::Operation('&') => { + Some(crate::BinaryOperator::And) + } + _ => None, + }, + |lexer, context| { + self.equality_expression(lexer, context) + }, + ) + }, + ) + }, + ) + }, + ) + }, + )?; + Ok((handle, self.pop_rule_span(lexer))) + } + + fn variable_decl<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<ast::GlobalVariable<'a>, Error<'a>> { + self.push_rule_span(Rule::VariableDecl, lexer); + let mut space = crate::AddressSpace::Handle; + + if lexer.skip(Token::Paren('<')) { + let (class_str, span) = lexer.next_ident_with_span()?; + space = match class_str { + "storage" => { + let access = if lexer.skip(Token::Separator(',')) { + lexer.next_storage_access()? + } else { + // defaulting to `read` + crate::StorageAccess::LOAD + }; + crate::AddressSpace::Storage { access } + } + _ => conv::map_address_space(class_str, span)?, + }; + lexer.expect(Token::Paren('>'))?; + } + let name = lexer.next_ident()?; + lexer.expect(Token::Separator(':'))?; + let ty = self.type_decl(lexer, ctx)?; + + let init = if lexer.skip(Token::Operation('=')) { + let handle = self.general_expression(lexer, ctx)?; + Some(handle) + } else { + None + }; + lexer.expect(Token::Separator(';'))?; + self.pop_rule_span(lexer); + + Ok(ast::GlobalVariable { + name, + space, + binding: None, + ty, + init, + }) + } + + fn struct_body<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<Vec<ast::StructMember<'a>>, Error<'a>> { + let mut members = Vec::new(); + + lexer.expect(Token::Paren('{'))?; + let mut ready = true; + while !lexer.skip(Token::Paren('}')) { + if !ready { + return Err(Error::Unexpected( + lexer.next().1, + ExpectedToken::Token(Token::Separator(',')), + )); + } + let (mut size, mut align) = (ParsedAttribute::default(), ParsedAttribute::default()); + self.push_rule_span(Rule::Attribute, lexer); + let mut bind_parser = BindingParser::default(); + while lexer.skip(Token::Attribute) { + match lexer.next_ident_with_span()? { + ("size", name_span) => { + lexer.expect(Token::Paren('('))?; + let expr = self.general_expression(lexer, ctx)?; + lexer.expect(Token::Paren(')'))?; + size.set(expr, name_span)?; + } + ("align", name_span) => { + lexer.expect(Token::Paren('('))?; + let expr = self.general_expression(lexer, ctx)?; + lexer.expect(Token::Paren(')'))?; + align.set(expr, name_span)?; + } + (word, word_span) => bind_parser.parse(self, lexer, word, word_span, ctx)?, + } + } + + let bind_span = self.pop_rule_span(lexer); + let binding = bind_parser.finish(bind_span)?; + + let name = lexer.next_ident()?; + lexer.expect(Token::Separator(':'))?; + let ty = self.type_decl(lexer, ctx)?; + ready = lexer.skip(Token::Separator(',')); + + members.push(ast::StructMember { + name, + ty, + binding, + size: size.value, + align: align.value, + }); + } + + Ok(members) + } + + fn matrix_scalar_type<'a>( + &mut self, + lexer: &mut Lexer<'a>, + columns: crate::VectorSize, + rows: crate::VectorSize, + ) -> Result<ast::Type<'a>, Error<'a>> { + let (scalar, span) = lexer.next_scalar_generic_with_span()?; + match scalar.kind { + crate::ScalarKind::Float => Ok(ast::Type::Matrix { + columns, + rows, + width: scalar.width, + }), + _ => Err(Error::BadMatrixScalarKind(span, scalar)), + } + } + + fn type_decl_impl<'a>( + &mut self, + lexer: &mut Lexer<'a>, + word: &'a str, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<Option<ast::Type<'a>>, Error<'a>> { + if let Some(scalar) = conv::get_scalar_type(word) { + return Ok(Some(ast::Type::Scalar(scalar))); + } + + Ok(Some(match word { + "vec2" => { + let scalar = lexer.next_scalar_generic()?; + ast::Type::Vector { + size: crate::VectorSize::Bi, + scalar, + } + } + "vec2i" => ast::Type::Vector { + size: crate::VectorSize::Bi, + scalar: Scalar { + kind: crate::ScalarKind::Sint, + width: 4, + }, + }, + "vec2u" => ast::Type::Vector { + size: crate::VectorSize::Bi, + scalar: Scalar { + kind: crate::ScalarKind::Uint, + width: 4, + }, + }, + "vec2f" => ast::Type::Vector { + size: crate::VectorSize::Bi, + scalar: Scalar::F32, + }, + "vec3" => { + let scalar = lexer.next_scalar_generic()?; + ast::Type::Vector { + size: crate::VectorSize::Tri, + scalar, + } + } + "vec3i" => ast::Type::Vector { + size: crate::VectorSize::Tri, + scalar: Scalar { + kind: crate::ScalarKind::Sint, + width: 4, + }, + }, + "vec3u" => ast::Type::Vector { + size: crate::VectorSize::Tri, + scalar: Scalar { + kind: crate::ScalarKind::Uint, + width: 4, + }, + }, + "vec3f" => ast::Type::Vector { + size: crate::VectorSize::Tri, + scalar: Scalar::F32, + }, + "vec4" => { + let scalar = lexer.next_scalar_generic()?; + ast::Type::Vector { + size: crate::VectorSize::Quad, + scalar, + } + } + "vec4i" => ast::Type::Vector { + size: crate::VectorSize::Quad, + scalar: Scalar { + kind: crate::ScalarKind::Sint, + width: 4, + }, + }, + "vec4u" => ast::Type::Vector { + size: crate::VectorSize::Quad, + scalar: Scalar { + kind: crate::ScalarKind::Uint, + width: 4, + }, + }, + "vec4f" => ast::Type::Vector { + size: crate::VectorSize::Quad, + scalar: Scalar::F32, + }, + "mat2x2" => { + self.matrix_scalar_type(lexer, crate::VectorSize::Bi, crate::VectorSize::Bi)? + } + "mat2x2f" => ast::Type::Matrix { + columns: crate::VectorSize::Bi, + rows: crate::VectorSize::Bi, + width: 4, + }, + "mat2x3" => { + self.matrix_scalar_type(lexer, crate::VectorSize::Bi, crate::VectorSize::Tri)? + } + "mat2x3f" => ast::Type::Matrix { + columns: crate::VectorSize::Bi, + rows: crate::VectorSize::Tri, + width: 4, + }, + "mat2x4" => { + self.matrix_scalar_type(lexer, crate::VectorSize::Bi, crate::VectorSize::Quad)? + } + "mat2x4f" => ast::Type::Matrix { + columns: crate::VectorSize::Bi, + rows: crate::VectorSize::Quad, + width: 4, + }, + "mat3x2" => { + self.matrix_scalar_type(lexer, crate::VectorSize::Tri, crate::VectorSize::Bi)? + } + "mat3x2f" => ast::Type::Matrix { + columns: crate::VectorSize::Tri, + rows: crate::VectorSize::Bi, + width: 4, + }, + "mat3x3" => { + self.matrix_scalar_type(lexer, crate::VectorSize::Tri, crate::VectorSize::Tri)? + } + "mat3x3f" => ast::Type::Matrix { + columns: crate::VectorSize::Tri, + rows: crate::VectorSize::Tri, + width: 4, + }, + "mat3x4" => { + self.matrix_scalar_type(lexer, crate::VectorSize::Tri, crate::VectorSize::Quad)? + } + "mat3x4f" => ast::Type::Matrix { + columns: crate::VectorSize::Tri, + rows: crate::VectorSize::Quad, + width: 4, + }, + "mat4x2" => { + self.matrix_scalar_type(lexer, crate::VectorSize::Quad, crate::VectorSize::Bi)? + } + "mat4x2f" => ast::Type::Matrix { + columns: crate::VectorSize::Quad, + rows: crate::VectorSize::Bi, + width: 4, + }, + "mat4x3" => { + self.matrix_scalar_type(lexer, crate::VectorSize::Quad, crate::VectorSize::Tri)? + } + "mat4x3f" => ast::Type::Matrix { + columns: crate::VectorSize::Quad, + rows: crate::VectorSize::Tri, + width: 4, + }, + "mat4x4" => { + self.matrix_scalar_type(lexer, crate::VectorSize::Quad, crate::VectorSize::Quad)? + } + "mat4x4f" => ast::Type::Matrix { + columns: crate::VectorSize::Quad, + rows: crate::VectorSize::Quad, + width: 4, + }, + "atomic" => { + let scalar = lexer.next_scalar_generic()?; + ast::Type::Atomic(scalar) + } + "ptr" => { + lexer.expect_generic_paren('<')?; + let (ident, span) = lexer.next_ident_with_span()?; + let mut space = conv::map_address_space(ident, span)?; + lexer.expect(Token::Separator(','))?; + let base = self.type_decl(lexer, ctx)?; + if let crate::AddressSpace::Storage { ref mut access } = space { + *access = if lexer.skip(Token::Separator(',')) { + lexer.next_storage_access()? + } else { + crate::StorageAccess::LOAD + }; + } + lexer.expect_generic_paren('>')?; + ast::Type::Pointer { base, space } + } + "array" => { + lexer.expect_generic_paren('<')?; + let base = self.type_decl(lexer, ctx)?; + let size = if lexer.skip(Token::Separator(',')) { + let size = self.unary_expression(lexer, ctx)?; + ast::ArraySize::Constant(size) + } else { + ast::ArraySize::Dynamic + }; + lexer.expect_generic_paren('>')?; + + ast::Type::Array { base, size } + } + "binding_array" => { + lexer.expect_generic_paren('<')?; + let base = self.type_decl(lexer, ctx)?; + let size = if lexer.skip(Token::Separator(',')) { + let size = self.unary_expression(lexer, ctx)?; + ast::ArraySize::Constant(size) + } else { + ast::ArraySize::Dynamic + }; + lexer.expect_generic_paren('>')?; + + ast::Type::BindingArray { base, size } + } + "sampler" => ast::Type::Sampler { comparison: false }, + "sampler_comparison" => ast::Type::Sampler { comparison: true }, + "texture_1d" => { + let (scalar, span) = lexer.next_scalar_generic_with_span()?; + Self::check_texture_sample_type(scalar, span)?; + ast::Type::Image { + dim: crate::ImageDimension::D1, + arrayed: false, + class: crate::ImageClass::Sampled { + kind: scalar.kind, + multi: false, + }, + } + } + "texture_1d_array" => { + let (scalar, span) = lexer.next_scalar_generic_with_span()?; + Self::check_texture_sample_type(scalar, span)?; + ast::Type::Image { + dim: crate::ImageDimension::D1, + arrayed: true, + class: crate::ImageClass::Sampled { + kind: scalar.kind, + multi: false, + }, + } + } + "texture_2d" => { + let (scalar, span) = lexer.next_scalar_generic_with_span()?; + Self::check_texture_sample_type(scalar, span)?; + ast::Type::Image { + dim: crate::ImageDimension::D2, + arrayed: false, + class: crate::ImageClass::Sampled { + kind: scalar.kind, + multi: false, + }, + } + } + "texture_2d_array" => { + let (scalar, span) = lexer.next_scalar_generic_with_span()?; + Self::check_texture_sample_type(scalar, span)?; + ast::Type::Image { + dim: crate::ImageDimension::D2, + arrayed: true, + class: crate::ImageClass::Sampled { + kind: scalar.kind, + multi: false, + }, + } + } + "texture_3d" => { + let (scalar, span) = lexer.next_scalar_generic_with_span()?; + Self::check_texture_sample_type(scalar, span)?; + ast::Type::Image { + dim: crate::ImageDimension::D3, + arrayed: false, + class: crate::ImageClass::Sampled { + kind: scalar.kind, + multi: false, + }, + } + } + "texture_cube" => { + let (scalar, span) = lexer.next_scalar_generic_with_span()?; + Self::check_texture_sample_type(scalar, span)?; + ast::Type::Image { + dim: crate::ImageDimension::Cube, + arrayed: false, + class: crate::ImageClass::Sampled { + kind: scalar.kind, + multi: false, + }, + } + } + "texture_cube_array" => { + let (scalar, span) = lexer.next_scalar_generic_with_span()?; + Self::check_texture_sample_type(scalar, span)?; + ast::Type::Image { + dim: crate::ImageDimension::Cube, + arrayed: true, + class: crate::ImageClass::Sampled { + kind: scalar.kind, + multi: false, + }, + } + } + "texture_multisampled_2d" => { + let (scalar, span) = lexer.next_scalar_generic_with_span()?; + Self::check_texture_sample_type(scalar, span)?; + ast::Type::Image { + dim: crate::ImageDimension::D2, + arrayed: false, + class: crate::ImageClass::Sampled { + kind: scalar.kind, + multi: true, + }, + } + } + "texture_multisampled_2d_array" => { + let (scalar, span) = lexer.next_scalar_generic_with_span()?; + Self::check_texture_sample_type(scalar, span)?; + ast::Type::Image { + dim: crate::ImageDimension::D2, + arrayed: true, + class: crate::ImageClass::Sampled { + kind: scalar.kind, + multi: true, + }, + } + } + "texture_depth_2d" => ast::Type::Image { + dim: crate::ImageDimension::D2, + arrayed: false, + class: crate::ImageClass::Depth { multi: false }, + }, + "texture_depth_2d_array" => ast::Type::Image { + dim: crate::ImageDimension::D2, + arrayed: true, + class: crate::ImageClass::Depth { multi: false }, + }, + "texture_depth_cube" => ast::Type::Image { + dim: crate::ImageDimension::Cube, + arrayed: false, + class: crate::ImageClass::Depth { multi: false }, + }, + "texture_depth_cube_array" => ast::Type::Image { + dim: crate::ImageDimension::Cube, + arrayed: true, + class: crate::ImageClass::Depth { multi: false }, + }, + "texture_depth_multisampled_2d" => ast::Type::Image { + dim: crate::ImageDimension::D2, + arrayed: false, + class: crate::ImageClass::Depth { multi: true }, + }, + "texture_storage_1d" => { + let (format, access) = lexer.next_format_generic()?; + ast::Type::Image { + dim: crate::ImageDimension::D1, + arrayed: false, + class: crate::ImageClass::Storage { format, access }, + } + } + "texture_storage_1d_array" => { + let (format, access) = lexer.next_format_generic()?; + ast::Type::Image { + dim: crate::ImageDimension::D1, + arrayed: true, + class: crate::ImageClass::Storage { format, access }, + } + } + "texture_storage_2d" => { + let (format, access) = lexer.next_format_generic()?; + ast::Type::Image { + dim: crate::ImageDimension::D2, + arrayed: false, + class: crate::ImageClass::Storage { format, access }, + } + } + "texture_storage_2d_array" => { + let (format, access) = lexer.next_format_generic()?; + ast::Type::Image { + dim: crate::ImageDimension::D2, + arrayed: true, + class: crate::ImageClass::Storage { format, access }, + } + } + "texture_storage_3d" => { + let (format, access) = lexer.next_format_generic()?; + ast::Type::Image { + dim: crate::ImageDimension::D3, + arrayed: false, + class: crate::ImageClass::Storage { format, access }, + } + } + "acceleration_structure" => ast::Type::AccelerationStructure, + "ray_query" => ast::Type::RayQuery, + "RayDesc" => ast::Type::RayDesc, + "RayIntersection" => ast::Type::RayIntersection, + _ => return Ok(None), + })) + } + + const fn check_texture_sample_type(scalar: Scalar, span: Span) -> Result<(), Error<'static>> { + use crate::ScalarKind::*; + // Validate according to https://gpuweb.github.io/gpuweb/wgsl/#sampled-texture-type + match scalar { + Scalar { + kind: Float | Sint | Uint, + width: 4, + } => Ok(()), + _ => Err(Error::BadTextureSampleType { span, scalar }), + } + } + + /// Parse type declaration of a given name. + fn type_decl<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<Handle<ast::Type<'a>>, Error<'a>> { + self.push_rule_span(Rule::TypeDecl, lexer); + + let (name, span) = lexer.next_ident_with_span()?; + + let ty = match self.type_decl_impl(lexer, name, ctx)? { + Some(ty) => ty, + None => { + ctx.unresolved.insert(ast::Dependency { + ident: name, + usage: span, + }); + ast::Type::User(ast::Ident { name, span }) + } + }; + + self.pop_rule_span(lexer); + + let handle = ctx.types.append(ty, Span::UNDEFINED); + Ok(handle) + } + + fn assignment_op_and_rhs<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + block: &mut ast::Block<'a>, + target: Handle<ast::Expression<'a>>, + span_start: usize, + ) -> Result<(), Error<'a>> { + use crate::BinaryOperator as Bo; + + let op = lexer.next(); + let (op, value) = match op { + (Token::Operation('='), _) => { + let value = self.general_expression(lexer, ctx)?; + (None, value) + } + (Token::AssignmentOperation(c), _) => { + let op = match c { + '<' => Bo::ShiftLeft, + '>' => Bo::ShiftRight, + '+' => Bo::Add, + '-' => Bo::Subtract, + '*' => Bo::Multiply, + '/' => Bo::Divide, + '%' => Bo::Modulo, + '&' => Bo::And, + '|' => Bo::InclusiveOr, + '^' => Bo::ExclusiveOr, + // Note: `consume_token` shouldn't produce any other assignment ops + _ => unreachable!(), + }; + + let value = self.general_expression(lexer, ctx)?; + (Some(op), value) + } + token @ (Token::IncrementOperation | Token::DecrementOperation, _) => { + let op = match token.0 { + Token::IncrementOperation => ast::StatementKind::Increment, + Token::DecrementOperation => ast::StatementKind::Decrement, + _ => unreachable!(), + }; + + let span = lexer.span_from(span_start); + block.stmts.push(ast::Statement { + kind: op(target), + span, + }); + return Ok(()); + } + _ => return Err(Error::Unexpected(op.1, ExpectedToken::Assignment)), + }; + + let span = lexer.span_from(span_start); + block.stmts.push(ast::Statement { + kind: ast::StatementKind::Assign { target, op, value }, + span, + }); + Ok(()) + } + + /// Parse an assignment statement (will also parse increment and decrement statements) + fn assignment_statement<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + block: &mut ast::Block<'a>, + ) -> Result<(), Error<'a>> { + let span_start = lexer.start_byte_offset(); + let target = self.general_expression(lexer, ctx)?; + self.assignment_op_and_rhs(lexer, ctx, block, target, span_start) + } + + /// Parse a function call statement. + /// Expects `ident` to be consumed (not in the lexer). + fn function_statement<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ident: &'a str, + ident_span: Span, + span_start: usize, + context: &mut ExpressionContext<'a, '_, '_>, + block: &mut ast::Block<'a>, + ) -> Result<(), Error<'a>> { + self.push_rule_span(Rule::SingularExpr, lexer); + + context.unresolved.insert(ast::Dependency { + ident, + usage: ident_span, + }); + let arguments = self.arguments(lexer, context)?; + let span = lexer.span_from(span_start); + + block.stmts.push(ast::Statement { + kind: ast::StatementKind::Call { + function: ast::Ident { + name: ident, + span: ident_span, + }, + arguments, + }, + span, + }); + + self.pop_rule_span(lexer); + + Ok(()) + } + + fn function_call_or_assignment_statement<'a>( + &mut self, + lexer: &mut Lexer<'a>, + context: &mut ExpressionContext<'a, '_, '_>, + block: &mut ast::Block<'a>, + ) -> Result<(), Error<'a>> { + let span_start = lexer.start_byte_offset(); + match lexer.peek() { + (Token::Word(name), span) => { + // A little hack for 2 token lookahead. + let cloned = lexer.clone(); + let _ = lexer.next(); + match lexer.peek() { + (Token::Paren('('), _) => { + self.function_statement(lexer, name, span, span_start, context, block) + } + _ => { + *lexer = cloned; + self.assignment_statement(lexer, context, block) + } + } + } + _ => self.assignment_statement(lexer, context, block), + } + } + + fn statement<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + block: &mut ast::Block<'a>, + ) -> Result<(), Error<'a>> { + self.push_rule_span(Rule::Statement, lexer); + match lexer.peek() { + (Token::Separator(';'), _) => { + let _ = lexer.next(); + self.pop_rule_span(lexer); + return Ok(()); + } + (Token::Paren('{'), _) => { + let (inner, span) = self.block(lexer, ctx)?; + block.stmts.push(ast::Statement { + kind: ast::StatementKind::Block(inner), + span, + }); + self.pop_rule_span(lexer); + return Ok(()); + } + (Token::Word(word), _) => { + let kind = match word { + "_" => { + let _ = lexer.next(); + lexer.expect(Token::Operation('='))?; + let expr = self.general_expression(lexer, ctx)?; + lexer.expect(Token::Separator(';'))?; + + ast::StatementKind::Ignore(expr) + } + "let" => { + let _ = lexer.next(); + let name = lexer.next_ident()?; + + let given_ty = if lexer.skip(Token::Separator(':')) { + let ty = self.type_decl(lexer, ctx)?; + Some(ty) + } else { + None + }; + lexer.expect(Token::Operation('='))?; + let expr_id = self.general_expression(lexer, ctx)?; + lexer.expect(Token::Separator(';'))?; + + let handle = ctx.declare_local(name)?; + ast::StatementKind::LocalDecl(ast::LocalDecl::Let(ast::Let { + name, + ty: given_ty, + init: expr_id, + handle, + })) + } + "var" => { + let _ = lexer.next(); + + let name = lexer.next_ident()?; + let ty = if lexer.skip(Token::Separator(':')) { + let ty = self.type_decl(lexer, ctx)?; + Some(ty) + } else { + None + }; + + let init = if lexer.skip(Token::Operation('=')) { + let init = self.general_expression(lexer, ctx)?; + Some(init) + } else { + None + }; + + lexer.expect(Token::Separator(';'))?; + + let handle = ctx.declare_local(name)?; + ast::StatementKind::LocalDecl(ast::LocalDecl::Var(ast::LocalVariable { + name, + ty, + init, + handle, + })) + } + "return" => { + let _ = lexer.next(); + let value = if lexer.peek().0 != Token::Separator(';') { + let handle = self.general_expression(lexer, ctx)?; + Some(handle) + } else { + None + }; + lexer.expect(Token::Separator(';'))?; + ast::StatementKind::Return { value } + } + "if" => { + let _ = lexer.next(); + let condition = self.general_expression(lexer, ctx)?; + + let accept = self.block(lexer, ctx)?.0; + + let mut elsif_stack = Vec::new(); + let mut elseif_span_start = lexer.start_byte_offset(); + let mut reject = loop { + if !lexer.skip(Token::Word("else")) { + break ast::Block::default(); + } + + if !lexer.skip(Token::Word("if")) { + // ... else { ... } + break self.block(lexer, ctx)?.0; + } + + // ... else if (...) { ... } + let other_condition = self.general_expression(lexer, ctx)?; + let other_block = self.block(lexer, ctx)?; + elsif_stack.push((elseif_span_start, other_condition, other_block)); + elseif_span_start = lexer.start_byte_offset(); + }; + + // reverse-fold the else-if blocks + //Note: we may consider uplifting this to the IR + for (other_span_start, other_cond, other_block) in + elsif_stack.into_iter().rev() + { + let sub_stmt = ast::StatementKind::If { + condition: other_cond, + accept: other_block.0, + reject, + }; + reject = ast::Block::default(); + let span = lexer.span_from(other_span_start); + reject.stmts.push(ast::Statement { + kind: sub_stmt, + span, + }) + } + + ast::StatementKind::If { + condition, + accept, + reject, + } + } + "switch" => { + let _ = lexer.next(); + let selector = self.general_expression(lexer, ctx)?; + lexer.expect(Token::Paren('{'))?; + let mut cases = Vec::new(); + + loop { + // cases + default + match lexer.next() { + (Token::Word("case"), _) => { + // parse a list of values + let value = loop { + let value = self.switch_value(lexer, ctx)?; + if lexer.skip(Token::Separator(',')) { + if lexer.skip(Token::Separator(':')) { + break value; + } + } else { + lexer.skip(Token::Separator(':')); + break value; + } + cases.push(ast::SwitchCase { + value, + body: ast::Block::default(), + fall_through: true, + }); + }; + + let body = self.block(lexer, ctx)?.0; + + cases.push(ast::SwitchCase { + value, + body, + fall_through: false, + }); + } + (Token::Word("default"), _) => { + lexer.skip(Token::Separator(':')); + let body = self.block(lexer, ctx)?.0; + cases.push(ast::SwitchCase { + value: ast::SwitchValue::Default, + body, + fall_through: false, + }); + } + (Token::Paren('}'), _) => break, + (_, span) => { + return Err(Error::Unexpected(span, ExpectedToken::SwitchItem)) + } + } + } + + ast::StatementKind::Switch { selector, cases } + } + "loop" => self.r#loop(lexer, ctx)?, + "while" => { + let _ = lexer.next(); + let mut body = ast::Block::default(); + + let (condition, span) = lexer.capture_span(|lexer| { + let condition = self.general_expression(lexer, ctx)?; + Ok(condition) + })?; + let mut reject = ast::Block::default(); + reject.stmts.push(ast::Statement { + kind: ast::StatementKind::Break, + span, + }); + + body.stmts.push(ast::Statement { + kind: ast::StatementKind::If { + condition, + accept: ast::Block::default(), + reject, + }, + span, + }); + + let (block, span) = self.block(lexer, ctx)?; + body.stmts.push(ast::Statement { + kind: ast::StatementKind::Block(block), + span, + }); + + ast::StatementKind::Loop { + body, + continuing: ast::Block::default(), + break_if: None, + } + } + "for" => { + let _ = lexer.next(); + lexer.expect(Token::Paren('('))?; + + ctx.local_table.push_scope(); + + if !lexer.skip(Token::Separator(';')) { + let num_statements = block.stmts.len(); + let (_, span) = { + let ctx = &mut *ctx; + let block = &mut *block; + lexer.capture_span(|lexer| self.statement(lexer, ctx, block))? + }; + + if block.stmts.len() != num_statements { + match block.stmts.last().unwrap().kind { + ast::StatementKind::Call { .. } + | ast::StatementKind::Assign { .. } + | ast::StatementKind::LocalDecl(_) => {} + _ => return Err(Error::InvalidForInitializer(span)), + } + } + }; + + let mut body = ast::Block::default(); + if !lexer.skip(Token::Separator(';')) { + let (condition, span) = lexer.capture_span(|lexer| { + let condition = self.general_expression(lexer, ctx)?; + lexer.expect(Token::Separator(';'))?; + Ok(condition) + })?; + let mut reject = ast::Block::default(); + reject.stmts.push(ast::Statement { + kind: ast::StatementKind::Break, + span, + }); + body.stmts.push(ast::Statement { + kind: ast::StatementKind::If { + condition, + accept: ast::Block::default(), + reject, + }, + span, + }); + }; + + let mut continuing = ast::Block::default(); + if !lexer.skip(Token::Paren(')')) { + self.function_call_or_assignment_statement( + lexer, + ctx, + &mut continuing, + )?; + lexer.expect(Token::Paren(')'))?; + } + + let (block, span) = self.block(lexer, ctx)?; + body.stmts.push(ast::Statement { + kind: ast::StatementKind::Block(block), + span, + }); + + ctx.local_table.pop_scope(); + + ast::StatementKind::Loop { + body, + continuing, + break_if: None, + } + } + "break" => { + let (_, span) = lexer.next(); + // Check if the next token is an `if`, this indicates + // that the user tried to type out a `break if` which + // is illegal in this position. + let (peeked_token, peeked_span) = lexer.peek(); + if let Token::Word("if") = peeked_token { + let span = span.until(&peeked_span); + return Err(Error::InvalidBreakIf(span)); + } + lexer.expect(Token::Separator(';'))?; + ast::StatementKind::Break + } + "continue" => { + let _ = lexer.next(); + lexer.expect(Token::Separator(';'))?; + ast::StatementKind::Continue + } + "discard" => { + let _ = lexer.next(); + lexer.expect(Token::Separator(';'))?; + ast::StatementKind::Kill + } + // assignment or a function call + _ => { + self.function_call_or_assignment_statement(lexer, ctx, block)?; + lexer.expect(Token::Separator(';'))?; + self.pop_rule_span(lexer); + return Ok(()); + } + }; + + let span = self.pop_rule_span(lexer); + block.stmts.push(ast::Statement { kind, span }); + } + _ => { + self.assignment_statement(lexer, ctx, block)?; + lexer.expect(Token::Separator(';'))?; + self.pop_rule_span(lexer); + } + } + Ok(()) + } + + fn r#loop<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<ast::StatementKind<'a>, Error<'a>> { + let _ = lexer.next(); + let mut body = ast::Block::default(); + let mut continuing = ast::Block::default(); + let mut break_if = None; + + lexer.expect(Token::Paren('{'))?; + + ctx.local_table.push_scope(); + + loop { + if lexer.skip(Token::Word("continuing")) { + // Branch for the `continuing` block, this must be + // the last thing in the loop body + + // Expect a opening brace to start the continuing block + lexer.expect(Token::Paren('{'))?; + loop { + if lexer.skip(Token::Word("break")) { + // Branch for the `break if` statement, this statement + // has the form `break if <expr>;` and must be the last + // statement in a continuing block + + // The break must be followed by an `if` to form + // the break if + lexer.expect(Token::Word("if"))?; + + let condition = self.general_expression(lexer, ctx)?; + // Set the condition of the break if to the newly parsed + // expression + break_if = Some(condition); + + // Expect a semicolon to close the statement + lexer.expect(Token::Separator(';'))?; + // Expect a closing brace to close the continuing block, + // since the break if must be the last statement + lexer.expect(Token::Paren('}'))?; + // Stop parsing the continuing block + break; + } else if lexer.skip(Token::Paren('}')) { + // If we encounter a closing brace it means we have reached + // the end of the continuing block and should stop processing + break; + } else { + // Otherwise try to parse a statement + self.statement(lexer, ctx, &mut continuing)?; + } + } + // Since the continuing block must be the last part of the loop body, + // we expect to see a closing brace to end the loop body + lexer.expect(Token::Paren('}'))?; + break; + } + if lexer.skip(Token::Paren('}')) { + // If we encounter a closing brace it means we have reached + // the end of the loop body and should stop processing + break; + } + // Otherwise try to parse a statement + self.statement(lexer, ctx, &mut body)?; + } + + ctx.local_table.pop_scope(); + + Ok(ast::StatementKind::Loop { + body, + continuing, + break_if, + }) + } + + /// compound_statement + fn block<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<(ast::Block<'a>, Span), Error<'a>> { + self.push_rule_span(Rule::Block, lexer); + + ctx.local_table.push_scope(); + + lexer.expect(Token::Paren('{'))?; + let mut block = ast::Block::default(); + while !lexer.skip(Token::Paren('}')) { + self.statement(lexer, ctx, &mut block)?; + } + + ctx.local_table.pop_scope(); + + let span = self.pop_rule_span(lexer); + Ok((block, span)) + } + + fn varying_binding<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<Option<ast::Binding<'a>>, Error<'a>> { + let mut bind_parser = BindingParser::default(); + self.push_rule_span(Rule::Attribute, lexer); + + while lexer.skip(Token::Attribute) { + let (word, span) = lexer.next_ident_with_span()?; + bind_parser.parse(self, lexer, word, span, ctx)?; + } + + let span = self.pop_rule_span(lexer); + bind_parser.finish(span) + } + + fn function_decl<'a>( + &mut self, + lexer: &mut Lexer<'a>, + out: &mut ast::TranslationUnit<'a>, + dependencies: &mut FastIndexSet<ast::Dependency<'a>>, + ) -> Result<ast::Function<'a>, Error<'a>> { + self.push_rule_span(Rule::FunctionDecl, lexer); + // read function name + let fun_name = lexer.next_ident()?; + + let mut locals = Arena::new(); + + let mut ctx = ExpressionContext { + expressions: &mut out.expressions, + local_table: &mut SymbolTable::default(), + locals: &mut locals, + types: &mut out.types, + unresolved: dependencies, + }; + + // start a scope that contains arguments as well as the function body + ctx.local_table.push_scope(); + + // read parameter list + let mut arguments = Vec::new(); + lexer.expect(Token::Paren('('))?; + let mut ready = true; + while !lexer.skip(Token::Paren(')')) { + if !ready { + return Err(Error::Unexpected( + lexer.next().1, + ExpectedToken::Token(Token::Separator(',')), + )); + } + let binding = self.varying_binding(lexer, &mut ctx)?; + + let param_name = lexer.next_ident()?; + + lexer.expect(Token::Separator(':'))?; + let param_type = self.type_decl(lexer, &mut ctx)?; + + let handle = ctx.declare_local(param_name)?; + arguments.push(ast::FunctionArgument { + name: param_name, + ty: param_type, + binding, + handle, + }); + ready = lexer.skip(Token::Separator(',')); + } + // read return type + let result = if lexer.skip(Token::Arrow) && !lexer.skip(Token::Word("void")) { + let binding = self.varying_binding(lexer, &mut ctx)?; + let ty = self.type_decl(lexer, &mut ctx)?; + Some(ast::FunctionResult { ty, binding }) + } else { + None + }; + + // do not use `self.block` here, since we must not push a new scope + lexer.expect(Token::Paren('{'))?; + let mut body = ast::Block::default(); + while !lexer.skip(Token::Paren('}')) { + self.statement(lexer, &mut ctx, &mut body)?; + } + + ctx.local_table.pop_scope(); + + let fun = ast::Function { + entry_point: None, + name: fun_name, + arguments, + result, + body, + locals, + }; + + // done + self.pop_rule_span(lexer); + + Ok(fun) + } + + fn global_decl<'a>( + &mut self, + lexer: &mut Lexer<'a>, + out: &mut ast::TranslationUnit<'a>, + ) -> Result<(), Error<'a>> { + // read attributes + let mut binding = None; + let mut stage = ParsedAttribute::default(); + let mut compute_span = Span::new(0, 0); + let mut workgroup_size = ParsedAttribute::default(); + let mut early_depth_test = ParsedAttribute::default(); + let (mut bind_index, mut bind_group) = + (ParsedAttribute::default(), ParsedAttribute::default()); + + let mut dependencies = FastIndexSet::default(); + let mut ctx = ExpressionContext { + expressions: &mut out.expressions, + local_table: &mut SymbolTable::default(), + locals: &mut Arena::new(), + types: &mut out.types, + unresolved: &mut dependencies, + }; + + self.push_rule_span(Rule::Attribute, lexer); + while lexer.skip(Token::Attribute) { + match lexer.next_ident_with_span()? { + ("binding", name_span) => { + lexer.expect(Token::Paren('('))?; + bind_index.set(self.general_expression(lexer, &mut ctx)?, name_span)?; + lexer.expect(Token::Paren(')'))?; + } + ("group", name_span) => { + lexer.expect(Token::Paren('('))?; + bind_group.set(self.general_expression(lexer, &mut ctx)?, name_span)?; + lexer.expect(Token::Paren(')'))?; + } + ("vertex", name_span) => { + stage.set(crate::ShaderStage::Vertex, name_span)?; + } + ("fragment", name_span) => { + stage.set(crate::ShaderStage::Fragment, name_span)?; + } + ("compute", name_span) => { + stage.set(crate::ShaderStage::Compute, name_span)?; + compute_span = name_span; + } + ("workgroup_size", name_span) => { + lexer.expect(Token::Paren('('))?; + let mut new_workgroup_size = [None; 3]; + for (i, size) in new_workgroup_size.iter_mut().enumerate() { + *size = Some(self.general_expression(lexer, &mut ctx)?); + match lexer.next() { + (Token::Paren(')'), _) => break, + (Token::Separator(','), _) if i != 2 => (), + other => { + return Err(Error::Unexpected( + other.1, + ExpectedToken::WorkgroupSizeSeparator, + )) + } + } + } + workgroup_size.set(new_workgroup_size, name_span)?; + } + ("early_depth_test", name_span) => { + let conservative = if lexer.skip(Token::Paren('(')) { + let (ident, ident_span) = lexer.next_ident_with_span()?; + let value = conv::map_conservative_depth(ident, ident_span)?; + lexer.expect(Token::Paren(')'))?; + Some(value) + } else { + None + }; + early_depth_test.set(crate::EarlyDepthTest { conservative }, name_span)?; + } + (_, word_span) => return Err(Error::UnknownAttribute(word_span)), + } + } + + let attrib_span = self.pop_rule_span(lexer); + match (bind_group.value, bind_index.value) { + (Some(group), Some(index)) => { + binding = Some(ast::ResourceBinding { + group, + binding: index, + }); + } + (Some(_), None) => return Err(Error::MissingAttribute("binding", attrib_span)), + (None, Some(_)) => return Err(Error::MissingAttribute("group", attrib_span)), + (None, None) => {} + } + + // read item + let start = lexer.start_byte_offset(); + let kind = match lexer.next() { + (Token::Separator(';'), _) => None, + (Token::Word("struct"), _) => { + let name = lexer.next_ident()?; + + let members = self.struct_body(lexer, &mut ctx)?; + Some(ast::GlobalDeclKind::Struct(ast::Struct { name, members })) + } + (Token::Word("alias"), _) => { + let name = lexer.next_ident()?; + + lexer.expect(Token::Operation('='))?; + let ty = self.type_decl(lexer, &mut ctx)?; + lexer.expect(Token::Separator(';'))?; + Some(ast::GlobalDeclKind::Type(ast::TypeAlias { name, ty })) + } + (Token::Word("const"), _) => { + let name = lexer.next_ident()?; + + let ty = if lexer.skip(Token::Separator(':')) { + let ty = self.type_decl(lexer, &mut ctx)?; + Some(ty) + } else { + None + }; + + lexer.expect(Token::Operation('='))?; + let init = self.general_expression(lexer, &mut ctx)?; + lexer.expect(Token::Separator(';'))?; + + Some(ast::GlobalDeclKind::Const(ast::Const { name, ty, init })) + } + (Token::Word("var"), _) => { + let mut var = self.variable_decl(lexer, &mut ctx)?; + var.binding = binding.take(); + Some(ast::GlobalDeclKind::Var(var)) + } + (Token::Word("fn"), _) => { + let function = self.function_decl(lexer, out, &mut dependencies)?; + Some(ast::GlobalDeclKind::Fn(ast::Function { + entry_point: if let Some(stage) = stage.value { + if stage == ShaderStage::Compute && workgroup_size.value.is_none() { + return Err(Error::MissingWorkgroupSize(compute_span)); + } + Some(ast::EntryPoint { + stage, + early_depth_test: early_depth_test.value, + workgroup_size: workgroup_size.value, + }) + } else { + None + }, + ..function + })) + } + (Token::End, _) => return Ok(()), + other => return Err(Error::Unexpected(other.1, ExpectedToken::GlobalItem)), + }; + + if let Some(kind) = kind { + out.decls.append( + ast::GlobalDecl { kind, dependencies }, + lexer.span_from(start), + ); + } + + if !self.rules.is_empty() { + log::error!("Reached the end of global decl, but rule stack is not empty"); + log::error!("Rules: {:?}", self.rules); + return Err(Error::Internal("rule stack is not empty")); + }; + + match binding { + None => Ok(()), + Some(_) => Err(Error::Internal("we had the attribute but no var?")), + } + } + + pub fn parse<'a>(&mut self, source: &'a str) -> Result<ast::TranslationUnit<'a>, Error<'a>> { + self.reset(); + + let mut lexer = Lexer::new(source); + let mut tu = ast::TranslationUnit::default(); + loop { + match self.global_decl(&mut lexer, &mut tu) { + Err(error) => return Err(error), + Ok(()) => { + if lexer.peek().0 == Token::End { + break; + } + } + } + } + + Ok(tu) + } +} diff --git a/third_party/rust/naga/src/front/wgsl/parse/number.rs b/third_party/rust/naga/src/front/wgsl/parse/number.rs new file mode 100644 index 0000000000..7b09ac59bb --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/parse/number.rs @@ -0,0 +1,420 @@ +use crate::front::wgsl::error::NumberError; +use crate::front::wgsl::parse::lexer::Token; + +/// When using this type assume no Abstract Int/Float for now +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum Number { + /// Abstract Int (-2^63 ≤ i < 2^63) + AbstractInt(i64), + /// Abstract Float (IEEE-754 binary64) + AbstractFloat(f64), + /// Concrete i32 + I32(i32), + /// Concrete u32 + U32(u32), + /// Concrete f32 + F32(f32), + /// Concrete f64 + F64(f64), +} + +pub(in crate::front::wgsl) fn consume_number(input: &str) -> (Token<'_>, &str) { + let (result, rest) = parse(input); + (Token::Number(result), rest) +} + +enum Kind { + Int(IntKind), + Float(FloatKind), +} + +enum IntKind { + I32, + U32, +} + +#[derive(Debug)] +enum FloatKind { + F16, + F32, + F64, +} + +// The following regexes (from the WGSL spec) will be matched: + +// int_literal: +// | / 0 [iu]? / +// | / [1-9][0-9]* [iu]? / +// | / 0[xX][0-9a-fA-F]+ [iu]? / + +// decimal_float_literal: +// | / 0 [fh] / +// | / [1-9][0-9]* [fh] / +// | / [0-9]* \.[0-9]+ ([eE][+-]?[0-9]+)? [fh]? / +// | / [0-9]+ \.[0-9]* ([eE][+-]?[0-9]+)? [fh]? / +// | / [0-9]+ [eE][+-]?[0-9]+ [fh]? / + +// hex_float_literal: +// | / 0[xX][0-9a-fA-F]* \.[0-9a-fA-F]+ ([pP][+-]?[0-9]+ [fh]?)? / +// | / 0[xX][0-9a-fA-F]+ \.[0-9a-fA-F]* ([pP][+-]?[0-9]+ [fh]?)? / +// | / 0[xX][0-9a-fA-F]+ [pP][+-]?[0-9]+ [fh]? / + +// You could visualize the regex below via https://debuggex.com to get a rough idea what `parse` is doing +// (?:0[xX](?:([0-9a-fA-F]+\.[0-9a-fA-F]*|[0-9a-fA-F]*\.[0-9a-fA-F]+)(?:([pP][+-]?[0-9]+)([fh]?))?|([0-9a-fA-F]+)([pP][+-]?[0-9]+)([fh]?)|([0-9a-fA-F]+)([iu]?))|((?:[0-9]+[eE][+-]?[0-9]+|(?:[0-9]+\.[0-9]*|[0-9]*\.[0-9]+)(?:[eE][+-]?[0-9]+)?))([fh]?)|((?:[0-9]|[1-9][0-9]+))([iufh]?)) + +// Leading signs are handled as unary operators. + +fn parse(input: &str) -> (Result<Number, NumberError>, &str) { + /// returns `true` and consumes `X` bytes from the given byte buffer + /// if the given `X` nr of patterns are found at the start of the buffer + macro_rules! consume { + ($bytes:ident, $($pattern:pat),*) => { + match $bytes { + &[$($pattern),*, ref rest @ ..] => { $bytes = rest; true }, + _ => false, + } + }; + } + + /// consumes one byte from the given byte buffer + /// if one of the given patterns are found at the start of the buffer + /// returning the corresponding expr for the matched pattern + macro_rules! consume_map { + ($bytes:ident, [$( $($pattern:pat_param),* => $to:expr),* $(,)?]) => { + match $bytes { + $( &[ $($pattern),*, ref rest @ ..] => { $bytes = rest; Some($to) }, )* + _ => None, + } + }; + } + + /// consumes all consecutive bytes matched by the `0-9` pattern from the given byte buffer + /// returning the number of consumed bytes + macro_rules! consume_dec_digits { + ($bytes:ident) => {{ + let start_len = $bytes.len(); + while let &[b'0'..=b'9', ref rest @ ..] = $bytes { + $bytes = rest; + } + start_len - $bytes.len() + }}; + } + + /// consumes all consecutive bytes matched by the `0-9 | a-f | A-F` pattern from the given byte buffer + /// returning the number of consumed bytes + macro_rules! consume_hex_digits { + ($bytes:ident) => {{ + let start_len = $bytes.len(); + while let &[b'0'..=b'9' | b'a'..=b'f' | b'A'..=b'F', ref rest @ ..] = $bytes { + $bytes = rest; + } + start_len - $bytes.len() + }}; + } + + macro_rules! consume_float_suffix { + ($bytes:ident) => { + consume_map!($bytes, [ + b'h' => FloatKind::F16, + b'f' => FloatKind::F32, + b'l', b'f' => FloatKind::F64, + ]) + }; + } + + /// maps the given `&[u8]` (tail of the initial `input: &str`) to a `&str` + macro_rules! rest_to_str { + ($bytes:ident) => { + &input[input.len() - $bytes.len()..] + }; + } + + struct ExtractSubStr<'a>(&'a str); + + impl<'a> ExtractSubStr<'a> { + /// given an `input` and a `start` (tail of the `input`) + /// creates a new [`ExtractSubStr`](`Self`) + fn start(input: &'a str, start: &'a [u8]) -> Self { + let start = input.len() - start.len(); + Self(&input[start..]) + } + /// given an `end` (tail of the initial `input`) + /// returns a substring of `input` + fn end(&self, end: &'a [u8]) -> &'a str { + let end = self.0.len() - end.len(); + &self.0[..end] + } + } + + let mut bytes = input.as_bytes(); + + let general_extract = ExtractSubStr::start(input, bytes); + + if consume!(bytes, b'0', b'x' | b'X') { + let digits_extract = ExtractSubStr::start(input, bytes); + + let consumed = consume_hex_digits!(bytes); + + if consume!(bytes, b'.') { + let consumed_after_period = consume_hex_digits!(bytes); + + if consumed + consumed_after_period == 0 { + return (Err(NumberError::Invalid), rest_to_str!(bytes)); + } + + let significand = general_extract.end(bytes); + + if consume!(bytes, b'p' | b'P') { + consume!(bytes, b'+' | b'-'); + let consumed = consume_dec_digits!(bytes); + + if consumed == 0 { + return (Err(NumberError::Invalid), rest_to_str!(bytes)); + } + + let number = general_extract.end(bytes); + + let kind = consume_float_suffix!(bytes); + + (parse_hex_float(number, kind), rest_to_str!(bytes)) + } else { + ( + parse_hex_float_missing_exponent(significand, None), + rest_to_str!(bytes), + ) + } + } else { + if consumed == 0 { + return (Err(NumberError::Invalid), rest_to_str!(bytes)); + } + + let significand = general_extract.end(bytes); + let digits = digits_extract.end(bytes); + + let exp_extract = ExtractSubStr::start(input, bytes); + + if consume!(bytes, b'p' | b'P') { + consume!(bytes, b'+' | b'-'); + let consumed = consume_dec_digits!(bytes); + + if consumed == 0 { + return (Err(NumberError::Invalid), rest_to_str!(bytes)); + } + + let exponent = exp_extract.end(bytes); + + let kind = consume_float_suffix!(bytes); + + ( + parse_hex_float_missing_period(significand, exponent, kind), + rest_to_str!(bytes), + ) + } else { + let kind = consume_map!(bytes, [b'i' => IntKind::I32, b'u' => IntKind::U32]); + + (parse_hex_int(digits, kind), rest_to_str!(bytes)) + } + } + } else { + let is_first_zero = bytes.first() == Some(&b'0'); + + let consumed = consume_dec_digits!(bytes); + + if consume!(bytes, b'.') { + let consumed_after_period = consume_dec_digits!(bytes); + + if consumed + consumed_after_period == 0 { + return (Err(NumberError::Invalid), rest_to_str!(bytes)); + } + + if consume!(bytes, b'e' | b'E') { + consume!(bytes, b'+' | b'-'); + let consumed = consume_dec_digits!(bytes); + + if consumed == 0 { + return (Err(NumberError::Invalid), rest_to_str!(bytes)); + } + } + + let number = general_extract.end(bytes); + + let kind = consume_float_suffix!(bytes); + + (parse_dec_float(number, kind), rest_to_str!(bytes)) + } else { + if consumed == 0 { + return (Err(NumberError::Invalid), rest_to_str!(bytes)); + } + + if consume!(bytes, b'e' | b'E') { + consume!(bytes, b'+' | b'-'); + let consumed = consume_dec_digits!(bytes); + + if consumed == 0 { + return (Err(NumberError::Invalid), rest_to_str!(bytes)); + } + + let number = general_extract.end(bytes); + + let kind = consume_float_suffix!(bytes); + + (parse_dec_float(number, kind), rest_to_str!(bytes)) + } else { + // make sure the multi-digit numbers don't start with zero + if consumed > 1 && is_first_zero { + return (Err(NumberError::Invalid), rest_to_str!(bytes)); + } + + let digits = general_extract.end(bytes); + + let kind = consume_map!(bytes, [ + b'i' => Kind::Int(IntKind::I32), + b'u' => Kind::Int(IntKind::U32), + b'h' => Kind::Float(FloatKind::F16), + b'f' => Kind::Float(FloatKind::F32), + b'l', b'f' => Kind::Float(FloatKind::F64), + ]); + + (parse_dec(digits, kind), rest_to_str!(bytes)) + } + } + } +} + +fn parse_hex_float_missing_exponent( + // format: 0[xX] ( [0-9a-fA-F]+\.[0-9a-fA-F]* | [0-9a-fA-F]*\.[0-9a-fA-F]+ ) + significand: &str, + kind: Option<FloatKind>, +) -> Result<Number, NumberError> { + let hexf_input = format!("{}{}", significand, "p0"); + parse_hex_float(&hexf_input, kind) +} + +fn parse_hex_float_missing_period( + // format: 0[xX] [0-9a-fA-F]+ + significand: &str, + // format: [pP][+-]?[0-9]+ + exponent: &str, + kind: Option<FloatKind>, +) -> Result<Number, NumberError> { + let hexf_input = format!("{significand}.{exponent}"); + parse_hex_float(&hexf_input, kind) +} + +fn parse_hex_int( + // format: [0-9a-fA-F]+ + digits: &str, + kind: Option<IntKind>, +) -> Result<Number, NumberError> { + parse_int(digits, kind, 16) +} + +fn parse_dec( + // format: ( [0-9] | [1-9][0-9]+ ) + digits: &str, + kind: Option<Kind>, +) -> Result<Number, NumberError> { + match kind { + None => parse_int(digits, None, 10), + Some(Kind::Int(kind)) => parse_int(digits, Some(kind), 10), + Some(Kind::Float(kind)) => parse_dec_float(digits, Some(kind)), + } +} + +// Float parsing notes + +// The following chapters of IEEE 754-2019 are relevant: +// +// 7.4 Overflow (largest finite number is exceeded by what would have been +// the rounded floating-point result were the exponent range unbounded) +// +// 7.5 Underflow (tiny non-zero result is detected; +// for decimal formats tininess is detected before rounding when a non-zero result +// computed as though both the exponent range and the precision were unbounded +// would lie strictly between 2^−126) +// +// 7.6 Inexact (rounded result differs from what would have been computed +// were both exponent range and precision unbounded) + +// The WGSL spec requires us to error: +// on overflow for decimal floating point literals +// on overflow and inexact for hexadecimal floating point literals +// (underflow is not mentioned) + +// hexf_parse errors on overflow, underflow, inexact +// rust std lib float from str handles overflow, underflow, inexact transparently (rounds and will not error) + +// Therefore we only check for overflow manually for decimal floating point literals + +// input format: 0[xX] ( [0-9a-fA-F]+\.[0-9a-fA-F]* | [0-9a-fA-F]*\.[0-9a-fA-F]+ ) [pP][+-]?[0-9]+ +fn parse_hex_float(input: &str, kind: Option<FloatKind>) -> Result<Number, NumberError> { + match kind { + None => match hexf_parse::parse_hexf64(input, false) { + Ok(num) => Ok(Number::AbstractFloat(num)), + // can only be ParseHexfErrorKind::Inexact but we can't check since it's private + _ => Err(NumberError::NotRepresentable), + }, + Some(FloatKind::F16) => Err(NumberError::UnimplementedF16), + Some(FloatKind::F32) => match hexf_parse::parse_hexf32(input, false) { + Ok(num) => Ok(Number::F32(num)), + // can only be ParseHexfErrorKind::Inexact but we can't check since it's private + _ => Err(NumberError::NotRepresentable), + }, + Some(FloatKind::F64) => match hexf_parse::parse_hexf64(input, false) { + Ok(num) => Ok(Number::F64(num)), + // can only be ParseHexfErrorKind::Inexact but we can't check since it's private + _ => Err(NumberError::NotRepresentable), + }, + } +} + +// input format: ( [0-9]+\.[0-9]* | [0-9]*\.[0-9]+ ) ([eE][+-]?[0-9]+)? +// | [0-9]+ [eE][+-]?[0-9]+ +fn parse_dec_float(input: &str, kind: Option<FloatKind>) -> Result<Number, NumberError> { + match kind { + None => { + let num = input.parse::<f64>().unwrap(); // will never fail + num.is_finite() + .then_some(Number::AbstractFloat(num)) + .ok_or(NumberError::NotRepresentable) + } + Some(FloatKind::F32) => { + let num = input.parse::<f32>().unwrap(); // will never fail + num.is_finite() + .then_some(Number::F32(num)) + .ok_or(NumberError::NotRepresentable) + } + Some(FloatKind::F64) => { + let num = input.parse::<f64>().unwrap(); // will never fail + num.is_finite() + .then_some(Number::F64(num)) + .ok_or(NumberError::NotRepresentable) + } + Some(FloatKind::F16) => Err(NumberError::UnimplementedF16), + } +} + +fn parse_int(input: &str, kind: Option<IntKind>, radix: u32) -> Result<Number, NumberError> { + fn map_err(e: core::num::ParseIntError) -> NumberError { + match *e.kind() { + core::num::IntErrorKind::PosOverflow | core::num::IntErrorKind::NegOverflow => { + NumberError::NotRepresentable + } + _ => unreachable!(), + } + } + match kind { + None => match i64::from_str_radix(input, radix) { + Ok(num) => Ok(Number::AbstractInt(num)), + Err(e) => Err(map_err(e)), + }, + Some(IntKind::I32) => match i32::from_str_radix(input, radix) { + Ok(num) => Ok(Number::I32(num)), + Err(e) => Err(map_err(e)), + }, + Some(IntKind::U32) => match u32::from_str_radix(input, radix) { + Ok(num) => Ok(Number::U32(num)), + Err(e) => Err(map_err(e)), + }, + } +} diff --git a/third_party/rust/naga/src/front/wgsl/tests.rs b/third_party/rust/naga/src/front/wgsl/tests.rs new file mode 100644 index 0000000000..eb2f8a2eb3 --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/tests.rs @@ -0,0 +1,637 @@ +use super::parse_str; + +#[test] +fn parse_comment() { + parse_str( + "// + //// + ///////////////////////////////////////////////////////// asda + //////////////////// dad ////////// / + ///////////////////////////////////////////////////////////////////////////////////////////////////// + // + ", + ) + .unwrap(); +} + +#[test] +fn parse_types() { + parse_str("const a : i32 = 2;").unwrap(); + assert!(parse_str("const a : x32 = 2;").is_err()); + parse_str("var t: texture_2d<f32>;").unwrap(); + parse_str("var t: texture_cube_array<i32>;").unwrap(); + parse_str("var t: texture_multisampled_2d<u32>;").unwrap(); + parse_str("var t: texture_storage_1d<rgba8uint,write>;").unwrap(); + parse_str("var t: texture_storage_3d<r32float,read>;").unwrap(); +} + +#[test] +fn parse_type_inference() { + parse_str( + " + fn foo() { + let a = 2u; + let b: u32 = a; + var x = 3.; + var y = vec2<f32>(1, 2); + }", + ) + .unwrap(); + assert!(parse_str( + " + fn foo() { let c : i32 = 2.0; }", + ) + .is_err()); +} + +#[test] +fn parse_type_cast() { + parse_str( + " + const a : i32 = 2; + fn main() { + var x: f32 = f32(a); + x = f32(i32(a + 1) / 2); + } + ", + ) + .unwrap(); + parse_str( + " + fn main() { + let x: vec2<f32> = vec2<f32>(1.0, 2.0); + let y: vec2<u32> = vec2<u32>(x); + } + ", + ) + .unwrap(); + parse_str( + " + fn main() { + let x: vec2<f32> = vec2<f32>(0.0); + } + ", + ) + .unwrap(); + assert!(parse_str( + " + fn main() { + let x: vec2<f32> = vec2<f32>(0i, 0i); + } + ", + ) + .is_err()); +} + +#[test] +fn parse_struct() { + parse_str( + " + struct Foo { x: i32 } + struct Bar { + @size(16) x: vec2<i32>, + @align(16) y: f32, + @size(32) @align(128) z: vec3<f32>, + }; + struct Empty {} + var<storage,read_write> s: Foo; + ", + ) + .unwrap(); +} + +#[test] +fn parse_standard_fun() { + parse_str( + " + fn main() { + var x: i32 = min(max(1, 2), 3); + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_statement() { + parse_str( + " + fn main() { + ; + {} + {;} + } + ", + ) + .unwrap(); + + parse_str( + " + fn foo() {} + fn bar() { foo(); } + ", + ) + .unwrap(); +} + +#[test] +fn parse_if() { + parse_str( + " + fn main() { + if true { + discard; + } else {} + if 0 != 1 {} + if false { + return; + } else if true { + return; + } else {} + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_parentheses_if() { + parse_str( + " + fn main() { + if (true) { + discard; + } else {} + if (0 != 1) {} + if (false) { + return; + } else if (true) { + return; + } else {} + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_loop() { + parse_str( + " + fn main() { + var i: i32 = 0; + loop { + if i == 1 { break; } + continuing { i = 1; } + } + loop { + if i == 0 { continue; } + break; + } + } + ", + ) + .unwrap(); + parse_str( + " + fn main() { + var found: bool = false; + var i: i32 = 0; + while !found { + if i == 10 { + found = true; + } + + i = i + 1; + } + } + ", + ) + .unwrap(); + parse_str( + " + fn main() { + while true { + break; + } + } + ", + ) + .unwrap(); + parse_str( + " + fn main() { + var a: i32 = 0; + for(var i: i32 = 0; i < 4; i = i + 1) { + a = a + 2; + } + } + ", + ) + .unwrap(); + parse_str( + " + fn main() { + for(;;) { + break; + } + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_switch() { + parse_str( + " + fn main() { + var pos: f32; + switch (3) { + case 0, 1: { pos = 0.0; } + case 2: { pos = 1.0; } + default: { pos = 3.0; } + } + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_switch_optional_colon_in_case() { + parse_str( + " + fn main() { + var pos: f32; + switch (3) { + case 0, 1 { pos = 0.0; } + case 2 { pos = 1.0; } + default { pos = 3.0; } + } + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_switch_default_in_case() { + parse_str( + " + fn main() { + var pos: f32; + switch (3) { + case 0, 1: { pos = 0.0; } + case 2: {} + case default, 3: { pos = 3.0; } + } + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_parentheses_switch() { + parse_str( + " + fn main() { + var pos: f32; + switch pos > 1.0 { + default: { pos = 3.0; } + } + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_texture_load() { + parse_str( + " + var t: texture_3d<u32>; + fn foo() { + let r: vec4<u32> = textureLoad(t, vec3<u32>(0u, 1u, 2u), 1); + } + ", + ) + .unwrap(); + parse_str( + " + var t: texture_multisampled_2d_array<i32>; + fn foo() { + let r: vec4<i32> = textureLoad(t, vec2<i32>(10, 20), 2, 3); + } + ", + ) + .unwrap(); + parse_str( + " + var t: texture_storage_1d_array<r32float,read>; + fn foo() { + let r: vec4<f32> = textureLoad(t, 10, 2); + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_texture_store() { + parse_str( + " + var t: texture_storage_2d<rgba8unorm,write>; + fn foo() { + textureStore(t, vec2<i32>(10, 20), vec4<f32>(0.0, 1.0, 2.0, 3.0)); + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_texture_query() { + parse_str( + " + var t: texture_multisampled_2d_array<f32>; + fn foo() { + var dim: vec2<u32> = textureDimensions(t); + dim = textureDimensions(t, 0); + let layers: u32 = textureNumLayers(t); + let samples: u32 = textureNumSamples(t); + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_postfix() { + parse_str( + "fn foo() { + let x: f32 = vec4<f32>(1.0, 2.0, 3.0, 4.0).xyz.rgbr.aaaa.wz.g; + let y: f32 = fract(vec2<f32>(0.5, x)).x; + }", + ) + .unwrap(); +} + +#[test] +fn parse_expressions() { + parse_str("fn foo() { + let x: f32 = select(0.0, 1.0, true); + let y: vec2<f32> = select(vec2<f32>(1.0, 1.0), vec2<f32>(x, x), vec2<bool>(x < 0.5, x > 0.5)); + let z: bool = !(0.0 == 1.0); + }").unwrap(); +} + +#[test] +fn binary_expression_mixed_scalar_and_vector_operands() { + for (operand, expect_splat) in [ + ('<', false), + ('>', false), + ('&', false), + ('|', false), + ('+', true), + ('-', true), + ('*', false), + ('/', true), + ('%', true), + ] { + let module = parse_str(&format!( + " + @fragment + fn main(@location(0) some_vec: vec3<f32>) -> @location(0) vec4<f32> {{ + if (all(1.0 {operand} some_vec)) {{ + return vec4(0.0); + }} + return vec4(1.0); + }} + " + )) + .unwrap(); + + let expressions = &&module.entry_points[0].function.expressions; + + let found_expressions = expressions + .iter() + .filter(|&(_, e)| { + if let crate::Expression::Binary { left, .. } = *e { + matches!( + (expect_splat, &expressions[left]), + (false, &crate::Expression::Literal(crate::Literal::F32(..))) + | (true, &crate::Expression::Splat { .. }) + ) + } else { + false + } + }) + .count(); + + assert_eq!( + found_expressions, + 1, + "expected `{operand}` expression {} splat", + if expect_splat { "with" } else { "without" } + ); + } + + let module = parse_str( + "@fragment + fn main(mat: mat3x3<f32>) { + let vec = vec3<f32>(1.0, 1.0, 1.0); + let result = mat / vec; + }", + ) + .unwrap(); + let expressions = &&module.entry_points[0].function.expressions; + let found_splat = expressions.iter().any(|(_, e)| { + if let crate::Expression::Binary { left, .. } = *e { + matches!(&expressions[left], &crate::Expression::Splat { .. }) + } else { + false + } + }); + assert!(!found_splat, "'mat / vec' should not be splatted"); +} + +#[test] +fn parse_pointers() { + parse_str( + "fn foo(a: ptr<private, f32>) -> f32 { return *a; } + fn bar() { + var x: f32 = 1.0; + let px = &x; + let py = foo(px); + }", + ) + .unwrap(); +} + +#[test] +fn parse_struct_instantiation() { + parse_str( + " + struct Foo { + a: f32, + b: vec3<f32>, + } + + @fragment + fn fs_main() { + var foo: Foo = Foo(0.0, vec3<f32>(0.0, 1.0, 42.0)); + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_array_length() { + parse_str( + " + struct Foo { + data: array<u32> + } // this is used as both input and output for convenience + + @group(0) @binding(0) + var<storage> foo: Foo; + + @group(0) @binding(1) + var<storage> bar: array<u32>; + + fn baz() { + var x: u32 = arrayLength(foo.data); + var y: u32 = arrayLength(bar); + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_storage_buffers() { + parse_str( + " + @group(0) @binding(0) + var<storage> foo: array<u32>; + ", + ) + .unwrap(); + parse_str( + " + @group(0) @binding(0) + var<storage,read> foo: array<u32>; + ", + ) + .unwrap(); + parse_str( + " + @group(0) @binding(0) + var<storage,write> foo: array<u32>; + ", + ) + .unwrap(); + parse_str( + " + @group(0) @binding(0) + var<storage,read_write> foo: array<u32>; + ", + ) + .unwrap(); +} + +#[test] +fn parse_alias() { + parse_str( + " + alias Vec4 = vec4<f32>; + ", + ) + .unwrap(); +} + +#[test] +fn parse_texture_load_store_expecting_four_args() { + for (func, texture) in [ + ( + "textureStore", + "texture_storage_2d_array<rg11b10float, write>", + ), + ("textureLoad", "texture_2d_array<i32>"), + ] { + let error = parse_str(&format!( + " + @group(0) @binding(0) var tex_los_res: {texture}; + @compute + @workgroup_size(1) + fn main(@builtin(global_invocation_id) id: vec3<u32>) {{ + var color = vec4(1, 1, 1, 1); + {func}(tex_los_res, id, color); + }} + " + )) + .unwrap_err(); + assert_eq!( + error.message(), + "wrong number of arguments: expected 4, found 3" + ); + } +} + +#[test] +fn parse_repeated_attributes() { + use crate::{ + front::wgsl::{error::Error, Frontend}, + Span, + }; + + let template_vs = "@vertex fn vs() -> __REPLACE__ vec4<f32> { return vec4<f32>(0.0); }"; + let template_struct = "struct A { __REPLACE__ data: vec3<f32> }"; + let template_resource = "__REPLACE__ var tex_los_res: texture_2d_array<i32>;"; + let template_stage = "__REPLACE__ fn vs() -> vec4<f32> { return vec4<f32>(0.0); }"; + for (attribute, template) in [ + ("align(16)", template_struct), + ("binding(0)", template_resource), + ("builtin(position)", template_vs), + ("compute", template_stage), + ("fragment", template_stage), + ("group(0)", template_resource), + ("interpolate(flat)", template_vs), + ("invariant", template_vs), + ("location(0)", template_vs), + ("size(16)", template_struct), + ("vertex", template_stage), + ("early_depth_test(less_equal)", template_resource), + ("workgroup_size(1)", template_stage), + ] { + let shader = template.replace("__REPLACE__", &format!("@{attribute} @{attribute}")); + let name_length = attribute.rfind('(').unwrap_or(attribute.len()) as u32; + let span_start = shader.rfind(attribute).unwrap() as u32; + let span_end = span_start + name_length; + let expected_span = Span::new(span_start, span_end); + + let result = Frontend::new().inner(&shader); + assert!(matches!( + result.unwrap_err(), + Error::RepeatedAttribute(span) if span == expected_span + )); + } +} + +#[test] +fn parse_missing_workgroup_size() { + use crate::{ + front::wgsl::{error::Error, Frontend}, + Span, + }; + + let shader = "@compute fn vs() -> vec4<f32> { return vec4<f32>(0.0); }"; + let result = Frontend::new().inner(shader); + assert!(matches!( + result.unwrap_err(), + Error::MissingWorkgroupSize(span) if span == Span::new(1, 8) + )); +} diff --git a/third_party/rust/naga/src/front/wgsl/to_wgsl.rs b/third_party/rust/naga/src/front/wgsl/to_wgsl.rs new file mode 100644 index 0000000000..c8331ace09 --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/to_wgsl.rs @@ -0,0 +1,283 @@ +//! Producing the WGSL forms of types, for use in error messages. + +use crate::proc::GlobalCtx; +use crate::Handle; + +impl crate::proc::TypeResolution { + pub fn to_wgsl(&self, gctx: &GlobalCtx) -> String { + match *self { + crate::proc::TypeResolution::Handle(handle) => handle.to_wgsl(gctx), + crate::proc::TypeResolution::Value(ref inner) => inner.to_wgsl(gctx), + } + } +} + +impl Handle<crate::Type> { + /// Formats the type as it is written in wgsl. + /// + /// For example `vec3<f32>`. + pub fn to_wgsl(self, gctx: &GlobalCtx) -> String { + let ty = &gctx.types[self]; + match ty.name { + Some(ref name) => name.clone(), + None => ty.inner.to_wgsl(gctx), + } + } +} + +impl crate::TypeInner { + /// Formats the type as it is written in wgsl. + /// + /// For example `vec3<f32>`. + /// + /// Note: `TypeInner::Struct` doesn't include the name of the + /// struct type. Therefore this method will simply return "struct" + /// for them. + pub fn to_wgsl(&self, gctx: &GlobalCtx) -> String { + use crate::TypeInner as Ti; + + match *self { + Ti::Scalar(scalar) => scalar.to_wgsl(), + Ti::Vector { size, scalar } => { + format!("vec{}<{}>", size as u32, scalar.to_wgsl()) + } + Ti::Matrix { + columns, + rows, + scalar, + } => { + format!( + "mat{}x{}<{}>", + columns as u32, + rows as u32, + scalar.to_wgsl(), + ) + } + Ti::Atomic(scalar) => { + format!("atomic<{}>", scalar.to_wgsl()) + } + Ti::Pointer { base, .. } => { + let name = base.to_wgsl(gctx); + format!("ptr<{name}>") + } + Ti::ValuePointer { scalar, .. } => { + format!("ptr<{}>", scalar.to_wgsl()) + } + Ti::Array { base, size, .. } => { + let base = base.to_wgsl(gctx); + match size { + crate::ArraySize::Constant(size) => format!("array<{base}, {size}>"), + crate::ArraySize::Dynamic => format!("array<{base}>"), + } + } + Ti::Struct { .. } => { + // TODO: Actually output the struct? + "struct".to_string() + } + Ti::Image { + dim, + arrayed, + class, + } => { + let dim_suffix = match dim { + crate::ImageDimension::D1 => "_1d", + crate::ImageDimension::D2 => "_2d", + crate::ImageDimension::D3 => "_3d", + crate::ImageDimension::Cube => "_cube", + }; + let array_suffix = if arrayed { "_array" } else { "" }; + + let class_suffix = match class { + crate::ImageClass::Sampled { multi: true, .. } => "_multisampled", + crate::ImageClass::Depth { multi: false } => "_depth", + crate::ImageClass::Depth { multi: true } => "_depth_multisampled", + crate::ImageClass::Sampled { multi: false, .. } + | crate::ImageClass::Storage { .. } => "", + }; + + let type_in_brackets = match class { + crate::ImageClass::Sampled { kind, .. } => { + // Note: The only valid widths are 4 bytes wide. + // The lexer has already verified this, so we can safely assume it here. + // https://gpuweb.github.io/gpuweb/wgsl/#sampled-texture-type + let element_type = crate::Scalar { kind, width: 4 }.to_wgsl(); + format!("<{element_type}>") + } + crate::ImageClass::Depth { multi: _ } => String::new(), + crate::ImageClass::Storage { format, access } => { + if access.contains(crate::StorageAccess::STORE) { + format!("<{},write>", format.to_wgsl()) + } else { + format!("<{}>", format.to_wgsl()) + } + } + }; + + format!("texture{class_suffix}{dim_suffix}{array_suffix}{type_in_brackets}") + } + Ti::Sampler { .. } => "sampler".to_string(), + Ti::AccelerationStructure => "acceleration_structure".to_string(), + Ti::RayQuery => "ray_query".to_string(), + Ti::BindingArray { base, size, .. } => { + let member_type = &gctx.types[base]; + let base = member_type.name.as_deref().unwrap_or("unknown"); + match size { + crate::ArraySize::Constant(size) => format!("binding_array<{base}, {size}>"), + crate::ArraySize::Dynamic => format!("binding_array<{base}>"), + } + } + } + } +} + +impl crate::Scalar { + /// Format a scalar kind+width as a type is written in wgsl. + /// + /// Examples: `f32`, `u64`, `bool`. + pub fn to_wgsl(self) -> String { + let prefix = match self.kind { + crate::ScalarKind::Sint => "i", + crate::ScalarKind::Uint => "u", + crate::ScalarKind::Float => "f", + crate::ScalarKind::Bool => return "bool".to_string(), + crate::ScalarKind::AbstractInt => return "{AbstractInt}".to_string(), + crate::ScalarKind::AbstractFloat => return "{AbstractFloat}".to_string(), + }; + format!("{}{}", prefix, self.width * 8) + } +} + +impl crate::StorageFormat { + pub const fn to_wgsl(self) -> &'static str { + use crate::StorageFormat as Sf; + match self { + Sf::R8Unorm => "r8unorm", + Sf::R8Snorm => "r8snorm", + Sf::R8Uint => "r8uint", + Sf::R8Sint => "r8sint", + Sf::R16Uint => "r16uint", + Sf::R16Sint => "r16sint", + Sf::R16Float => "r16float", + Sf::Rg8Unorm => "rg8unorm", + Sf::Rg8Snorm => "rg8snorm", + Sf::Rg8Uint => "rg8uint", + Sf::Rg8Sint => "rg8sint", + Sf::R32Uint => "r32uint", + Sf::R32Sint => "r32sint", + Sf::R32Float => "r32float", + Sf::Rg16Uint => "rg16uint", + Sf::Rg16Sint => "rg16sint", + Sf::Rg16Float => "rg16float", + Sf::Rgba8Unorm => "rgba8unorm", + Sf::Rgba8Snorm => "rgba8snorm", + Sf::Rgba8Uint => "rgba8uint", + Sf::Rgba8Sint => "rgba8sint", + Sf::Bgra8Unorm => "bgra8unorm", + Sf::Rgb10a2Uint => "rgb10a2uint", + Sf::Rgb10a2Unorm => "rgb10a2unorm", + Sf::Rg11b10Float => "rg11b10float", + Sf::Rg32Uint => "rg32uint", + Sf::Rg32Sint => "rg32sint", + Sf::Rg32Float => "rg32float", + Sf::Rgba16Uint => "rgba16uint", + Sf::Rgba16Sint => "rgba16sint", + Sf::Rgba16Float => "rgba16float", + Sf::Rgba32Uint => "rgba32uint", + Sf::Rgba32Sint => "rgba32sint", + Sf::Rgba32Float => "rgba32float", + Sf::R16Unorm => "r16unorm", + Sf::R16Snorm => "r16snorm", + Sf::Rg16Unorm => "rg16unorm", + Sf::Rg16Snorm => "rg16snorm", + Sf::Rgba16Unorm => "rgba16unorm", + Sf::Rgba16Snorm => "rgba16snorm", + } + } +} + +mod tests { + #[test] + fn to_wgsl() { + use std::num::NonZeroU32; + + let mut types = crate::UniqueArena::new(); + + let mytype1 = types.insert( + crate::Type { + name: Some("MyType1".to_string()), + inner: crate::TypeInner::Struct { + members: vec![], + span: 0, + }, + }, + Default::default(), + ); + let mytype2 = types.insert( + crate::Type { + name: Some("MyType2".to_string()), + inner: crate::TypeInner::Struct { + members: vec![], + span: 0, + }, + }, + Default::default(), + ); + + let gctx = crate::proc::GlobalCtx { + types: &types, + constants: &crate::Arena::new(), + const_expressions: &crate::Arena::new(), + }; + let array = crate::TypeInner::Array { + base: mytype1, + stride: 4, + size: crate::ArraySize::Constant(unsafe { NonZeroU32::new_unchecked(32) }), + }; + assert_eq!(array.to_wgsl(&gctx), "array<MyType1, 32>"); + + let mat = crate::TypeInner::Matrix { + rows: crate::VectorSize::Quad, + columns: crate::VectorSize::Bi, + scalar: crate::Scalar::F64, + }; + assert_eq!(mat.to_wgsl(&gctx), "mat2x4<f64>"); + + let ptr = crate::TypeInner::Pointer { + base: mytype2, + space: crate::AddressSpace::Storage { + access: crate::StorageAccess::default(), + }, + }; + assert_eq!(ptr.to_wgsl(&gctx), "ptr<MyType2>"); + + let img1 = crate::TypeInner::Image { + dim: crate::ImageDimension::D2, + arrayed: false, + class: crate::ImageClass::Sampled { + kind: crate::ScalarKind::Float, + multi: true, + }, + }; + assert_eq!(img1.to_wgsl(&gctx), "texture_multisampled_2d<f32>"); + + let img2 = crate::TypeInner::Image { + dim: crate::ImageDimension::Cube, + arrayed: true, + class: crate::ImageClass::Depth { multi: false }, + }; + assert_eq!(img2.to_wgsl(&gctx), "texture_depth_cube_array"); + + let img3 = crate::TypeInner::Image { + dim: crate::ImageDimension::D2, + arrayed: false, + class: crate::ImageClass::Depth { multi: true }, + }; + assert_eq!(img3.to_wgsl(&gctx), "texture_depth_multisampled_2d"); + + let array = crate::TypeInner::BindingArray { + base: mytype1, + size: crate::ArraySize::Constant(unsafe { NonZeroU32::new_unchecked(32) }), + }; + assert_eq!(array.to_wgsl(&gctx), "binding_array<MyType1, 32>"); + } +} |