summaryrefslogtreecommitdiffstats
path: root/third_party/rust/naga/src/front
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 19:33:14 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 19:33:14 +0000
commit36d22d82aa202bb199967e9512281e9a53db42c9 (patch)
tree105e8c98ddea1c1e4784a60a5a6410fa416be2de /third_party/rust/naga/src/front
parentInitial commit. (diff)
downloadfirefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.tar.xz
firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.zip
Adding upstream version 115.7.0esr.upstream/115.7.0esrupstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/naga/src/front')
-rw-r--r--third_party/rust/naga/src/front/glsl/ast.rs393
-rw-r--r--third_party/rust/naga/src/front/glsl/builtins.rs2497
-rw-r--r--third_party/rust/naga/src/front/glsl/constants.rs984
-rw-r--r--third_party/rust/naga/src/front/glsl/context.rs1609
-rw-r--r--third_party/rust/naga/src/front/glsl/error.rs134
-rw-r--r--third_party/rust/naga/src/front/glsl/functions.rs1680
-rw-r--r--third_party/rust/naga/src/front/glsl/lex.rs301
-rw-r--r--third_party/rust/naga/src/front/glsl/mod.rs235
-rw-r--r--third_party/rust/naga/src/front/glsl/offset.rs173
-rw-r--r--third_party/rust/naga/src/front/glsl/parser.rs448
-rw-r--r--third_party/rust/naga/src/front/glsl/parser/declarations.rs671
-rw-r--r--third_party/rust/naga/src/front/glsl/parser/expressions.rs546
-rw-r--r--third_party/rust/naga/src/front/glsl/parser/functions.rs654
-rw-r--r--third_party/rust/naga/src/front/glsl/parser/types.rs434
-rw-r--r--third_party/rust/naga/src/front/glsl/parser_tests.rs832
-rw-r--r--third_party/rust/naga/src/front/glsl/token.rs137
-rw-r--r--third_party/rust/naga/src/front/glsl/types.rs335
-rw-r--r--third_party/rust/naga/src/front/glsl/variables.rs679
-rw-r--r--third_party/rust/naga/src/front/interpolator.rs61
-rw-r--r--third_party/rust/naga/src/front/mod.rs374
-rw-r--r--third_party/rust/naga/src/front/spv/convert.rs181
-rw-r--r--third_party/rust/naga/src/front/spv/error.rs127
-rw-r--r--third_party/rust/naga/src/front/spv/function.rs661
-rw-r--r--third_party/rust/naga/src/front/spv/image.rs737
-rw-r--r--third_party/rust/naga/src/front/spv/mod.rs5195
-rw-r--r--third_party/rust/naga/src/front/spv/null.rs175
-rw-r--r--third_party/rust/naga/src/front/type_gen.rs314
-rw-r--r--third_party/rust/naga/src/front/wgsl/error.rs690
-rw-r--r--third_party/rust/naga/src/front/wgsl/index.rs193
-rw-r--r--third_party/rust/naga/src/front/wgsl/lower/construction.rs668
-rw-r--r--third_party/rust/naga/src/front/wgsl/lower/mod.rs2426
-rw-r--r--third_party/rust/naga/src/front/wgsl/mod.rs340
-rw-r--r--third_party/rust/naga/src/front/wgsl/parse/ast.rs486
-rw-r--r--third_party/rust/naga/src/front/wgsl/parse/conv.rs236
-rw-r--r--third_party/rust/naga/src/front/wgsl/parse/lexer.rs723
-rw-r--r--third_party/rust/naga/src/front/wgsl/parse/mod.rs2298
-rw-r--r--third_party/rust/naga/src/front/wgsl/parse/number.rs443
-rw-r--r--third_party/rust/naga/src/front/wgsl/tests.rs483
38 files changed, 29553 insertions, 0 deletions
diff --git a/third_party/rust/naga/src/front/glsl/ast.rs b/third_party/rust/naga/src/front/glsl/ast.rs
new file mode 100644
index 0000000000..cbb21f6de9
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/ast.rs
@@ -0,0 +1,393 @@
+use std::{borrow::Cow, fmt};
+
+use super::{builtins::MacroCall, context::ExprPos, Span};
+use crate::{
+ AddressSpace, BinaryOperator, Binding, Constant, Expression, Function, GlobalVariable, Handle,
+ Interpolation, Sampling, StorageAccess, Type, UnaryOperator,
+};
+
+#[derive(Debug, Clone, Copy)]
+pub enum GlobalLookupKind {
+ Variable(Handle<GlobalVariable>),
+ Constant(Handle<Constant>, Handle<Type>),
+ BlockSelect(Handle<GlobalVariable>, u32),
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct GlobalLookup {
+ pub kind: GlobalLookupKind,
+ pub entry_arg: Option<usize>,
+ pub mutable: bool,
+}
+
+#[derive(Debug, Clone)]
+pub struct ParameterInfo {
+ pub qualifier: ParameterQualifier,
+ /// Whether the parameter should be treated as a depth image instead of a
+ /// sampled image.
+ pub depth: bool,
+}
+
+/// How the function is implemented
+#[derive(Clone, Copy)]
+pub enum FunctionKind {
+ /// The function is user defined
+ Call(Handle<Function>),
+ /// The function is a builtin
+ Macro(MacroCall),
+}
+
+impl fmt::Debug for FunctionKind {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match *self {
+ Self::Call(_) => write!(f, "Call"),
+ Self::Macro(_) => write!(f, "Macro"),
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct Overload {
+ /// Normalized function parameters, modifiers are not applied
+ pub parameters: Vec<Handle<Type>>,
+ pub parameters_info: Vec<ParameterInfo>,
+ /// How the function is implemented
+ pub kind: FunctionKind,
+ /// Whether this function was already defined or is just a prototype
+ pub defined: bool,
+ /// Whether this overload is the one provided by the language or has
+ /// been redeclared by the user (builtins only)
+ pub internal: bool,
+ /// Whether or not this function returns void (nothing)
+ pub void: bool,
+}
+
+bitflags::bitflags! {
+ /// Tracks the variations of the builtin already generated, this is needed because some
+ /// builtins overloads can't be generated unless explicitly used, since they might cause
+ /// unneeded capabilities to be requested
+ #[derive(Default)]
+ pub struct BuiltinVariations: u32 {
+ /// Request the standard overloads
+ const STANDARD = 1 << 0;
+ /// Request overloads that use the double type
+ const DOUBLE = 1 << 1;
+ /// Request overloads that use samplerCubeArray(Shadow)
+ const CUBE_TEXTURES_ARRAY = 1 << 2;
+ /// Request overloads that use sampler2DMSArray
+ const D2_MULTI_TEXTURES_ARRAY = 1 << 3;
+ }
+}
+
+#[derive(Debug, Default)]
+pub struct FunctionDeclaration {
+ pub overloads: Vec<Overload>,
+ /// Tracks the builtin overload variations that were already generated
+ pub variations: BuiltinVariations,
+}
+
+#[derive(Debug)]
+pub struct EntryArg {
+ pub name: Option<String>,
+ pub binding: Binding,
+ pub handle: Handle<GlobalVariable>,
+ pub storage: StorageQualifier,
+}
+
+#[derive(Debug, Clone)]
+pub struct VariableReference {
+ pub expr: Handle<Expression>,
+ /// Wether the variable is of a pointer type (and needs loading) or not
+ pub load: bool,
+ /// Wether the value of the variable can be changed or not
+ pub mutable: bool,
+ pub constant: Option<(Handle<Constant>, Handle<Type>)>,
+ pub entry_arg: Option<usize>,
+}
+
+#[derive(Debug, Clone)]
+pub struct HirExpr {
+ pub kind: HirExprKind,
+ pub meta: Span,
+}
+
+#[derive(Debug, Clone)]
+pub enum HirExprKind {
+ Access {
+ base: Handle<HirExpr>,
+ index: Handle<HirExpr>,
+ },
+ Select {
+ base: Handle<HirExpr>,
+ field: String,
+ },
+ Constant(Handle<Constant>),
+ Binary {
+ left: Handle<HirExpr>,
+ op: BinaryOperator,
+ right: Handle<HirExpr>,
+ },
+ Unary {
+ op: UnaryOperator,
+ expr: Handle<HirExpr>,
+ },
+ Variable(VariableReference),
+ Call(FunctionCall),
+ /// Represents the ternary operator in glsl (`:?`)
+ Conditional {
+ /// The expression that will decide which branch to take, must evaluate to a boolean
+ condition: Handle<HirExpr>,
+ /// The expression that will be evaluated if [`condition`] returns `true`
+ ///
+ /// [`condition`]: Self::Conditional::condition
+ accept: Handle<HirExpr>,
+ /// The expression that will be evaluated if [`condition`] returns `false`
+ ///
+ /// [`condition`]: Self::Conditional::condition
+ reject: Handle<HirExpr>,
+ },
+ Assign {
+ tgt: Handle<HirExpr>,
+ value: Handle<HirExpr>,
+ },
+ /// A prefix/postfix operator like `++`
+ PrePostfix {
+ /// The operation to be performed
+ op: BinaryOperator,
+ /// Whether this is a postfix or a prefix
+ postfix: bool,
+ /// The target expression
+ expr: Handle<HirExpr>,
+ },
+ /// A method call like `what.something(a, b, c)`
+ Method {
+ /// expression the method call applies to (`what` in the example)
+ expr: Handle<HirExpr>,
+ /// the method name (`something` in the example)
+ name: String,
+ /// the arguments to the method (`a`, `b`, and `c` in the example)
+ args: Vec<Handle<HirExpr>>,
+ },
+}
+
+#[derive(Debug, Hash, PartialEq, Eq)]
+pub enum QualifierKey<'a> {
+ String(Cow<'a, str>),
+ /// Used for `std140` and `std430` layout qualifiers
+ Layout,
+ /// Used for image formats
+ Format,
+}
+
+#[derive(Debug)]
+pub enum QualifierValue {
+ None,
+ Uint(u32),
+ Layout(StructLayout),
+ Format(crate::StorageFormat),
+}
+
+#[derive(Debug, Default)]
+pub struct TypeQualifiers<'a> {
+ pub span: Span,
+ pub storage: (StorageQualifier, Span),
+ pub invariant: Option<Span>,
+ pub interpolation: Option<(Interpolation, Span)>,
+ pub precision: Option<(Precision, Span)>,
+ pub sampling: Option<(Sampling, Span)>,
+ /// Memory qualifiers used in the declaration to set the storage access to be used
+ /// in declarations that support it (storage images and buffers)
+ pub storage_access: Option<(StorageAccess, Span)>,
+ pub layout_qualifiers: crate::FastHashMap<QualifierKey<'a>, (QualifierValue, Span)>,
+}
+
+impl<'a> TypeQualifiers<'a> {
+ /// Appends `errors` with errors for all unused qualifiers
+ pub fn unused_errors(&self, errors: &mut Vec<super::Error>) {
+ if let Some(meta) = self.invariant {
+ errors.push(super::Error {
+ kind: super::ErrorKind::SemanticError(
+ "Invariant qualifier can only be used in in/out variables".into(),
+ ),
+ meta,
+ });
+ }
+
+ if let Some((_, meta)) = self.interpolation {
+ errors.push(super::Error {
+ kind: super::ErrorKind::SemanticError(
+ "Interpolation qualifiers can only be used in in/out variables".into(),
+ ),
+ meta,
+ });
+ }
+
+ if let Some((_, meta)) = self.sampling {
+ errors.push(super::Error {
+ kind: super::ErrorKind::SemanticError(
+ "Sampling qualifiers can only be used in in/out variables".into(),
+ ),
+ meta,
+ });
+ }
+
+ if let Some((_, meta)) = self.storage_access {
+ errors.push(super::Error {
+ kind: super::ErrorKind::SemanticError(
+ "Memory qualifiers can only be used in storage variables".into(),
+ ),
+ meta,
+ });
+ }
+
+ for &(_, meta) in self.layout_qualifiers.values() {
+ errors.push(super::Error {
+ kind: super::ErrorKind::SemanticError("Unexpected qualifier".into()),
+ meta,
+ });
+ }
+ }
+
+ /// Removes the layout qualifier with `name`, if it exists and adds an error if it isn't
+ /// a [`QualifierValue::Uint`]
+ pub fn uint_layout_qualifier(
+ &mut self,
+ name: &'a str,
+ errors: &mut Vec<super::Error>,
+ ) -> Option<u32> {
+ match self
+ .layout_qualifiers
+ .remove(&QualifierKey::String(name.into()))
+ {
+ Some((QualifierValue::Uint(v), _)) => Some(v),
+ Some((_, meta)) => {
+ errors.push(super::Error {
+ kind: super::ErrorKind::SemanticError("Qualifier expects a uint value".into()),
+ meta,
+ });
+ // Return a dummy value instead of `None` to differentiate from
+ // the qualifier not existing, since some parts might require the
+ // qualifier to exist and throwing another error that it doesn't
+ // exist would be unhelpful
+ Some(0)
+ }
+ _ => None,
+ }
+ }
+
+ /// Removes the layout qualifier with `name`, if it exists and adds an error if it isn't
+ /// a [`QualifierValue::None`]
+ pub fn none_layout_qualifier(&mut self, name: &'a str, errors: &mut Vec<super::Error>) -> bool {
+ match self
+ .layout_qualifiers
+ .remove(&QualifierKey::String(name.into()))
+ {
+ Some((QualifierValue::None, _)) => true,
+ Some((_, meta)) => {
+ errors.push(super::Error {
+ kind: super::ErrorKind::SemanticError(
+ "Qualifier doesn't expect a value".into(),
+ ),
+ meta,
+ });
+ // Return a `true` to since the qualifier is defined and adding
+ // another error for it not being defined would be unhelpful
+ true
+ }
+ _ => false,
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub enum FunctionCallKind {
+ TypeConstructor(Handle<Type>),
+ Function(String),
+}
+
+#[derive(Debug, Clone)]
+pub struct FunctionCall {
+ pub kind: FunctionCallKind,
+ pub args: Vec<Handle<HirExpr>>,
+}
+
+#[derive(Debug, Clone, Copy, PartialEq)]
+pub enum StorageQualifier {
+ AddressSpace(AddressSpace),
+ Input,
+ Output,
+ Const,
+}
+
+impl Default for StorageQualifier {
+ fn default() -> Self {
+ StorageQualifier::AddressSpace(AddressSpace::Function)
+ }
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum StructLayout {
+ Std140,
+ Std430,
+}
+
+// TODO: Encode precision hints in the IR
+/// A precision hint used in GLSL declarations.
+///
+/// Precision hints can be used to either speed up shader execution or control
+/// the precision of arithmetic operations.
+///
+/// To use a precision hint simply add it before the type in the declaration.
+/// ```glsl
+/// mediump float a;
+/// ```
+///
+/// The default when no precision is declared is `highp` which means that all
+/// operations operate with the type defined width.
+///
+/// For `mediump` and `lowp` operations follow the spir-v
+/// [`RelaxedPrecision`][RelaxedPrecision] decoration semantics.
+///
+/// [RelaxedPrecision]: https://www.khronos.org/registry/SPIR-V/specs/unified1/SPIRV.html#_a_id_relaxedprecisionsection_a_relaxed_precision
+#[derive(Debug, Clone, PartialEq, Copy)]
+pub enum Precision {
+ /// `lowp` precision
+ Low,
+ /// `mediump` precision
+ Medium,
+ /// `highp` precision
+ High,
+}
+
+#[derive(Debug, Clone, PartialEq, Copy)]
+pub enum ParameterQualifier {
+ In,
+ Out,
+ InOut,
+ Const,
+}
+
+impl ParameterQualifier {
+ /// Returns true if the argument should be passed as a lhs expression
+ pub const fn is_lhs(&self) -> bool {
+ match *self {
+ ParameterQualifier::Out | ParameterQualifier::InOut => true,
+ _ => false,
+ }
+ }
+
+ /// Converts from a parameter qualifier into a [`ExprPos`](ExprPos)
+ pub const fn as_pos(&self) -> ExprPos {
+ match *self {
+ ParameterQualifier::Out | ParameterQualifier::InOut => ExprPos::Lhs,
+ _ => ExprPos::Rhs,
+ }
+ }
+}
+
+/// The GLSL profile used by a shader.
+#[derive(Debug, Clone, Copy, PartialEq)]
+pub enum Profile {
+ /// The `core` profile, default when no profile is specified.
+ Core,
+}
diff --git a/third_party/rust/naga/src/front/glsl/builtins.rs b/third_party/rust/naga/src/front/glsl/builtins.rs
new file mode 100644
index 0000000000..14e223d6a4
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/builtins.rs
@@ -0,0 +1,2497 @@
+use super::{
+ ast::{
+ BuiltinVariations, FunctionDeclaration, FunctionKind, Overload, ParameterInfo,
+ ParameterQualifier,
+ },
+ context::Context,
+ Error, ErrorKind, Frontend, Result,
+};
+use crate::{
+ BinaryOperator, Block, Constant, DerivativeAxis as Axis, DerivativeControl as Ctrl, Expression,
+ Handle, ImageClass, ImageDimension as Dim, ImageQuery, MathFunction, Module,
+ RelationalFunction, SampleLevel, ScalarKind as Sk, Span, Type, TypeInner, UnaryOperator,
+ VectorSize,
+};
+
+impl crate::ScalarKind {
+ const fn dummy_storage_format(&self) -> crate::StorageFormat {
+ match *self {
+ Sk::Sint => crate::StorageFormat::R16Sint,
+ Sk::Uint => crate::StorageFormat::R16Uint,
+ _ => crate::StorageFormat::R16Float,
+ }
+ }
+}
+
+impl Module {
+ /// Helper function, to create a function prototype for a builtin
+ fn add_builtin(&mut self, args: Vec<TypeInner>, builtin: MacroCall) -> Overload {
+ let mut parameters = Vec::with_capacity(args.len());
+ let mut parameters_info = Vec::with_capacity(args.len());
+
+ for arg in args {
+ parameters.push(self.types.insert(
+ Type {
+ name: None,
+ inner: arg,
+ },
+ Span::default(),
+ ));
+ parameters_info.push(ParameterInfo {
+ qualifier: ParameterQualifier::In,
+ depth: false,
+ });
+ }
+
+ Overload {
+ parameters,
+ parameters_info,
+ kind: FunctionKind::Macro(builtin),
+ defined: false,
+ internal: true,
+ void: false,
+ }
+ }
+}
+
+const fn make_coords_arg(number_of_components: usize, kind: Sk) -> TypeInner {
+ let width = 4;
+
+ match number_of_components {
+ 1 => TypeInner::Scalar { kind, width },
+ _ => TypeInner::Vector {
+ size: match number_of_components {
+ 2 => VectorSize::Bi,
+ 3 => VectorSize::Tri,
+ _ => VectorSize::Quad,
+ },
+ kind,
+ width,
+ },
+ }
+}
+
+/// Inject builtins into the declaration
+///
+/// This is done to not add a large startup cost and not increase memory
+/// usage if it isn't needed.
+pub fn inject_builtin(
+ declaration: &mut FunctionDeclaration,
+ module: &mut Module,
+ name: &str,
+ mut variations: BuiltinVariations,
+) {
+ log::trace!(
+ "{} variations: {:?} {:?}",
+ name,
+ variations,
+ declaration.variations
+ );
+ // Don't regeneate variations
+ variations.remove(declaration.variations);
+ declaration.variations |= variations;
+
+ if variations.contains(BuiltinVariations::STANDARD) {
+ inject_standard_builtins(declaration, module, name)
+ }
+
+ if variations.contains(BuiltinVariations::DOUBLE) {
+ inject_double_builtin(declaration, module, name)
+ }
+
+ let width = 4;
+ match name {
+ "texture"
+ | "textureGrad"
+ | "textureGradOffset"
+ | "textureLod"
+ | "textureLodOffset"
+ | "textureOffset"
+ | "textureProj"
+ | "textureProjGrad"
+ | "textureProjGradOffset"
+ | "textureProjLod"
+ | "textureProjLodOffset"
+ | "textureProjOffset" => {
+ let f = |kind, dim, arrayed, multi, shadow| {
+ for bits in 0..=0b11 {
+ let variant = bits & 0b1 != 0;
+ let bias = bits & 0b10 != 0;
+
+ let (proj, offset, level_type) = match name {
+ // texture(gsampler, gvec P, [float bias]);
+ "texture" => (false, false, TextureLevelType::None),
+ // textureGrad(gsampler, gvec P, gvec dPdx, gvec dPdy);
+ "textureGrad" => (false, false, TextureLevelType::Grad),
+ // textureGradOffset(gsampler, gvec P, gvec dPdx, gvec dPdy, ivec offset);
+ "textureGradOffset" => (false, true, TextureLevelType::Grad),
+ // textureLod(gsampler, gvec P, float lod);
+ "textureLod" => (false, false, TextureLevelType::Lod),
+ // textureLodOffset(gsampler, gvec P, float lod, ivec offset);
+ "textureLodOffset" => (false, true, TextureLevelType::Lod),
+ // textureOffset(gsampler, gvec+1 P, ivec offset, [float bias]);
+ "textureOffset" => (false, true, TextureLevelType::None),
+ // textureProj(gsampler, gvec+1 P, [float bias]);
+ "textureProj" => (true, false, TextureLevelType::None),
+ // textureProjGrad(gsampler, gvec+1 P, gvec dPdx, gvec dPdy);
+ "textureProjGrad" => (true, false, TextureLevelType::Grad),
+ // textureProjGradOffset(gsampler, gvec+1 P, gvec dPdx, gvec dPdy, ivec offset);
+ "textureProjGradOffset" => (true, true, TextureLevelType::Grad),
+ // textureProjLod(gsampler, gvec+1 P, float lod);
+ "textureProjLod" => (true, false, TextureLevelType::Lod),
+ // textureProjLodOffset(gsampler, gvec+1 P, gvec dPdx, gvec dPdy, ivec offset);
+ "textureProjLodOffset" => (true, true, TextureLevelType::Lod),
+ // textureProjOffset(gsampler, gvec+1 P, ivec offset, [float bias]);
+ "textureProjOffset" => (true, true, TextureLevelType::None),
+ _ => unreachable!(),
+ };
+
+ let builtin = MacroCall::Texture {
+ proj,
+ offset,
+ shadow,
+ level_type,
+ };
+
+ // Parse out the variant settings.
+ let grad = level_type == TextureLevelType::Grad;
+ let lod = level_type == TextureLevelType::Lod;
+
+ let supports_variant = proj && !shadow;
+ if variant && !supports_variant {
+ continue;
+ }
+
+ if bias && !matches!(level_type, TextureLevelType::None) {
+ continue;
+ }
+
+ // Proj doesn't work with arrayed or Cube
+ if proj && (arrayed || dim == Dim::Cube) {
+ continue;
+ }
+
+ // texture operations with offset are not supported for cube maps
+ if dim == Dim::Cube && offset {
+ continue;
+ }
+
+ // sampler2DArrayShadow can't be used in textureLod or in texture with bias
+ if (lod || bias) && arrayed && shadow && dim == Dim::D2 {
+ continue;
+ }
+
+ // TODO: glsl supports using bias with depth samplers but naga doesn't
+ if bias && shadow {
+ continue;
+ }
+
+ let class = match shadow {
+ true => ImageClass::Depth { multi },
+ false => ImageClass::Sampled { kind, multi },
+ };
+
+ let image = TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ };
+
+ let num_coords_from_dim = image_dims_to_coords_size(dim).min(3);
+ let mut num_coords = num_coords_from_dim;
+
+ if shadow && proj {
+ num_coords = 4;
+ } else if dim == Dim::D1 && shadow {
+ num_coords = 3;
+ } else if shadow {
+ num_coords += 1;
+ } else if proj {
+ if variant && num_coords == 4 {
+ // Normal form already has 4 components, no need to have a variant form.
+ continue;
+ } else if variant {
+ num_coords = 4;
+ } else {
+ num_coords += 1;
+ }
+ }
+
+ if !(dim == Dim::D1 && shadow) {
+ num_coords += arrayed as usize;
+ }
+
+ // Special case: texture(gsamplerCubeArrayShadow) kicks the shadow compare ref to a separate argument,
+ // since it would otherwise take five arguments. It also can't take a bias, nor can it be proj/grad/lod/offset
+ // (presumably because nobody asked for it, and implementation complexity?)
+ if num_coords >= 5 {
+ if lod || grad || offset || proj || bias {
+ continue;
+ }
+ debug_assert!(dim == Dim::Cube && shadow && arrayed);
+ }
+ debug_assert!(num_coords <= 5);
+
+ let vector = make_coords_arg(num_coords, Sk::Float);
+ let mut args = vec![image, vector];
+
+ if num_coords == 5 {
+ args.push(TypeInner::Scalar {
+ kind: Sk::Float,
+ width,
+ });
+ }
+
+ match level_type {
+ TextureLevelType::Lod => {
+ args.push(TypeInner::Scalar {
+ kind: Sk::Float,
+ width,
+ });
+ }
+ TextureLevelType::Grad => {
+ args.push(make_coords_arg(num_coords_from_dim, Sk::Float));
+ args.push(make_coords_arg(num_coords_from_dim, Sk::Float));
+ }
+ TextureLevelType::None => {}
+ };
+
+ if offset {
+ args.push(make_coords_arg(num_coords_from_dim, Sk::Sint));
+ }
+
+ if bias {
+ args.push(TypeInner::Scalar {
+ kind: Sk::Float,
+ width,
+ });
+ }
+
+ declaration
+ .overloads
+ .push(module.add_builtin(args, builtin));
+ }
+ };
+
+ texture_args_generator(TextureArgsOptions::SHADOW | variations.into(), f)
+ }
+ "textureSize" => {
+ let f = |kind, dim, arrayed, multi, shadow| {
+ let class = match shadow {
+ true => ImageClass::Depth { multi },
+ false => ImageClass::Sampled { kind, multi },
+ };
+
+ let image = TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ };
+
+ let mut args = vec![image];
+
+ if !multi {
+ args.push(TypeInner::Scalar {
+ kind: Sk::Sint,
+ width,
+ })
+ }
+
+ declaration
+ .overloads
+ .push(module.add_builtin(args, MacroCall::TextureSize { arrayed }))
+ };
+
+ texture_args_generator(
+ TextureArgsOptions::SHADOW | TextureArgsOptions::MULTI | variations.into(),
+ f,
+ )
+ }
+ "texelFetch" | "texelFetchOffset" => {
+ let offset = "texelFetchOffset" == name;
+ let f = |kind, dim, arrayed, multi, _shadow| {
+ // Cube images aren't supported
+ if let Dim::Cube = dim {
+ return;
+ }
+
+ let image = TypeInner::Image {
+ dim,
+ arrayed,
+ class: ImageClass::Sampled { kind, multi },
+ };
+
+ let dim_value = image_dims_to_coords_size(dim);
+ let coordinates = make_coords_arg(dim_value + arrayed as usize, Sk::Sint);
+
+ let mut args = vec![
+ image,
+ coordinates,
+ TypeInner::Scalar {
+ kind: Sk::Sint,
+ width,
+ },
+ ];
+
+ if offset {
+ args.push(make_coords_arg(dim_value, Sk::Sint));
+ }
+
+ declaration
+ .overloads
+ .push(module.add_builtin(args, MacroCall::ImageLoad { multi }))
+ };
+
+ // Don't generate shadow images since they aren't supported
+ texture_args_generator(TextureArgsOptions::MULTI | variations.into(), f)
+ }
+ "imageSize" => {
+ let f = |kind: Sk, dim, arrayed, _, _| {
+ // Naga doesn't support cube images and it's usefulness
+ // is questionable, so they won't be supported for now
+ if dim == Dim::Cube {
+ return;
+ }
+
+ let image = TypeInner::Image {
+ dim,
+ arrayed,
+ class: ImageClass::Storage {
+ format: kind.dummy_storage_format(),
+ access: crate::StorageAccess::empty(),
+ },
+ };
+
+ declaration
+ .overloads
+ .push(module.add_builtin(vec![image], MacroCall::TextureSize { arrayed }))
+ };
+
+ texture_args_generator(variations.into(), f)
+ }
+ "imageLoad" => {
+ let f = |kind: Sk, dim, arrayed, _, _| {
+ // Naga doesn't support cube images and it's usefulness
+ // is questionable, so they won't be supported for now
+ if dim == Dim::Cube {
+ return;
+ }
+
+ let image = TypeInner::Image {
+ dim,
+ arrayed,
+ class: ImageClass::Storage {
+ format: kind.dummy_storage_format(),
+ access: crate::StorageAccess::LOAD,
+ },
+ };
+
+ let dim_value = image_dims_to_coords_size(dim);
+ let mut coord_size = dim_value + arrayed as usize;
+ // > Every OpenGL API call that operates on cubemap array
+ // > textures takes layer-faces, not array layers
+ //
+ // So this means that imageCubeArray only takes a three component
+ // vector coordinate and the third component is a layer index.
+ if Dim::Cube == dim && arrayed {
+ coord_size = 3
+ }
+ let coordinates = make_coords_arg(coord_size, Sk::Sint);
+
+ let args = vec![image, coordinates];
+
+ declaration
+ .overloads
+ .push(module.add_builtin(args, MacroCall::ImageLoad { multi: false }))
+ };
+
+ // Don't generate shadow nor multisampled images since they aren't supported
+ texture_args_generator(variations.into(), f)
+ }
+ "imageStore" => {
+ let f = |kind: Sk, dim, arrayed, _, _| {
+ // Naga doesn't support cube images and it's usefulness
+ // is questionable, so they won't be supported for now
+ if dim == Dim::Cube {
+ return;
+ }
+
+ let image = TypeInner::Image {
+ dim,
+ arrayed,
+ class: ImageClass::Storage {
+ format: kind.dummy_storage_format(),
+ access: crate::StorageAccess::STORE,
+ },
+ };
+
+ let dim_value = image_dims_to_coords_size(dim);
+ let mut coord_size = dim_value + arrayed as usize;
+ // > Every OpenGL API call that operates on cubemap array
+ // > textures takes layer-faces, not array layers
+ //
+ // So this means that imageCubeArray only takes a three component
+ // vector coordinate and the third component is a layer index.
+ if Dim::Cube == dim && arrayed {
+ coord_size = 3
+ }
+ let coordinates = make_coords_arg(coord_size, Sk::Sint);
+
+ let args = vec![
+ image,
+ coordinates,
+ TypeInner::Vector {
+ size: VectorSize::Quad,
+ kind,
+ width,
+ },
+ ];
+
+ let mut overload = module.add_builtin(args, MacroCall::ImageStore);
+ overload.void = true;
+ declaration.overloads.push(overload)
+ };
+
+ // Don't generate shadow nor multisampled images since they aren't supported
+ texture_args_generator(variations.into(), f)
+ }
+ _ => {}
+ }
+}
+
+/// Injects the builtins into declaration that don't need any special variations
+fn inject_standard_builtins(
+ declaration: &mut FunctionDeclaration,
+ module: &mut Module,
+ name: &str,
+) {
+ let width = 4;
+ match name {
+ "sampler1D" | "sampler1DArray" | "sampler2D" | "sampler2DArray" | "sampler2DMS"
+ | "sampler2DMSArray" | "sampler3D" | "samplerCube" | "samplerCubeArray" => {
+ declaration.overloads.push(module.add_builtin(
+ vec![
+ TypeInner::Image {
+ dim: match name {
+ "sampler1D" | "sampler1DArray" => Dim::D1,
+ "sampler2D" | "sampler2DArray" | "sampler2DMS" | "sampler2DMSArray" => {
+ Dim::D2
+ }
+ "sampler3D" => Dim::D3,
+ _ => Dim::Cube,
+ },
+ arrayed: matches!(
+ name,
+ "sampler1DArray"
+ | "sampler2DArray"
+ | "sampler2DMSArray"
+ | "samplerCubeArray"
+ ),
+ class: ImageClass::Sampled {
+ kind: Sk::Float,
+ multi: matches!(name, "sampler2DMS" | "sampler2DMSArray"),
+ },
+ },
+ TypeInner::Sampler { comparison: false },
+ ],
+ MacroCall::Sampler,
+ ))
+ }
+ "sampler1DShadow"
+ | "sampler1DArrayShadow"
+ | "sampler2DShadow"
+ | "sampler2DArrayShadow"
+ | "samplerCubeShadow"
+ | "samplerCubeArrayShadow" => {
+ let dim = match name {
+ "sampler1DShadow" | "sampler1DArrayShadow" => Dim::D1,
+ "sampler2DShadow" | "sampler2DArrayShadow" => Dim::D2,
+ _ => Dim::Cube,
+ };
+ let arrayed = matches!(
+ name,
+ "sampler1DArrayShadow" | "sampler2DArrayShadow" | "samplerCubeArrayShadow"
+ );
+
+ for i in 0..2 {
+ let ty = TypeInner::Image {
+ dim,
+ arrayed,
+ class: match i {
+ 0 => ImageClass::Sampled {
+ kind: Sk::Float,
+ multi: false,
+ },
+ _ => ImageClass::Depth { multi: false },
+ },
+ };
+
+ declaration.overloads.push(module.add_builtin(
+ vec![ty, TypeInner::Sampler { comparison: true }],
+ MacroCall::SamplerShadow,
+ ))
+ }
+ }
+ "sin" | "exp" | "exp2" | "sinh" | "cos" | "cosh" | "tan" | "tanh" | "acos" | "asin"
+ | "log" | "log2" | "radians" | "degrees" | "asinh" | "acosh" | "atanh"
+ | "floatBitsToInt" | "floatBitsToUint" | "dFdx" | "dFdxFine" | "dFdxCoarse" | "dFdy"
+ | "dFdyFine" | "dFdyCoarse" | "fwidth" | "fwidthFine" | "fwidthCoarse" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ for bits in 0..0b100 {
+ let size = match bits {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+ let kind = Sk::Float;
+
+ declaration.overloads.push(module.add_builtin(
+ vec![match size {
+ Some(size) => TypeInner::Vector { size, kind, width },
+ None => TypeInner::Scalar { kind, width },
+ }],
+ match name {
+ "sin" => MacroCall::MathFunction(MathFunction::Sin),
+ "exp" => MacroCall::MathFunction(MathFunction::Exp),
+ "exp2" => MacroCall::MathFunction(MathFunction::Exp2),
+ "sinh" => MacroCall::MathFunction(MathFunction::Sinh),
+ "cos" => MacroCall::MathFunction(MathFunction::Cos),
+ "cosh" => MacroCall::MathFunction(MathFunction::Cosh),
+ "tan" => MacroCall::MathFunction(MathFunction::Tan),
+ "tanh" => MacroCall::MathFunction(MathFunction::Tanh),
+ "acos" => MacroCall::MathFunction(MathFunction::Acos),
+ "asin" => MacroCall::MathFunction(MathFunction::Asin),
+ "log" => MacroCall::MathFunction(MathFunction::Log),
+ "log2" => MacroCall::MathFunction(MathFunction::Log2),
+ "asinh" => MacroCall::MathFunction(MathFunction::Asinh),
+ "acosh" => MacroCall::MathFunction(MathFunction::Acosh),
+ "atanh" => MacroCall::MathFunction(MathFunction::Atanh),
+ "radians" => MacroCall::MathFunction(MathFunction::Radians),
+ "degrees" => MacroCall::MathFunction(MathFunction::Degrees),
+ "floatBitsToInt" => MacroCall::BitCast(Sk::Sint),
+ "floatBitsToUint" => MacroCall::BitCast(Sk::Uint),
+ "dFdxCoarse" => MacroCall::Derivate(Axis::X, Ctrl::Coarse),
+ "dFdyCoarse" => MacroCall::Derivate(Axis::Y, Ctrl::Coarse),
+ "fwidthCoarse" => MacroCall::Derivate(Axis::Width, Ctrl::Coarse),
+ "dFdxFine" => MacroCall::Derivate(Axis::X, Ctrl::Fine),
+ "dFdyFine" => MacroCall::Derivate(Axis::Y, Ctrl::Fine),
+ "fwidthFine" => MacroCall::Derivate(Axis::Width, Ctrl::Fine),
+ "dFdx" => MacroCall::Derivate(Axis::X, Ctrl::None),
+ "dFdy" => MacroCall::Derivate(Axis::Y, Ctrl::None),
+ "fwidth" => MacroCall::Derivate(Axis::Width, Ctrl::None),
+ _ => unreachable!(),
+ },
+ ))
+ }
+ }
+ "intBitsToFloat" | "uintBitsToFloat" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ for bits in 0..0b100 {
+ let size = match bits {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+ let kind = match name {
+ "intBitsToFloat" => Sk::Sint,
+ _ => Sk::Uint,
+ };
+
+ declaration.overloads.push(module.add_builtin(
+ vec![match size {
+ Some(size) => TypeInner::Vector { size, kind, width },
+ None => TypeInner::Scalar { kind, width },
+ }],
+ MacroCall::BitCast(Sk::Float),
+ ))
+ }
+ }
+ "pow" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ for bits in 0..0b100 {
+ let size = match bits {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+ let kind = Sk::Float;
+ let ty = || match size {
+ Some(size) => TypeInner::Vector { size, kind, width },
+ None => TypeInner::Scalar { kind, width },
+ };
+
+ declaration.overloads.push(
+ module
+ .add_builtin(vec![ty(), ty()], MacroCall::MathFunction(MathFunction::Pow)),
+ )
+ }
+ }
+ "abs" | "sign" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ // bit 2 - float/sint
+ for bits in 0..0b1000 {
+ let size = match bits & 0b11 {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+ let kind = match bits >> 2 {
+ 0b0 => Sk::Float,
+ _ => Sk::Sint,
+ };
+
+ let args = vec![match size {
+ Some(size) => TypeInner::Vector { size, kind, width },
+ None => TypeInner::Scalar { kind, width },
+ }];
+
+ declaration.overloads.push(module.add_builtin(
+ args,
+ MacroCall::MathFunction(match name {
+ "abs" => MathFunction::Abs,
+ "sign" => MathFunction::Sign,
+ _ => unreachable!(),
+ }),
+ ))
+ }
+ }
+ "bitCount" | "bitfieldReverse" | "bitfieldExtract" | "bitfieldInsert" | "findLSB"
+ | "findMSB" => {
+ let fun = match name {
+ "bitCount" => MathFunction::CountOneBits,
+ "bitfieldReverse" => MathFunction::ReverseBits,
+ "bitfieldExtract" => MathFunction::ExtractBits,
+ "bitfieldInsert" => MathFunction::InsertBits,
+ "findLSB" => MathFunction::FindLsb,
+ "findMSB" => MathFunction::FindMsb,
+ _ => unreachable!(),
+ };
+
+ let mc = match fun {
+ MathFunction::ExtractBits => MacroCall::BitfieldExtract,
+ MathFunction::InsertBits => MacroCall::BitfieldInsert,
+ _ => MacroCall::MathFunction(fun),
+ };
+
+ // bits layout
+ // bit 0 - int/uint
+ // bit 1 through 2 - dims
+ for bits in 0..0b1000 {
+ let kind = match bits & 0b1 {
+ 0b0 => Sk::Sint,
+ _ => Sk::Uint,
+ };
+ let size = match bits >> 1 {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+
+ let ty = || match size {
+ Some(size) => TypeInner::Vector { size, kind, width },
+ None => TypeInner::Scalar { kind, width },
+ };
+
+ let mut args = vec![ty()];
+
+ match fun {
+ MathFunction::ExtractBits => {
+ args.push(TypeInner::Scalar {
+ kind: Sk::Sint,
+ width: 4,
+ });
+ args.push(TypeInner::Scalar {
+ kind: Sk::Sint,
+ width: 4,
+ });
+ }
+ MathFunction::InsertBits => {
+ args.push(ty());
+ args.push(TypeInner::Scalar {
+ kind: Sk::Sint,
+ width: 4,
+ });
+ args.push(TypeInner::Scalar {
+ kind: Sk::Sint,
+ width: 4,
+ });
+ }
+ _ => {}
+ }
+
+ // we need to cast the return type of findLsb / findMsb
+ let mc = if kind == Sk::Uint {
+ match mc {
+ MacroCall::MathFunction(MathFunction::FindLsb) => MacroCall::FindLsbUint,
+ MacroCall::MathFunction(MathFunction::FindMsb) => MacroCall::FindMsbUint,
+ mc => mc,
+ }
+ } else {
+ mc
+ };
+
+ declaration.overloads.push(module.add_builtin(args, mc))
+ }
+ }
+ "packSnorm4x8" | "packUnorm4x8" | "packSnorm2x16" | "packUnorm2x16" | "packHalf2x16" => {
+ let fun = match name {
+ "packSnorm4x8" => MathFunction::Pack4x8snorm,
+ "packUnorm4x8" => MathFunction::Pack4x8unorm,
+ "packSnorm2x16" => MathFunction::Pack2x16unorm,
+ "packUnorm2x16" => MathFunction::Pack2x16snorm,
+ "packHalf2x16" => MathFunction::Pack2x16float,
+ _ => unreachable!(),
+ };
+
+ let ty = match fun {
+ MathFunction::Pack4x8snorm | MathFunction::Pack4x8unorm => TypeInner::Vector {
+ size: crate::VectorSize::Quad,
+ kind: Sk::Float,
+ width: 4,
+ },
+ MathFunction::Pack2x16unorm
+ | MathFunction::Pack2x16snorm
+ | MathFunction::Pack2x16float => TypeInner::Vector {
+ size: crate::VectorSize::Bi,
+ kind: Sk::Float,
+ width: 4,
+ },
+ _ => unreachable!(),
+ };
+
+ let args = vec![ty];
+
+ declaration
+ .overloads
+ .push(module.add_builtin(args, MacroCall::MathFunction(fun)));
+ }
+ "unpackSnorm4x8" | "unpackUnorm4x8" | "unpackSnorm2x16" | "unpackUnorm2x16"
+ | "unpackHalf2x16" => {
+ let fun = match name {
+ "unpackSnorm4x8" => MathFunction::Unpack4x8snorm,
+ "unpackUnorm4x8" => MathFunction::Unpack4x8unorm,
+ "unpackSnorm2x16" => MathFunction::Unpack2x16snorm,
+ "unpackUnorm2x16" => MathFunction::Unpack2x16unorm,
+ "unpackHalf2x16" => MathFunction::Unpack2x16float,
+ _ => unreachable!(),
+ };
+
+ let args = vec![TypeInner::Scalar {
+ kind: Sk::Uint,
+ width: 4,
+ }];
+
+ declaration
+ .overloads
+ .push(module.add_builtin(args, MacroCall::MathFunction(fun)));
+ }
+ "atan" => {
+ // bits layout
+ // bit 0 - atan/atan2
+ // bit 1 through 2 - dims
+ for bits in 0..0b1000 {
+ let fun = match bits & 0b1 {
+ 0b0 => MathFunction::Atan,
+ _ => MathFunction::Atan2,
+ };
+ let size = match bits >> 1 {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+ let kind = Sk::Float;
+ let ty = || match size {
+ Some(size) => TypeInner::Vector { size, kind, width },
+ None => TypeInner::Scalar { kind, width },
+ };
+
+ let mut args = vec![ty()];
+
+ if fun == MathFunction::Atan2 {
+ args.push(ty())
+ }
+
+ declaration
+ .overloads
+ .push(module.add_builtin(args, MacroCall::MathFunction(fun)))
+ }
+ }
+ "all" | "any" | "not" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ for bits in 0..0b11 {
+ let size = match bits {
+ 0b00 => VectorSize::Bi,
+ 0b01 => VectorSize::Tri,
+ _ => VectorSize::Quad,
+ };
+
+ let args = vec![TypeInner::Vector {
+ size,
+ kind: Sk::Bool,
+ width: crate::BOOL_WIDTH,
+ }];
+
+ let fun = match name {
+ "all" => MacroCall::Relational(RelationalFunction::All),
+ "any" => MacroCall::Relational(RelationalFunction::Any),
+ "not" => MacroCall::Unary(UnaryOperator::Not),
+ _ => unreachable!(),
+ };
+
+ declaration.overloads.push(module.add_builtin(args, fun))
+ }
+ }
+ "lessThan" | "greaterThan" | "lessThanEqual" | "greaterThanEqual" => {
+ for bits in 0..0b1001 {
+ let (size, kind) = match bits {
+ 0b0000 => (VectorSize::Bi, Sk::Float),
+ 0b0001 => (VectorSize::Tri, Sk::Float),
+ 0b0010 => (VectorSize::Quad, Sk::Float),
+ 0b0011 => (VectorSize::Bi, Sk::Sint),
+ 0b0100 => (VectorSize::Tri, Sk::Sint),
+ 0b0101 => (VectorSize::Quad, Sk::Sint),
+ 0b0110 => (VectorSize::Bi, Sk::Uint),
+ 0b0111 => (VectorSize::Tri, Sk::Uint),
+ _ => (VectorSize::Quad, Sk::Uint),
+ };
+
+ let ty = || TypeInner::Vector { size, kind, width };
+ let args = vec![ty(), ty()];
+
+ let fun = MacroCall::Binary(match name {
+ "lessThan" => BinaryOperator::Less,
+ "greaterThan" => BinaryOperator::Greater,
+ "lessThanEqual" => BinaryOperator::LessEqual,
+ "greaterThanEqual" => BinaryOperator::GreaterEqual,
+ _ => unreachable!(),
+ });
+
+ declaration.overloads.push(module.add_builtin(args, fun))
+ }
+ }
+ "equal" | "notEqual" => {
+ for bits in 0..0b1100 {
+ let (size, kind) = match bits {
+ 0b0000 => (VectorSize::Bi, Sk::Float),
+ 0b0001 => (VectorSize::Tri, Sk::Float),
+ 0b0010 => (VectorSize::Quad, Sk::Float),
+ 0b0011 => (VectorSize::Bi, Sk::Sint),
+ 0b0100 => (VectorSize::Tri, Sk::Sint),
+ 0b0101 => (VectorSize::Quad, Sk::Sint),
+ 0b0110 => (VectorSize::Bi, Sk::Uint),
+ 0b0111 => (VectorSize::Tri, Sk::Uint),
+ 0b1000 => (VectorSize::Quad, Sk::Uint),
+ 0b1001 => (VectorSize::Bi, Sk::Bool),
+ 0b1010 => (VectorSize::Tri, Sk::Bool),
+ _ => (VectorSize::Quad, Sk::Bool),
+ };
+
+ let width = if let Sk::Bool = kind {
+ crate::BOOL_WIDTH
+ } else {
+ width
+ };
+
+ let ty = || TypeInner::Vector { size, kind, width };
+ let args = vec![ty(), ty()];
+
+ let fun = MacroCall::Binary(match name {
+ "equal" => BinaryOperator::Equal,
+ "notEqual" => BinaryOperator::NotEqual,
+ _ => unreachable!(),
+ });
+
+ declaration.overloads.push(module.add_builtin(args, fun))
+ }
+ }
+ "min" | "max" => {
+ // bits layout
+ // bit 0 through 1 - scalar kind
+ // bit 2 through 4 - dims
+ for bits in 0..0b11100 {
+ let kind = match bits & 0b11 {
+ 0b00 => Sk::Float,
+ 0b01 => Sk::Sint,
+ 0b10 => Sk::Uint,
+ _ => continue,
+ };
+ let (size, second_size) = match bits >> 2 {
+ 0b000 => (None, None),
+ 0b001 => (Some(VectorSize::Bi), None),
+ 0b010 => (Some(VectorSize::Tri), None),
+ 0b011 => (Some(VectorSize::Quad), None),
+ 0b100 => (Some(VectorSize::Bi), Some(VectorSize::Bi)),
+ 0b101 => (Some(VectorSize::Tri), Some(VectorSize::Tri)),
+ _ => (Some(VectorSize::Quad), Some(VectorSize::Quad)),
+ };
+
+ let args = vec![
+ match size {
+ Some(size) => TypeInner::Vector { size, kind, width },
+ None => TypeInner::Scalar { kind, width },
+ },
+ match second_size {
+ Some(size) => TypeInner::Vector { size, kind, width },
+ None => TypeInner::Scalar { kind, width },
+ },
+ ];
+
+ let fun = match name {
+ "max" => MacroCall::Splatted(MathFunction::Max, size, 1),
+ "min" => MacroCall::Splatted(MathFunction::Min, size, 1),
+ _ => unreachable!(),
+ };
+
+ declaration.overloads.push(module.add_builtin(args, fun))
+ }
+ }
+ "mix" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ // bit 2 through 4 - types
+ //
+ // 0b10011 is the last element since splatted single elements
+ // were already added
+ for bits in 0..0b10011 {
+ let size = match bits & 0b11 {
+ 0b00 => Some(VectorSize::Bi),
+ 0b01 => Some(VectorSize::Tri),
+ 0b10 => Some(VectorSize::Quad),
+ _ => None,
+ };
+ let (kind, splatted, boolean) = match bits >> 2 {
+ 0b000 => (Sk::Sint, false, true),
+ 0b001 => (Sk::Uint, false, true),
+ 0b010 => (Sk::Float, false, true),
+ 0b011 => (Sk::Float, false, false),
+ _ => (Sk::Float, true, false),
+ };
+
+ let ty = |kind, width| match size {
+ Some(size) => TypeInner::Vector { size, kind, width },
+ None => TypeInner::Scalar { kind, width },
+ };
+ let args = vec![
+ ty(kind, width),
+ ty(kind, width),
+ match (boolean, splatted) {
+ (true, _) => ty(Sk::Bool, crate::BOOL_WIDTH),
+ (_, false) => TypeInner::Scalar { kind, width },
+ _ => ty(kind, width),
+ },
+ ];
+
+ declaration.overloads.push(module.add_builtin(
+ args,
+ match boolean {
+ true => MacroCall::MixBoolean,
+ false => MacroCall::Splatted(MathFunction::Mix, size, 2),
+ },
+ ))
+ }
+ }
+ "clamp" => {
+ // bits layout
+ // bit 0 through 1 - float/int/uint
+ // bit 2 through 3 - dims
+ // bit 4 - splatted
+ //
+ // 0b11010 is the last element since splatted single elements
+ // were already added
+ for bits in 0..0b11011 {
+ let kind = match bits & 0b11 {
+ 0b00 => Sk::Float,
+ 0b01 => Sk::Sint,
+ 0b10 => Sk::Uint,
+ _ => continue,
+ };
+ let size = match (bits >> 2) & 0b11 {
+ 0b00 => Some(VectorSize::Bi),
+ 0b01 => Some(VectorSize::Tri),
+ 0b10 => Some(VectorSize::Quad),
+ _ => None,
+ };
+ let splatted = bits & 0b10000 == 0b10000;
+
+ let base_ty = || match size {
+ Some(size) => TypeInner::Vector { size, kind, width },
+ None => TypeInner::Scalar { kind, width },
+ };
+ let limit_ty = || match splatted {
+ true => TypeInner::Scalar { kind, width },
+ false => base_ty(),
+ };
+
+ let args = vec![base_ty(), limit_ty(), limit_ty()];
+
+ declaration
+ .overloads
+ .push(module.add_builtin(args, MacroCall::Clamp(size)))
+ }
+ }
+ "barrier" => declaration
+ .overloads
+ .push(module.add_builtin(Vec::new(), MacroCall::Barrier)),
+ // Add common builtins with floats
+ _ => inject_common_builtin(declaration, module, name, 4),
+ }
+}
+
+/// Injects the builtins into declaration that need doubles
+fn inject_double_builtin(declaration: &mut FunctionDeclaration, module: &mut Module, name: &str) {
+ let width = 8;
+ match name {
+ "abs" | "sign" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ for bits in 0..0b100 {
+ let size = match bits {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+ let kind = Sk::Float;
+
+ let args = vec![match size {
+ Some(size) => TypeInner::Vector { size, kind, width },
+ None => TypeInner::Scalar { kind, width },
+ }];
+
+ declaration.overloads.push(module.add_builtin(
+ args,
+ MacroCall::MathFunction(match name {
+ "abs" => MathFunction::Abs,
+ "sign" => MathFunction::Sign,
+ _ => unreachable!(),
+ }),
+ ))
+ }
+ }
+ "min" | "max" => {
+ // bits layout
+ // bit 0 through 2 - dims
+ for bits in 0..0b111 {
+ let (size, second_size) = match bits {
+ 0b000 => (None, None),
+ 0b001 => (Some(VectorSize::Bi), None),
+ 0b010 => (Some(VectorSize::Tri), None),
+ 0b011 => (Some(VectorSize::Quad), None),
+ 0b100 => (Some(VectorSize::Bi), Some(VectorSize::Bi)),
+ 0b101 => (Some(VectorSize::Tri), Some(VectorSize::Tri)),
+ _ => (Some(VectorSize::Quad), Some(VectorSize::Quad)),
+ };
+ let kind = Sk::Float;
+
+ let args = vec![
+ match size {
+ Some(size) => TypeInner::Vector { size, kind, width },
+ None => TypeInner::Scalar { kind, width },
+ },
+ match second_size {
+ Some(size) => TypeInner::Vector { size, kind, width },
+ None => TypeInner::Scalar { kind, width },
+ },
+ ];
+
+ let fun = match name {
+ "max" => MacroCall::Splatted(MathFunction::Max, size, 1),
+ "min" => MacroCall::Splatted(MathFunction::Min, size, 1),
+ _ => unreachable!(),
+ };
+
+ declaration.overloads.push(module.add_builtin(args, fun))
+ }
+ }
+ "mix" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ // bit 2 through 3 - splatted/boolean
+ //
+ // 0b1010 is the last element since splatted with single elements
+ // is equal to normal single elements
+ for bits in 0..0b1011 {
+ let size = match bits & 0b11 {
+ 0b00 => Some(VectorSize::Quad),
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => None,
+ };
+ let kind = Sk::Float;
+ let (splatted, boolean) = match bits >> 2 {
+ 0b00 => (false, false),
+ 0b01 => (false, true),
+ _ => (true, false),
+ };
+
+ let ty = |kind, width| match size {
+ Some(size) => TypeInner::Vector { size, kind, width },
+ None => TypeInner::Scalar { kind, width },
+ };
+ let args = vec![
+ ty(kind, width),
+ ty(kind, width),
+ match (boolean, splatted) {
+ (true, _) => ty(Sk::Bool, crate::BOOL_WIDTH),
+ (_, false) => TypeInner::Scalar { kind, width },
+ _ => ty(kind, width),
+ },
+ ];
+
+ declaration.overloads.push(module.add_builtin(
+ args,
+ match boolean {
+ true => MacroCall::MixBoolean,
+ false => MacroCall::Splatted(MathFunction::Mix, size, 2),
+ },
+ ))
+ }
+ }
+ "clamp" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ // bit 2 - splatted
+ //
+ // 0b110 is the last element since splatted with single elements
+ // is equal to normal single elements
+ for bits in 0..0b111 {
+ let kind = Sk::Float;
+ let size = match bits & 0b11 {
+ 0b00 => Some(VectorSize::Bi),
+ 0b01 => Some(VectorSize::Tri),
+ 0b10 => Some(VectorSize::Quad),
+ _ => None,
+ };
+ let splatted = bits & 0b100 == 0b100;
+
+ let base_ty = || match size {
+ Some(size) => TypeInner::Vector { size, kind, width },
+ None => TypeInner::Scalar { kind, width },
+ };
+ let limit_ty = || match splatted {
+ true => TypeInner::Scalar { kind, width },
+ false => base_ty(),
+ };
+
+ let args = vec![base_ty(), limit_ty(), limit_ty()];
+
+ declaration
+ .overloads
+ .push(module.add_builtin(args, MacroCall::Clamp(size)))
+ }
+ }
+ "lessThan" | "greaterThan" | "lessThanEqual" | "greaterThanEqual" | "equal"
+ | "notEqual" => {
+ for bits in 0..0b11 {
+ let (size, kind) = match bits {
+ 0b00 => (VectorSize::Bi, Sk::Float),
+ 0b01 => (VectorSize::Tri, Sk::Float),
+ _ => (VectorSize::Quad, Sk::Float),
+ };
+
+ let ty = || TypeInner::Vector { size, kind, width };
+ let args = vec![ty(), ty()];
+
+ let fun = MacroCall::Binary(match name {
+ "lessThan" => BinaryOperator::Less,
+ "greaterThan" => BinaryOperator::Greater,
+ "lessThanEqual" => BinaryOperator::LessEqual,
+ "greaterThanEqual" => BinaryOperator::GreaterEqual,
+ "equal" => BinaryOperator::Equal,
+ "notEqual" => BinaryOperator::NotEqual,
+ _ => unreachable!(),
+ });
+
+ declaration.overloads.push(module.add_builtin(args, fun))
+ }
+ }
+ // Add common builtins with doubles
+ _ => inject_common_builtin(declaration, module, name, 8),
+ }
+}
+
+/// Injects the builtins into declaration that can used either float or doubles
+fn inject_common_builtin(
+ declaration: &mut FunctionDeclaration,
+ module: &mut Module,
+ name: &str,
+ float_width: crate::Bytes,
+) {
+ match name {
+ "ceil" | "round" | "roundEven" | "floor" | "fract" | "trunc" | "sqrt" | "inversesqrt"
+ | "normalize" | "length" | "isinf" | "isnan" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ for bits in 0..0b100 {
+ let size = match bits {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+
+ let args = vec![match size {
+ Some(size) => TypeInner::Vector {
+ size,
+ kind: Sk::Float,
+ width: float_width,
+ },
+ None => TypeInner::Scalar {
+ kind: Sk::Float,
+ width: float_width,
+ },
+ }];
+
+ let fun = match name {
+ "ceil" => MacroCall::MathFunction(MathFunction::Ceil),
+ "round" | "roundEven" => MacroCall::MathFunction(MathFunction::Round),
+ "floor" => MacroCall::MathFunction(MathFunction::Floor),
+ "fract" => MacroCall::MathFunction(MathFunction::Fract),
+ "trunc" => MacroCall::MathFunction(MathFunction::Trunc),
+ "sqrt" => MacroCall::MathFunction(MathFunction::Sqrt),
+ "inversesqrt" => MacroCall::MathFunction(MathFunction::InverseSqrt),
+ "normalize" => MacroCall::MathFunction(MathFunction::Normalize),
+ "length" => MacroCall::MathFunction(MathFunction::Length),
+ "isinf" => MacroCall::Relational(RelationalFunction::IsInf),
+ "isnan" => MacroCall::Relational(RelationalFunction::IsNan),
+ _ => unreachable!(),
+ };
+
+ declaration.overloads.push(module.add_builtin(args, fun))
+ }
+ }
+ "dot" | "reflect" | "distance" | "ldexp" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ for bits in 0..0b100 {
+ let size = match bits {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+ let ty = || match size {
+ Some(size) => TypeInner::Vector {
+ size,
+ kind: Sk::Float,
+ width: float_width,
+ },
+ None => TypeInner::Scalar {
+ kind: Sk::Float,
+ width: float_width,
+ },
+ };
+
+ let fun = match name {
+ "dot" => MacroCall::MathFunction(MathFunction::Dot),
+ "reflect" => MacroCall::MathFunction(MathFunction::Reflect),
+ "distance" => MacroCall::MathFunction(MathFunction::Distance),
+ "ldexp" => MacroCall::MathFunction(MathFunction::Ldexp),
+ _ => unreachable!(),
+ };
+
+ declaration
+ .overloads
+ .push(module.add_builtin(vec![ty(), ty()], fun))
+ }
+ }
+ "transpose" => {
+ // bits layout
+ // bit 0 through 3 - dims
+ for bits in 0..0b1001 {
+ let (rows, columns) = match bits {
+ 0b0000 => (VectorSize::Bi, VectorSize::Bi),
+ 0b0001 => (VectorSize::Bi, VectorSize::Tri),
+ 0b0010 => (VectorSize::Bi, VectorSize::Quad),
+ 0b0011 => (VectorSize::Tri, VectorSize::Bi),
+ 0b0100 => (VectorSize::Tri, VectorSize::Tri),
+ 0b0101 => (VectorSize::Tri, VectorSize::Quad),
+ 0b0110 => (VectorSize::Quad, VectorSize::Bi),
+ 0b0111 => (VectorSize::Quad, VectorSize::Tri),
+ _ => (VectorSize::Quad, VectorSize::Quad),
+ };
+
+ declaration.overloads.push(module.add_builtin(
+ vec![TypeInner::Matrix {
+ columns,
+ rows,
+ width: float_width,
+ }],
+ MacroCall::MathFunction(MathFunction::Transpose),
+ ))
+ }
+ }
+ "inverse" | "determinant" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ for bits in 0..0b11 {
+ let (rows, columns) = match bits {
+ 0b00 => (VectorSize::Bi, VectorSize::Bi),
+ 0b01 => (VectorSize::Tri, VectorSize::Tri),
+ _ => (VectorSize::Quad, VectorSize::Quad),
+ };
+
+ let args = vec![TypeInner::Matrix {
+ columns,
+ rows,
+ width: float_width,
+ }];
+
+ declaration.overloads.push(module.add_builtin(
+ args,
+ MacroCall::MathFunction(match name {
+ "inverse" => MathFunction::Inverse,
+ "determinant" => MathFunction::Determinant,
+ _ => unreachable!(),
+ }),
+ ))
+ }
+ }
+ "mod" | "step" => {
+ // bits layout
+ // bit 0 through 2 - dims
+ for bits in 0..0b111 {
+ let (size, second_size) = match bits {
+ 0b000 => (None, None),
+ 0b001 => (Some(VectorSize::Bi), None),
+ 0b010 => (Some(VectorSize::Tri), None),
+ 0b011 => (Some(VectorSize::Quad), None),
+ 0b100 => (Some(VectorSize::Bi), Some(VectorSize::Bi)),
+ 0b101 => (Some(VectorSize::Tri), Some(VectorSize::Tri)),
+ _ => (Some(VectorSize::Quad), Some(VectorSize::Quad)),
+ };
+
+ let mut args = Vec::with_capacity(2);
+ let step = name == "step";
+
+ for i in 0..2 {
+ let maybe_size = match i == step as u32 {
+ true => size,
+ false => second_size,
+ };
+
+ args.push(match maybe_size {
+ Some(size) => TypeInner::Vector {
+ size,
+ kind: Sk::Float,
+ width: float_width,
+ },
+ None => TypeInner::Scalar {
+ kind: Sk::Float,
+ width: float_width,
+ },
+ })
+ }
+
+ let fun = match name {
+ "mod" => MacroCall::Mod(size),
+ "step" => MacroCall::Splatted(MathFunction::Step, size, 0),
+ _ => unreachable!(),
+ };
+
+ declaration.overloads.push(module.add_builtin(args, fun))
+ }
+ }
+ "modf" | "frexp" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ for bits in 0..0b100 {
+ let size = match bits {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+
+ let ty = module.types.insert(
+ Type {
+ name: None,
+ inner: match size {
+ Some(size) => TypeInner::Vector {
+ size,
+ kind: Sk::Float,
+ width: float_width,
+ },
+ None => TypeInner::Scalar {
+ kind: Sk::Float,
+ width: float_width,
+ },
+ },
+ },
+ Span::default(),
+ );
+
+ let parameters = vec![ty, ty];
+
+ let fun = match name {
+ "modf" => MacroCall::MathFunction(MathFunction::Modf),
+ "frexp" => MacroCall::MathFunction(MathFunction::Frexp),
+ _ => unreachable!(),
+ };
+
+ declaration.overloads.push(Overload {
+ parameters,
+ parameters_info: vec![
+ ParameterInfo {
+ qualifier: ParameterQualifier::In,
+ depth: false,
+ },
+ ParameterInfo {
+ qualifier: ParameterQualifier::Out,
+ depth: false,
+ },
+ ],
+ kind: FunctionKind::Macro(fun),
+ defined: false,
+ internal: true,
+ void: false,
+ })
+ }
+ }
+ "cross" => {
+ let args = vec![
+ TypeInner::Vector {
+ size: VectorSize::Tri,
+ kind: Sk::Float,
+ width: float_width,
+ },
+ TypeInner::Vector {
+ size: VectorSize::Tri,
+ kind: Sk::Float,
+ width: float_width,
+ },
+ ];
+
+ declaration
+ .overloads
+ .push(module.add_builtin(args, MacroCall::MathFunction(MathFunction::Cross)))
+ }
+ "outerProduct" => {
+ // bits layout
+ // bit 0 through 3 - dims
+ for bits in 0..0b1001 {
+ let (size1, size2) = match bits {
+ 0b0000 => (VectorSize::Bi, VectorSize::Bi),
+ 0b0001 => (VectorSize::Bi, VectorSize::Tri),
+ 0b0010 => (VectorSize::Bi, VectorSize::Quad),
+ 0b0011 => (VectorSize::Tri, VectorSize::Bi),
+ 0b0100 => (VectorSize::Tri, VectorSize::Tri),
+ 0b0101 => (VectorSize::Tri, VectorSize::Quad),
+ 0b0110 => (VectorSize::Quad, VectorSize::Bi),
+ 0b0111 => (VectorSize::Quad, VectorSize::Tri),
+ _ => (VectorSize::Quad, VectorSize::Quad),
+ };
+
+ let args = vec![
+ TypeInner::Vector {
+ size: size1,
+ kind: Sk::Float,
+ width: float_width,
+ },
+ TypeInner::Vector {
+ size: size2,
+ kind: Sk::Float,
+ width: float_width,
+ },
+ ];
+
+ declaration
+ .overloads
+ .push(module.add_builtin(args, MacroCall::MathFunction(MathFunction::Outer)))
+ }
+ }
+ "faceforward" | "fma" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ for bits in 0..0b100 {
+ let size = match bits {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+
+ let ty = || match size {
+ Some(size) => TypeInner::Vector {
+ size,
+ kind: Sk::Float,
+ width: float_width,
+ },
+ None => TypeInner::Scalar {
+ kind: Sk::Float,
+ width: float_width,
+ },
+ };
+ let args = vec![ty(), ty(), ty()];
+
+ let fun = match name {
+ "faceforward" => MacroCall::MathFunction(MathFunction::FaceForward),
+ "fma" => MacroCall::MathFunction(MathFunction::Fma),
+ _ => unreachable!(),
+ };
+
+ declaration.overloads.push(module.add_builtin(args, fun))
+ }
+ }
+ "refract" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ for bits in 0..0b100 {
+ let size = match bits {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+
+ let ty = || match size {
+ Some(size) => TypeInner::Vector {
+ size,
+ kind: Sk::Float,
+ width: float_width,
+ },
+ None => TypeInner::Scalar {
+ kind: Sk::Float,
+ width: float_width,
+ },
+ };
+ let args = vec![
+ ty(),
+ ty(),
+ TypeInner::Scalar {
+ kind: Sk::Float,
+ width: 4,
+ },
+ ];
+ declaration
+ .overloads
+ .push(module.add_builtin(args, MacroCall::MathFunction(MathFunction::Refract)))
+ }
+ }
+ "smoothstep" => {
+ // bit 0 - splatted
+ // bit 1 through 2 - dims
+ for bits in 0..0b1000 {
+ let splatted = bits & 0b1 == 0b1;
+ let size = match bits >> 1 {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+
+ if splatted && size.is_none() {
+ continue;
+ }
+
+ let base_ty = || match size {
+ Some(size) => TypeInner::Vector {
+ size,
+ kind: Sk::Float,
+ width: float_width,
+ },
+ None => TypeInner::Scalar {
+ kind: Sk::Float,
+ width: float_width,
+ },
+ };
+ let ty = || match splatted {
+ true => TypeInner::Scalar {
+ kind: Sk::Float,
+ width: float_width,
+ },
+ false => base_ty(),
+ };
+ declaration.overloads.push(module.add_builtin(
+ vec![ty(), ty(), base_ty()],
+ MacroCall::SmoothStep { splatted: size },
+ ))
+ }
+ }
+ // The function isn't a builtin or we don't yet support it
+ _ => {}
+ }
+}
+
+#[derive(Clone, Copy, PartialEq, Debug)]
+pub enum TextureLevelType {
+ None,
+ Lod,
+ Grad,
+}
+
+/// A compiler defined builtin function
+#[derive(Clone, Copy, PartialEq, Debug)]
+pub enum MacroCall {
+ Sampler,
+ SamplerShadow,
+ Texture {
+ proj: bool,
+ offset: bool,
+ shadow: bool,
+ level_type: TextureLevelType,
+ },
+ TextureSize {
+ arrayed: bool,
+ },
+ ImageLoad {
+ multi: bool,
+ },
+ ImageStore,
+ MathFunction(MathFunction),
+ FindLsbUint,
+ FindMsbUint,
+ BitfieldExtract,
+ BitfieldInsert,
+ Relational(RelationalFunction),
+ Unary(UnaryOperator),
+ Binary(BinaryOperator),
+ Mod(Option<VectorSize>),
+ Splatted(MathFunction, Option<VectorSize>, usize),
+ MixBoolean,
+ Clamp(Option<VectorSize>),
+ BitCast(Sk),
+ Derivate(Axis, Ctrl),
+ Barrier,
+ /// SmoothStep needs a separate variant because it might need it's inputs
+ /// to be splatted depending on the overload
+ SmoothStep {
+ /// The size of the splat operation if some
+ splatted: Option<VectorSize>,
+ },
+}
+
+impl MacroCall {
+ /// Adds the necessary expressions and statements to the passed body and
+ /// finally returns the final expression with the correct result
+ pub fn call(
+ &self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ body: &mut Block,
+ args: &mut [Handle<Expression>],
+ meta: Span,
+ ) -> Result<Option<Handle<Expression>>> {
+ Ok(Some(match *self {
+ MacroCall::Sampler => {
+ ctx.samplers.insert(args[0], args[1]);
+ args[0]
+ }
+ MacroCall::SamplerShadow => {
+ sampled_to_depth(
+ &mut frontend.module,
+ ctx,
+ args[0],
+ meta,
+ &mut frontend.errors,
+ );
+ frontend.invalidate_expression(ctx, args[0], meta)?;
+ ctx.samplers.insert(args[0], args[1]);
+ args[0]
+ }
+ MacroCall::Texture {
+ proj,
+ offset,
+ shadow,
+ level_type,
+ } => {
+ let mut coords = args[1];
+
+ if proj {
+ let size = match *frontend.resolve_type(ctx, coords, meta)? {
+ TypeInner::Vector { size, .. } => size,
+ _ => unreachable!(),
+ };
+ let mut right = ctx.add_expression(
+ Expression::AccessIndex {
+ base: coords,
+ index: size as u32 - 1,
+ },
+ Span::default(),
+ body,
+ );
+ let left = if let VectorSize::Bi = size {
+ ctx.add_expression(
+ Expression::AccessIndex {
+ base: coords,
+ index: 0,
+ },
+ Span::default(),
+ body,
+ )
+ } else {
+ let size = match size {
+ VectorSize::Tri => VectorSize::Bi,
+ _ => VectorSize::Tri,
+ };
+ right = ctx.add_expression(
+ Expression::Splat { size, value: right },
+ Span::default(),
+ body,
+ );
+ ctx.vector_resize(size, coords, Span::default(), body)
+ };
+ coords = ctx.add_expression(
+ Expression::Binary {
+ op: BinaryOperator::Divide,
+ left,
+ right,
+ },
+ Span::default(),
+ body,
+ );
+ }
+
+ let extra = args.get(2).copied();
+ let comps =
+ frontend.coordinate_components(ctx, args[0], coords, extra, meta, body)?;
+
+ let mut num_args = 2;
+
+ if comps.used_extra {
+ num_args += 1;
+ };
+
+ // Parse out explicit texture level.
+ let mut level = match level_type {
+ TextureLevelType::None => SampleLevel::Auto,
+
+ TextureLevelType::Lod => {
+ num_args += 1;
+
+ if shadow {
+ log::warn!("Assuming LOD {:?} is zero", args[2],);
+
+ SampleLevel::Zero
+ } else {
+ SampleLevel::Exact(args[2])
+ }
+ }
+
+ TextureLevelType::Grad => {
+ num_args += 2;
+
+ if shadow {
+ log::warn!(
+ "Assuming gradients {:?} and {:?} are not greater than 1",
+ args[2],
+ args[3],
+ );
+ SampleLevel::Zero
+ } else {
+ SampleLevel::Gradient {
+ x: args[2],
+ y: args[3],
+ }
+ }
+ }
+ };
+
+ let texture_offset = match offset {
+ true => {
+ let offset_arg = args[num_args];
+ num_args += 1;
+ match frontend.solve_constant(ctx, offset_arg, meta) {
+ Ok(v) => Some(v),
+ Err(e) => {
+ frontend.errors.push(e);
+ None
+ }
+ }
+ }
+ false => None,
+ };
+
+ // Now go back and look for optional bias arg (if available)
+ if let TextureLevelType::None = level_type {
+ level = args
+ .get(num_args)
+ .copied()
+ .map_or(SampleLevel::Auto, SampleLevel::Bias);
+ }
+
+ texture_call(ctx, args[0], level, comps, texture_offset, body, meta)?
+ }
+
+ MacroCall::TextureSize { arrayed } => {
+ let mut expr = ctx.add_expression(
+ Expression::ImageQuery {
+ image: args[0],
+ query: ImageQuery::Size {
+ level: args.get(1).copied(),
+ },
+ },
+ Span::default(),
+ body,
+ );
+
+ if arrayed {
+ let mut components = Vec::with_capacity(4);
+
+ let size = match *frontend.resolve_type(ctx, expr, meta)? {
+ TypeInner::Vector { size: ori_size, .. } => {
+ for index in 0..(ori_size as u32) {
+ components.push(ctx.add_expression(
+ Expression::AccessIndex { base: expr, index },
+ Span::default(),
+ body,
+ ))
+ }
+
+ match ori_size {
+ VectorSize::Bi => VectorSize::Tri,
+ _ => VectorSize::Quad,
+ }
+ }
+ _ => {
+ components.push(expr);
+ VectorSize::Bi
+ }
+ };
+
+ components.push(ctx.add_expression(
+ Expression::ImageQuery {
+ image: args[0],
+ query: ImageQuery::NumLayers,
+ },
+ Span::default(),
+ body,
+ ));
+
+ let ty = frontend.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Vector {
+ size,
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ },
+ },
+ Span::default(),
+ );
+
+ expr = ctx.add_expression(Expression::Compose { components, ty }, meta, body)
+ }
+
+ ctx.add_expression(
+ Expression::As {
+ expr,
+ kind: Sk::Sint,
+ convert: Some(4),
+ },
+ Span::default(),
+ body,
+ )
+ }
+ MacroCall::ImageLoad { multi } => {
+ let comps =
+ frontend.coordinate_components(ctx, args[0], args[1], None, meta, body)?;
+ let (sample, level) = match (multi, args.get(2)) {
+ (_, None) => (None, None),
+ (true, Some(&arg)) => (Some(arg), None),
+ (false, Some(&arg)) => (None, Some(arg)),
+ };
+ ctx.add_expression(
+ Expression::ImageLoad {
+ image: args[0],
+ coordinate: comps.coordinate,
+ array_index: comps.array_index,
+ sample,
+ level,
+ },
+ Span::default(),
+ body,
+ )
+ }
+ MacroCall::ImageStore => {
+ let comps =
+ frontend.coordinate_components(ctx, args[0], args[1], None, meta, body)?;
+ ctx.emit_restart(body);
+ body.push(
+ crate::Statement::ImageStore {
+ image: args[0],
+ coordinate: comps.coordinate,
+ array_index: comps.array_index,
+ value: args[2],
+ },
+ meta,
+ );
+ return Ok(None);
+ }
+ MacroCall::MathFunction(fun) => ctx.add_expression(
+ Expression::Math {
+ fun,
+ arg: args[0],
+ arg1: args.get(1).copied(),
+ arg2: args.get(2).copied(),
+ arg3: args.get(3).copied(),
+ },
+ Span::default(),
+ body,
+ ),
+ mc @ (MacroCall::FindLsbUint | MacroCall::FindMsbUint) => {
+ let fun = match mc {
+ MacroCall::FindLsbUint => MathFunction::FindLsb,
+ MacroCall::FindMsbUint => MathFunction::FindMsb,
+ _ => unreachable!(),
+ };
+ let res = ctx.add_expression(
+ Expression::Math {
+ fun,
+ arg: args[0],
+ arg1: None,
+ arg2: None,
+ arg3: None,
+ },
+ Span::default(),
+ body,
+ );
+ ctx.add_expression(
+ Expression::As {
+ expr: res,
+ kind: Sk::Sint,
+ convert: Some(4),
+ },
+ Span::default(),
+ body,
+ )
+ }
+ MacroCall::BitfieldInsert => {
+ let conv_arg_2 = ctx.add_expression(
+ Expression::As {
+ expr: args[2],
+ kind: Sk::Uint,
+ convert: Some(4),
+ },
+ Span::default(),
+ body,
+ );
+ let conv_arg_3 = ctx.add_expression(
+ Expression::As {
+ expr: args[3],
+ kind: Sk::Uint,
+ convert: Some(4),
+ },
+ Span::default(),
+ body,
+ );
+ ctx.add_expression(
+ Expression::Math {
+ fun: MathFunction::InsertBits,
+ arg: args[0],
+ arg1: Some(args[1]),
+ arg2: Some(conv_arg_2),
+ arg3: Some(conv_arg_3),
+ },
+ Span::default(),
+ body,
+ )
+ }
+ MacroCall::BitfieldExtract => {
+ let conv_arg_1 = ctx.add_expression(
+ Expression::As {
+ expr: args[1],
+ kind: Sk::Uint,
+ convert: Some(4),
+ },
+ Span::default(),
+ body,
+ );
+ let conv_arg_2 = ctx.add_expression(
+ Expression::As {
+ expr: args[2],
+ kind: Sk::Uint,
+ convert: Some(4),
+ },
+ Span::default(),
+ body,
+ );
+ ctx.add_expression(
+ Expression::Math {
+ fun: MathFunction::ExtractBits,
+ arg: args[0],
+ arg1: Some(conv_arg_1),
+ arg2: Some(conv_arg_2),
+ arg3: None,
+ },
+ Span::default(),
+ body,
+ )
+ }
+ MacroCall::Relational(fun) => ctx.add_expression(
+ Expression::Relational {
+ fun,
+ argument: args[0],
+ },
+ Span::default(),
+ body,
+ ),
+ MacroCall::Unary(op) => ctx.add_expression(
+ Expression::Unary { op, expr: args[0] },
+ Span::default(),
+ body,
+ ),
+ MacroCall::Binary(op) => ctx.add_expression(
+ Expression::Binary {
+ op,
+ left: args[0],
+ right: args[1],
+ },
+ Span::default(),
+ body,
+ ),
+ MacroCall::Mod(size) => {
+ ctx.implicit_splat(frontend, &mut args[1], meta, size)?;
+
+ // x - y * floor(x / y)
+
+ let div = ctx.add_expression(
+ Expression::Binary {
+ op: BinaryOperator::Divide,
+ left: args[0],
+ right: args[1],
+ },
+ Span::default(),
+ body,
+ );
+ let floor = ctx.add_expression(
+ Expression::Math {
+ fun: MathFunction::Floor,
+ arg: div,
+ arg1: None,
+ arg2: None,
+ arg3: None,
+ },
+ Span::default(),
+ body,
+ );
+ let mult = ctx.add_expression(
+ Expression::Binary {
+ op: BinaryOperator::Multiply,
+ left: floor,
+ right: args[1],
+ },
+ Span::default(),
+ body,
+ );
+ ctx.add_expression(
+ Expression::Binary {
+ op: BinaryOperator::Subtract,
+ left: args[0],
+ right: mult,
+ },
+ Span::default(),
+ body,
+ )
+ }
+ MacroCall::Splatted(fun, size, i) => {
+ ctx.implicit_splat(frontend, &mut args[i], meta, size)?;
+
+ ctx.add_expression(
+ Expression::Math {
+ fun,
+ arg: args[0],
+ arg1: args.get(1).copied(),
+ arg2: args.get(2).copied(),
+ arg3: args.get(3).copied(),
+ },
+ Span::default(),
+ body,
+ )
+ }
+ MacroCall::MixBoolean => ctx.add_expression(
+ Expression::Select {
+ condition: args[2],
+ accept: args[1],
+ reject: args[0],
+ },
+ Span::default(),
+ body,
+ ),
+ MacroCall::Clamp(size) => {
+ ctx.implicit_splat(frontend, &mut args[1], meta, size)?;
+ ctx.implicit_splat(frontend, &mut args[2], meta, size)?;
+
+ ctx.add_expression(
+ Expression::Math {
+ fun: MathFunction::Clamp,
+ arg: args[0],
+ arg1: args.get(1).copied(),
+ arg2: args.get(2).copied(),
+ arg3: args.get(3).copied(),
+ },
+ Span::default(),
+ body,
+ )
+ }
+ MacroCall::BitCast(kind) => ctx.add_expression(
+ Expression::As {
+ expr: args[0],
+ kind,
+ convert: None,
+ },
+ Span::default(),
+ body,
+ ),
+ MacroCall::Derivate(axis, ctrl) => ctx.add_expression(
+ Expression::Derivative {
+ axis,
+ ctrl,
+ expr: args[0],
+ },
+ Span::default(),
+ body,
+ ),
+ MacroCall::Barrier => {
+ ctx.emit_restart(body);
+ body.push(crate::Statement::Barrier(crate::Barrier::all()), meta);
+ return Ok(None);
+ }
+ MacroCall::SmoothStep { splatted } => {
+ ctx.implicit_splat(frontend, &mut args[0], meta, splatted)?;
+ ctx.implicit_splat(frontend, &mut args[1], meta, splatted)?;
+
+ ctx.add_expression(
+ Expression::Math {
+ fun: MathFunction::SmoothStep,
+ arg: args[0],
+ arg1: args.get(1).copied(),
+ arg2: args.get(2).copied(),
+ arg3: None,
+ },
+ Span::default(),
+ body,
+ )
+ }
+ }))
+ }
+}
+
+fn texture_call(
+ ctx: &mut Context,
+ image: Handle<Expression>,
+ level: SampleLevel,
+ comps: CoordComponents,
+ offset: Option<Handle<Constant>>,
+ body: &mut Block,
+ meta: Span,
+) -> Result<Handle<Expression>> {
+ if let Some(sampler) = ctx.samplers.get(&image).copied() {
+ let mut array_index = comps.array_index;
+
+ if let Some(ref mut array_index_expr) = array_index {
+ ctx.conversion(array_index_expr, meta, Sk::Sint, 4)?;
+ }
+
+ Ok(ctx.add_expression(
+ Expression::ImageSample {
+ image,
+ sampler,
+ gather: None, //TODO
+ coordinate: comps.coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref: comps.depth_ref,
+ },
+ meta,
+ body,
+ ))
+ } else {
+ Err(Error {
+ kind: ErrorKind::SemanticError("Bad call".into()),
+ meta,
+ })
+ }
+}
+
+/// Helper struct for texture calls with the separate components from the vector argument
+///
+/// Obtained by calling [`coordinate_components`](Frontend::coordinate_components)
+#[derive(Debug)]
+struct CoordComponents {
+ coordinate: Handle<Expression>,
+ depth_ref: Option<Handle<Expression>>,
+ array_index: Option<Handle<Expression>>,
+ used_extra: bool,
+}
+
+impl Frontend {
+ /// Helper function for texture calls, splits the vector argument into it's components
+ fn coordinate_components(
+ &mut self,
+ ctx: &mut Context,
+ image: Handle<Expression>,
+ coord: Handle<Expression>,
+ extra: Option<Handle<Expression>>,
+ meta: Span,
+ body: &mut Block,
+ ) -> Result<CoordComponents> {
+ if let TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } = *self.resolve_type(ctx, image, meta)?
+ {
+ let image_size = match dim {
+ Dim::D1 => None,
+ Dim::D2 => Some(VectorSize::Bi),
+ Dim::D3 => Some(VectorSize::Tri),
+ Dim::Cube => Some(VectorSize::Tri),
+ };
+ let coord_size = match *self.resolve_type(ctx, coord, meta)? {
+ TypeInner::Vector { size, .. } => Some(size),
+ _ => None,
+ };
+ let (shadow, storage) = match class {
+ ImageClass::Depth { .. } => (true, false),
+ ImageClass::Storage { .. } => (false, true),
+ ImageClass::Sampled { .. } => (false, false),
+ };
+
+ let coordinate = match (image_size, coord_size) {
+ (Some(size), Some(coord_s)) if size != coord_s => {
+ ctx.vector_resize(size, coord, Span::default(), body)
+ }
+ (None, Some(_)) => ctx.add_expression(
+ Expression::AccessIndex {
+ base: coord,
+ index: 0,
+ },
+ Span::default(),
+ body,
+ ),
+ _ => coord,
+ };
+
+ let mut coord_index = image_size.map_or(1, |s| s as u32);
+
+ let array_index = if arrayed && !(storage && dim == Dim::Cube) {
+ let index = coord_index;
+ coord_index += 1;
+
+ Some(ctx.add_expression(
+ Expression::AccessIndex { base: coord, index },
+ Span::default(),
+ body,
+ ))
+ } else {
+ None
+ };
+ let mut used_extra = false;
+ let depth_ref = match shadow {
+ true => {
+ let index = coord_index;
+
+ if index == 4 {
+ used_extra = true;
+ extra
+ } else {
+ Some(ctx.add_expression(
+ Expression::AccessIndex { base: coord, index },
+ Span::default(),
+ body,
+ ))
+ }
+ }
+ false => None,
+ };
+
+ Ok(CoordComponents {
+ coordinate,
+ depth_ref,
+ array_index,
+ used_extra,
+ })
+ } else {
+ self.errors.push(Error {
+ kind: ErrorKind::SemanticError("Type is not an image".into()),
+ meta,
+ });
+
+ Ok(CoordComponents {
+ coordinate: coord,
+ depth_ref: None,
+ array_index: None,
+ used_extra: false,
+ })
+ }
+ }
+}
+
+/// Helper function to cast a expression holding a sampled image to a
+/// depth image.
+pub fn sampled_to_depth(
+ module: &mut Module,
+ ctx: &mut Context,
+ image: Handle<Expression>,
+ meta: Span,
+ errors: &mut Vec<Error>,
+) {
+ // Get the a mutable type handle of the underlying image storage
+ let ty = match ctx[image] {
+ Expression::GlobalVariable(handle) => &mut module.global_variables.get_mut(handle).ty,
+ Expression::FunctionArgument(i) => {
+ // Mark the function argument as carrying a depth texture
+ ctx.parameters_info[i as usize].depth = true;
+ // NOTE: We need to later also change the parameter type
+ &mut ctx.arguments[i as usize].ty
+ }
+ _ => {
+ // Only globals and function arguments are allowed to carry an image
+ return errors.push(Error {
+ kind: ErrorKind::SemanticError("Not a valid texture expression".into()),
+ meta,
+ });
+ }
+ };
+
+ match module.types[*ty].inner {
+ // Update the image class to depth in case it already isn't
+ TypeInner::Image {
+ class,
+ dim,
+ arrayed,
+ } => match class {
+ ImageClass::Sampled { multi, .. } => {
+ *ty = module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Image {
+ dim,
+ arrayed,
+ class: ImageClass::Depth { multi },
+ },
+ },
+ Span::default(),
+ )
+ }
+ ImageClass::Depth { .. } => {}
+ // Other image classes aren't allowed to be transformed to depth
+ ImageClass::Storage { .. } => errors.push(Error {
+ kind: ErrorKind::SemanticError("Not a texture".into()),
+ meta,
+ }),
+ },
+ _ => errors.push(Error {
+ kind: ErrorKind::SemanticError("Not a texture".into()),
+ meta,
+ }),
+ };
+
+ // Copy the handle to allow borrowing the `ctx` again
+ let ty = *ty;
+
+ // If the image was passed through a function argument we also need to change
+ // the corresponding parameter
+ if let Expression::FunctionArgument(i) = ctx[image] {
+ ctx.parameters[i as usize] = ty;
+ }
+}
+
+bitflags::bitflags! {
+ /// Influences the operation `texture_args_generator`
+ struct TextureArgsOptions: u32 {
+ /// Generates multisampled variants of images
+ const MULTI = 1 << 0;
+ /// Generates shadow variants of images
+ const SHADOW = 1 << 1;
+ /// Generates standard images
+ const STANDARD = 1 << 2;
+ /// Generates cube arrayed images
+ const CUBE_ARRAY = 1 << 3;
+ /// Generates cube arrayed images
+ const D2_MULTI_ARRAY = 1 << 4;
+ }
+}
+
+impl From<BuiltinVariations> for TextureArgsOptions {
+ fn from(variations: BuiltinVariations) -> Self {
+ let mut options = TextureArgsOptions::empty();
+ if variations.contains(BuiltinVariations::STANDARD) {
+ options |= TextureArgsOptions::STANDARD
+ }
+ if variations.contains(BuiltinVariations::CUBE_TEXTURES_ARRAY) {
+ options |= TextureArgsOptions::CUBE_ARRAY
+ }
+ if variations.contains(BuiltinVariations::D2_MULTI_TEXTURES_ARRAY) {
+ options |= TextureArgsOptions::D2_MULTI_ARRAY
+ }
+ options
+ }
+}
+
+/// Helper function to generate the image components for texture/image builtins
+///
+/// Calls the passed function `f` with:
+/// ```text
+/// f(ScalarKind, ImageDimension, arrayed, multi, shadow)
+/// ```
+///
+/// `options` controls extra image variants generation like multisampling and depth,
+/// see the struct documentation
+fn texture_args_generator(
+ options: TextureArgsOptions,
+ mut f: impl FnMut(crate::ScalarKind, Dim, bool, bool, bool),
+) {
+ for kind in [Sk::Float, Sk::Uint, Sk::Sint].iter().copied() {
+ for dim in [Dim::D1, Dim::D2, Dim::D3, Dim::Cube].iter().copied() {
+ for arrayed in [false, true].iter().copied() {
+ if dim == Dim::Cube && arrayed {
+ if !options.contains(TextureArgsOptions::CUBE_ARRAY) {
+ continue;
+ }
+ } else if Dim::D2 == dim
+ && options.contains(TextureArgsOptions::MULTI)
+ && arrayed
+ && options.contains(TextureArgsOptions::D2_MULTI_ARRAY)
+ {
+ // multisampling for sampler2DMSArray
+ f(kind, dim, arrayed, true, false);
+ } else if !options.contains(TextureArgsOptions::STANDARD) {
+ continue;
+ }
+
+ f(kind, dim, arrayed, false, false);
+
+ // 3D images can't be neither arrayed nor shadow
+ // so we break out early, this way arrayed will always
+ // be false and we won't hit the shadow branch
+ if let Dim::D3 = dim {
+ break;
+ }
+
+ if Dim::D2 == dim && options.contains(TextureArgsOptions::MULTI) && !arrayed {
+ // multisampling
+ f(kind, dim, arrayed, true, false);
+ }
+
+ if Sk::Float == kind && options.contains(TextureArgsOptions::SHADOW) {
+ // shadow
+ f(kind, dim, arrayed, false, true);
+ }
+ }
+ }
+ }
+}
+
+/// Helper functions used to convert from a image dimension into a integer representing the
+/// number of components needed for the coordinates vector (1 means scalar instead of vector)
+const fn image_dims_to_coords_size(dim: Dim) -> usize {
+ match dim {
+ Dim::D1 => 1,
+ Dim::D2 => 2,
+ _ => 3,
+ }
+}
diff --git a/third_party/rust/naga/src/front/glsl/constants.rs b/third_party/rust/naga/src/front/glsl/constants.rs
new file mode 100644
index 0000000000..045a9c6ffb
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/constants.rs
@@ -0,0 +1,984 @@
+use crate::{
+ arena::{Arena, Handle, UniqueArena},
+ BinaryOperator, Constant, ConstantInner, Expression, ScalarKind, ScalarValue, Type, TypeInner,
+ UnaryOperator,
+};
+
+#[derive(Debug)]
+pub struct ConstantSolver<'a> {
+ pub types: &'a mut UniqueArena<Type>,
+ pub expressions: &'a Arena<Expression>,
+ pub constants: &'a mut Arena<Constant>,
+}
+
+#[derive(Clone, Debug, PartialEq, thiserror::Error)]
+pub enum ConstantSolvingError {
+ #[error("Constants cannot access function arguments")]
+ FunctionArg,
+ #[error("Constants cannot access global variables")]
+ GlobalVariable,
+ #[error("Constants cannot access local variables")]
+ LocalVariable,
+ #[error("Cannot get the array length of a non array type")]
+ InvalidArrayLengthArg,
+ #[error("Constants cannot get the array length of a dynamically sized array")]
+ ArrayLengthDynamic,
+ #[error("Constants cannot call functions")]
+ Call,
+ #[error("Constants don't support atomic functions")]
+ Atomic,
+ #[error("Constants don't support relational functions")]
+ Relational,
+ #[error("Constants don't support derivative functions")]
+ Derivative,
+ #[error("Constants don't support select expressions")]
+ Select,
+ #[error("Constants don't support load expressions")]
+ Load,
+ #[error("Constants don't support image expressions")]
+ ImageExpression,
+ #[error("Constants don't support ray query expressions")]
+ RayQueryExpression,
+ #[error("Cannot access the type")]
+ InvalidAccessBase,
+ #[error("Cannot access at the index")]
+ InvalidAccessIndex,
+ #[error("Cannot access with index of type")]
+ InvalidAccessIndexTy,
+ #[error("Constants don't support bitcasts")]
+ Bitcast,
+ #[error("Cannot cast type")]
+ InvalidCastArg,
+ #[error("Cannot apply the unary op to the argument")]
+ InvalidUnaryOpArg,
+ #[error("Cannot apply the binary op to the arguments")]
+ InvalidBinaryOpArgs,
+ #[error("Cannot apply math function to type")]
+ InvalidMathArg,
+ #[error("Splat is defined only on scalar values")]
+ SplatScalarOnly,
+ #[error("Can only swizzle vector constants")]
+ SwizzleVectorOnly,
+ #[error("Not implemented as constant expression: {0}")]
+ NotImplemented(String),
+}
+
+impl<'a> ConstantSolver<'a> {
+ pub fn solve(
+ &mut self,
+ expr: Handle<Expression>,
+ ) -> Result<Handle<Constant>, ConstantSolvingError> {
+ let span = self.expressions.get_span(expr);
+ match self.expressions[expr] {
+ Expression::Constant(constant) => Ok(constant),
+ Expression::AccessIndex { base, index } => self.access(base, index as usize),
+ Expression::Access { base, index } => {
+ let index = self.solve(index)?;
+
+ self.access(base, self.constant_index(index)?)
+ }
+ Expression::Splat {
+ size,
+ value: splat_value,
+ } => {
+ let value_constant = self.solve(splat_value)?;
+ let ty = match self.constants[value_constant].inner {
+ ConstantInner::Scalar { ref value, width } => {
+ let kind = value.scalar_kind();
+ self.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Vector { size, kind, width },
+ },
+ span,
+ )
+ }
+ ConstantInner::Composite { .. } => {
+ return Err(ConstantSolvingError::SplatScalarOnly);
+ }
+ };
+
+ let inner = ConstantInner::Composite {
+ ty,
+ components: vec![value_constant; size as usize],
+ };
+ Ok(self.register_constant(inner, span))
+ }
+ Expression::Swizzle {
+ size,
+ vector: src_vector,
+ pattern,
+ } => {
+ let src_constant = self.solve(src_vector)?;
+ let (ty, src_components) = match self.constants[src_constant].inner {
+ ConstantInner::Scalar { .. } => {
+ return Err(ConstantSolvingError::SwizzleVectorOnly);
+ }
+ ConstantInner::Composite {
+ ty,
+ components: ref src_components,
+ } => match self.types[ty].inner {
+ crate::TypeInner::Vector {
+ size: _,
+ kind,
+ width,
+ } => {
+ let dst_ty = self.types.insert(
+ Type {
+ name: None,
+ inner: crate::TypeInner::Vector { size, kind, width },
+ },
+ span,
+ );
+ (dst_ty, &src_components[..])
+ }
+ _ => {
+ return Err(ConstantSolvingError::SwizzleVectorOnly);
+ }
+ },
+ };
+
+ let components = pattern
+ .iter()
+ .map(|&sc| src_components[sc as usize])
+ .collect();
+ let inner = ConstantInner::Composite { ty, components };
+
+ Ok(self.register_constant(inner, span))
+ }
+ Expression::Compose { ty, ref components } => {
+ let components = components
+ .iter()
+ .map(|c| self.solve(*c))
+ .collect::<Result<_, _>>()?;
+ let inner = ConstantInner::Composite { ty, components };
+
+ Ok(self.register_constant(inner, span))
+ }
+ Expression::Unary { expr, op } => {
+ let expr_constant = self.solve(expr)?;
+
+ self.unary_op(op, expr_constant, span)
+ }
+ Expression::Binary { left, right, op } => {
+ let left_constant = self.solve(left)?;
+ let right_constant = self.solve(right)?;
+
+ self.binary_op(op, left_constant, right_constant, span)
+ }
+ Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ ..
+ } => {
+ let arg = self.solve(arg)?;
+ let arg1 = arg1.map(|arg| self.solve(arg)).transpose()?;
+ let arg2 = arg2.map(|arg| self.solve(arg)).transpose()?;
+
+ let const0 = &self.constants[arg].inner;
+ let const1 = arg1.map(|arg| &self.constants[arg].inner);
+ let const2 = arg2.map(|arg| &self.constants[arg].inner);
+
+ match fun {
+ crate::MathFunction::Pow => {
+ let (value, width) = match (const0, const1.unwrap()) {
+ (
+ &ConstantInner::Scalar {
+ width,
+ value: value0,
+ },
+ &ConstantInner::Scalar { value: value1, .. },
+ ) => (
+ match (value0, value1) {
+ (ScalarValue::Sint(a), ScalarValue::Sint(b)) => {
+ ScalarValue::Sint(a.pow(b as u32))
+ }
+ (ScalarValue::Uint(a), ScalarValue::Uint(b)) => {
+ ScalarValue::Uint(a.pow(b as u32))
+ }
+ (ScalarValue::Float(a), ScalarValue::Float(b)) => {
+ ScalarValue::Float(a.powf(b))
+ }
+ _ => return Err(ConstantSolvingError::InvalidMathArg),
+ },
+ width,
+ ),
+ _ => return Err(ConstantSolvingError::InvalidMathArg),
+ };
+
+ let inner = ConstantInner::Scalar { width, value };
+ Ok(self.register_constant(inner, span))
+ }
+ crate::MathFunction::Clamp => {
+ let (value, width) = match (const0, const1.unwrap(), const2.unwrap()) {
+ (
+ &ConstantInner::Scalar {
+ width,
+ value: value0,
+ },
+ &ConstantInner::Scalar { value: value1, .. },
+ &ConstantInner::Scalar { value: value2, .. },
+ ) => (
+ match (value0, value1, value2) {
+ (
+ ScalarValue::Sint(a),
+ ScalarValue::Sint(b),
+ ScalarValue::Sint(c),
+ ) => ScalarValue::Sint(a.clamp(b, c)),
+ (
+ ScalarValue::Uint(a),
+ ScalarValue::Uint(b),
+ ScalarValue::Uint(c),
+ ) => ScalarValue::Uint(a.clamp(b, c)),
+ (
+ ScalarValue::Float(a),
+ ScalarValue::Float(b),
+ ScalarValue::Float(c),
+ ) => ScalarValue::Float(glsl_float_clamp(a, b, c)),
+ _ => return Err(ConstantSolvingError::InvalidMathArg),
+ },
+ width,
+ ),
+ _ => {
+ return Err(ConstantSolvingError::NotImplemented(format!(
+ "{fun:?} applied to vector values"
+ )))
+ }
+ };
+
+ let inner = ConstantInner::Scalar { width, value };
+ Ok(self.register_constant(inner, span))
+ }
+ _ => Err(ConstantSolvingError::NotImplemented(format!("{fun:?}"))),
+ }
+ }
+ Expression::As {
+ convert,
+ expr,
+ kind,
+ } => {
+ let expr_constant = self.solve(expr)?;
+
+ match convert {
+ Some(width) => self.cast(expr_constant, kind, width, span),
+ None => Err(ConstantSolvingError::Bitcast),
+ }
+ }
+ Expression::ArrayLength(expr) => {
+ let array = self.solve(expr)?;
+
+ match self.constants[array].inner {
+ ConstantInner::Scalar { .. } => {
+ Err(ConstantSolvingError::InvalidArrayLengthArg)
+ }
+ ConstantInner::Composite { ty, .. } => match self.types[ty].inner {
+ TypeInner::Array { size, .. } => match size {
+ crate::ArraySize::Constant(constant) => Ok(constant),
+ crate::ArraySize::Dynamic => {
+ Err(ConstantSolvingError::ArrayLengthDynamic)
+ }
+ },
+ _ => Err(ConstantSolvingError::InvalidArrayLengthArg),
+ },
+ }
+ }
+
+ Expression::Load { .. } => Err(ConstantSolvingError::Load),
+ Expression::Select { .. } => Err(ConstantSolvingError::Select),
+ Expression::LocalVariable(_) => Err(ConstantSolvingError::LocalVariable),
+ Expression::Derivative { .. } => Err(ConstantSolvingError::Derivative),
+ Expression::Relational { .. } => Err(ConstantSolvingError::Relational),
+ Expression::CallResult { .. } => Err(ConstantSolvingError::Call),
+ Expression::AtomicResult { .. } => Err(ConstantSolvingError::Atomic),
+ Expression::FunctionArgument(_) => Err(ConstantSolvingError::FunctionArg),
+ Expression::GlobalVariable(_) => Err(ConstantSolvingError::GlobalVariable),
+ Expression::ImageSample { .. }
+ | Expression::ImageLoad { .. }
+ | Expression::ImageQuery { .. } => Err(ConstantSolvingError::ImageExpression),
+ Expression::RayQueryProceedResult | Expression::RayQueryGetIntersection { .. } => {
+ Err(ConstantSolvingError::RayQueryExpression)
+ }
+ }
+ }
+
+ fn access(
+ &mut self,
+ base: Handle<Expression>,
+ index: usize,
+ ) -> Result<Handle<Constant>, ConstantSolvingError> {
+ let base = self.solve(base)?;
+
+ match self.constants[base].inner {
+ ConstantInner::Scalar { .. } => Err(ConstantSolvingError::InvalidAccessBase),
+ ConstantInner::Composite { ty, ref components } => {
+ match self.types[ty].inner {
+ TypeInner::Vector { .. }
+ | TypeInner::Matrix { .. }
+ | TypeInner::Array { .. }
+ | TypeInner::Struct { .. } => (),
+ _ => return Err(ConstantSolvingError::InvalidAccessBase),
+ }
+
+ components
+ .get(index)
+ .copied()
+ .ok_or(ConstantSolvingError::InvalidAccessIndex)
+ }
+ }
+ }
+
+ fn constant_index(&self, constant: Handle<Constant>) -> Result<usize, ConstantSolvingError> {
+ match self.constants[constant].inner {
+ ConstantInner::Scalar {
+ value: ScalarValue::Uint(index),
+ ..
+ } => Ok(index as usize),
+ _ => Err(ConstantSolvingError::InvalidAccessIndexTy),
+ }
+ }
+
+ fn cast(
+ &mut self,
+ constant: Handle<Constant>,
+ kind: ScalarKind,
+ target_width: crate::Bytes,
+ span: crate::Span,
+ ) -> Result<Handle<Constant>, ConstantSolvingError> {
+ let mut inner = self.constants[constant].inner.clone();
+
+ match inner {
+ ConstantInner::Scalar {
+ ref mut value,
+ ref mut width,
+ } => {
+ *width = target_width;
+ *value = match kind {
+ ScalarKind::Sint => ScalarValue::Sint(match *value {
+ ScalarValue::Sint(v) => v,
+ ScalarValue::Uint(v) => v as i64,
+ ScalarValue::Float(v) => v as i64,
+ ScalarValue::Bool(v) => v as i64,
+ }),
+ ScalarKind::Uint => ScalarValue::Uint(match *value {
+ ScalarValue::Sint(v) => v as u64,
+ ScalarValue::Uint(v) => v,
+ ScalarValue::Float(v) => v as u64,
+ ScalarValue::Bool(v) => v as u64,
+ }),
+ ScalarKind::Float => ScalarValue::Float(match *value {
+ ScalarValue::Sint(v) => v as f64,
+ ScalarValue::Uint(v) => v as f64,
+ ScalarValue::Float(v) => v,
+ ScalarValue::Bool(v) => v as u64 as f64,
+ }),
+ ScalarKind::Bool => ScalarValue::Bool(match *value {
+ ScalarValue::Sint(v) => v != 0,
+ ScalarValue::Uint(v) => v != 0,
+ ScalarValue::Float(v) => v != 0.0,
+ ScalarValue::Bool(v) => v,
+ }),
+ }
+ }
+ ConstantInner::Composite {
+ ty,
+ ref mut components,
+ } => {
+ match self.types[ty].inner {
+ TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
+ _ => return Err(ConstantSolvingError::InvalidCastArg),
+ }
+
+ for component in components {
+ *component = self.cast(*component, kind, target_width, span)?;
+ }
+ }
+ }
+
+ Ok(self.register_constant(inner, span))
+ }
+
+ fn unary_op(
+ &mut self,
+ op: UnaryOperator,
+ constant: Handle<Constant>,
+ span: crate::Span,
+ ) -> Result<Handle<Constant>, ConstantSolvingError> {
+ let mut inner = self.constants[constant].inner.clone();
+
+ match inner {
+ ConstantInner::Scalar { ref mut value, .. } => match op {
+ UnaryOperator::Negate => match *value {
+ ScalarValue::Sint(ref mut v) => *v = -*v,
+ ScalarValue::Float(ref mut v) => *v = -*v,
+ _ => return Err(ConstantSolvingError::InvalidUnaryOpArg),
+ },
+ UnaryOperator::Not => match *value {
+ ScalarValue::Sint(ref mut v) => *v = !*v,
+ ScalarValue::Uint(ref mut v) => *v = !*v,
+ ScalarValue::Bool(ref mut v) => *v = !*v,
+ ScalarValue::Float(_) => return Err(ConstantSolvingError::InvalidUnaryOpArg),
+ },
+ },
+ ConstantInner::Composite {
+ ty,
+ ref mut components,
+ } => {
+ match self.types[ty].inner {
+ TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
+ _ => return Err(ConstantSolvingError::InvalidCastArg),
+ }
+
+ for component in components {
+ *component = self.unary_op(op, *component, span)?
+ }
+ }
+ }
+
+ Ok(self.register_constant(inner, span))
+ }
+
+ fn binary_op(
+ &mut self,
+ op: BinaryOperator,
+ left: Handle<Constant>,
+ right: Handle<Constant>,
+ span: crate::Span,
+ ) -> Result<Handle<Constant>, ConstantSolvingError> {
+ let left_inner = &self.constants[left].inner;
+ let right_inner = &self.constants[right].inner;
+
+ let inner = match (left_inner, right_inner) {
+ (
+ &ConstantInner::Scalar {
+ value: left_value,
+ width,
+ },
+ &ConstantInner::Scalar {
+ value: right_value,
+ width: _,
+ },
+ ) => {
+ let value = match op {
+ BinaryOperator::Equal => ScalarValue::Bool(left_value == right_value),
+ BinaryOperator::NotEqual => ScalarValue::Bool(left_value != right_value),
+ BinaryOperator::Less => ScalarValue::Bool(left_value < right_value),
+ BinaryOperator::LessEqual => ScalarValue::Bool(left_value <= right_value),
+ BinaryOperator::Greater => ScalarValue::Bool(left_value > right_value),
+ BinaryOperator::GreaterEqual => ScalarValue::Bool(left_value >= right_value),
+
+ _ => match (left_value, right_value) {
+ (ScalarValue::Sint(a), ScalarValue::Sint(b)) => {
+ ScalarValue::Sint(match op {
+ BinaryOperator::Add => a.wrapping_add(b),
+ BinaryOperator::Subtract => a.wrapping_sub(b),
+ BinaryOperator::Multiply => a.wrapping_mul(b),
+ BinaryOperator::Divide => a.checked_div(b).unwrap_or(0),
+ BinaryOperator::Modulo => a.checked_rem(b).unwrap_or(0),
+ BinaryOperator::And => a & b,
+ BinaryOperator::ExclusiveOr => a ^ b,
+ BinaryOperator::InclusiveOr => a | b,
+ _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
+ })
+ }
+ (ScalarValue::Sint(a), ScalarValue::Uint(b)) => {
+ ScalarValue::Sint(match op {
+ BinaryOperator::ShiftLeft => a.wrapping_shl(b as u32),
+ BinaryOperator::ShiftRight => a.wrapping_shr(b as u32),
+ _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
+ })
+ }
+ (ScalarValue::Uint(a), ScalarValue::Uint(b)) => {
+ ScalarValue::Uint(match op {
+ BinaryOperator::Add => a.wrapping_add(b),
+ BinaryOperator::Subtract => a.wrapping_sub(b),
+ BinaryOperator::Multiply => a.wrapping_mul(b),
+ BinaryOperator::Divide => a.checked_div(b).unwrap_or(0),
+ BinaryOperator::Modulo => a.checked_rem(b).unwrap_or(0),
+ BinaryOperator::And => a & b,
+ BinaryOperator::ExclusiveOr => a ^ b,
+ BinaryOperator::InclusiveOr => a | b,
+ BinaryOperator::ShiftLeft => a.wrapping_shl(b as u32),
+ BinaryOperator::ShiftRight => a.wrapping_shr(b as u32),
+ _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
+ })
+ }
+ (ScalarValue::Float(a), ScalarValue::Float(b)) => {
+ ScalarValue::Float(match op {
+ BinaryOperator::Add => a + b,
+ BinaryOperator::Subtract => a - b,
+ BinaryOperator::Multiply => a * b,
+ BinaryOperator::Divide => a / b,
+ BinaryOperator::Modulo => a - b * (a / b).floor(),
+ _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
+ })
+ }
+ (ScalarValue::Bool(a), ScalarValue::Bool(b)) => {
+ ScalarValue::Bool(match op {
+ BinaryOperator::LogicalAnd => a && b,
+ BinaryOperator::LogicalOr => a || b,
+ _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
+ })
+ }
+ _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
+ },
+ };
+
+ ConstantInner::Scalar { value, width }
+ }
+ (&ConstantInner::Composite { ref components, ty }, &ConstantInner::Scalar { .. }) => {
+ let mut components = components.clone();
+ for comp in components.iter_mut() {
+ *comp = self.binary_op(op, *comp, right, span)?;
+ }
+ ConstantInner::Composite { ty, components }
+ }
+ (&ConstantInner::Scalar { .. }, &ConstantInner::Composite { ref components, ty }) => {
+ let mut components = components.clone();
+ for comp in components.iter_mut() {
+ *comp = self.binary_op(op, left, *comp, span)?;
+ }
+ ConstantInner::Composite { ty, components }
+ }
+ _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
+ };
+
+ Ok(self.register_constant(inner, span))
+ }
+
+ fn register_constant(&mut self, inner: ConstantInner, span: crate::Span) -> Handle<Constant> {
+ self.constants.fetch_or_append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner,
+ },
+ span,
+ )
+ }
+}
+
+/// Helper function to implement the GLSL `max` function for floats.
+///
+/// While Rust does provide a `f64::max` method, it has a different behavior than the
+/// GLSL `max` for NaNs. In Rust, if any of the arguments is a NaN, then the other
+/// is returned.
+///
+/// This leads to different results in the following example
+/// ```
+/// use std::cmp::max;
+/// std::f64::NAN.max(1.0);
+/// ```
+///
+/// Rust will return `1.0` while GLSL should return NaN.
+fn glsl_float_max(x: f64, y: f64) -> f64 {
+ if x < y {
+ y
+ } else {
+ x
+ }
+}
+
+/// Helper function to implement the GLSL `min` function for floats.
+///
+/// While Rust does provide a `f64::min` method, it has a different behavior than the
+/// GLSL `min` for NaNs. In Rust, if any of the arguments is a NaN, then the other
+/// is returned.
+///
+/// This leads to different results in the following example
+/// ```
+/// use std::cmp::min;
+/// std::f64::NAN.min(1.0);
+/// ```
+///
+/// Rust will return `1.0` while GLSL should return NaN.
+fn glsl_float_min(x: f64, y: f64) -> f64 {
+ if y < x {
+ y
+ } else {
+ x
+ }
+}
+
+/// Helper function to implement the GLSL `clamp` function for floats.
+///
+/// While Rust does provide a `f64::clamp` method, it panics if either
+/// `min` or `max` are `NaN`s which is not the behavior specified by
+/// the glsl specification.
+fn glsl_float_clamp(value: f64, min: f64, max: f64) -> f64 {
+ glsl_float_min(glsl_float_max(value, min), max)
+}
+
+#[cfg(test)]
+mod tests {
+ use std::vec;
+
+ use crate::{
+ Arena, Constant, ConstantInner, Expression, ScalarKind, ScalarValue, Type, TypeInner,
+ UnaryOperator, UniqueArena, VectorSize,
+ };
+
+ use super::ConstantSolver;
+
+ #[test]
+ fn nan_handling() {
+ assert!(super::glsl_float_max(f64::NAN, 2.0).is_nan());
+ assert!(!super::glsl_float_max(2.0, f64::NAN).is_nan());
+
+ assert!(super::glsl_float_min(f64::NAN, 2.0).is_nan());
+ assert!(!super::glsl_float_min(2.0, f64::NAN).is_nan());
+
+ assert!(super::glsl_float_clamp(f64::NAN, 1.0, 2.0).is_nan());
+ assert!(!super::glsl_float_clamp(1.0, f64::NAN, 2.0).is_nan());
+ assert!(!super::glsl_float_clamp(1.0, 2.0, f64::NAN).is_nan());
+ }
+
+ #[test]
+ fn unary_op() {
+ let mut types = UniqueArena::new();
+ let mut expressions = Arena::new();
+ let mut constants = Arena::new();
+
+ let vec_ty = types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Vector {
+ size: VectorSize::Bi,
+ kind: ScalarKind::Sint,
+ width: 4,
+ },
+ },
+ Default::default(),
+ );
+
+ let h = constants.append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner: ConstantInner::Scalar {
+ width: 4,
+ value: ScalarValue::Sint(4),
+ },
+ },
+ Default::default(),
+ );
+
+ let h1 = constants.append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner: ConstantInner::Scalar {
+ width: 4,
+ value: ScalarValue::Sint(8),
+ },
+ },
+ Default::default(),
+ );
+
+ let vec_h = constants.append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner: ConstantInner::Composite {
+ ty: vec_ty,
+ components: vec![h, h1],
+ },
+ },
+ Default::default(),
+ );
+
+ let expr = expressions.append(Expression::Constant(h), Default::default());
+ let expr1 = expressions.append(Expression::Constant(vec_h), Default::default());
+
+ let root1 = expressions.append(
+ Expression::Unary {
+ op: UnaryOperator::Negate,
+ expr,
+ },
+ Default::default(),
+ );
+
+ let root2 = expressions.append(
+ Expression::Unary {
+ op: UnaryOperator::Not,
+ expr,
+ },
+ Default::default(),
+ );
+
+ let root3 = expressions.append(
+ Expression::Unary {
+ op: UnaryOperator::Not,
+ expr: expr1,
+ },
+ Default::default(),
+ );
+
+ let mut solver = ConstantSolver {
+ types: &mut types,
+ expressions: &expressions,
+ constants: &mut constants,
+ };
+
+ let res1 = solver.solve(root1).unwrap();
+ let res2 = solver.solve(root2).unwrap();
+ let res3 = solver.solve(root3).unwrap();
+
+ assert_eq!(
+ constants[res1].inner,
+ ConstantInner::Scalar {
+ width: 4,
+ value: ScalarValue::Sint(-4),
+ },
+ );
+
+ assert_eq!(
+ constants[res2].inner,
+ ConstantInner::Scalar {
+ width: 4,
+ value: ScalarValue::Sint(!4),
+ },
+ );
+
+ let res3_inner = &constants[res3].inner;
+
+ match *res3_inner {
+ ConstantInner::Composite {
+ ref ty,
+ ref components,
+ } => {
+ assert_eq!(*ty, vec_ty);
+ let mut components_iter = components.iter().copied();
+ assert_eq!(
+ constants[components_iter.next().unwrap()].inner,
+ ConstantInner::Scalar {
+ width: 4,
+ value: ScalarValue::Sint(!4),
+ },
+ );
+ assert_eq!(
+ constants[components_iter.next().unwrap()].inner,
+ ConstantInner::Scalar {
+ width: 4,
+ value: ScalarValue::Sint(!8),
+ },
+ );
+ assert!(components_iter.next().is_none());
+ }
+ _ => panic!("Expected vector"),
+ }
+ }
+
+ #[test]
+ fn cast() {
+ let mut expressions = Arena::new();
+ let mut constants = Arena::new();
+
+ let h = constants.append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner: ConstantInner::Scalar {
+ width: 4,
+ value: ScalarValue::Sint(4),
+ },
+ },
+ Default::default(),
+ );
+
+ let expr = expressions.append(Expression::Constant(h), Default::default());
+
+ let root = expressions.append(
+ Expression::As {
+ expr,
+ kind: ScalarKind::Bool,
+ convert: Some(crate::BOOL_WIDTH),
+ },
+ Default::default(),
+ );
+
+ let mut solver = ConstantSolver {
+ types: &mut UniqueArena::new(),
+ expressions: &expressions,
+ constants: &mut constants,
+ };
+
+ let res = solver.solve(root).unwrap();
+
+ assert_eq!(
+ constants[res].inner,
+ ConstantInner::Scalar {
+ width: crate::BOOL_WIDTH,
+ value: ScalarValue::Bool(true),
+ },
+ );
+ }
+
+ #[test]
+ fn access() {
+ let mut types = UniqueArena::new();
+ let mut expressions = Arena::new();
+ let mut constants = Arena::new();
+
+ let matrix_ty = types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Matrix {
+ columns: VectorSize::Bi,
+ rows: VectorSize::Tri,
+ width: 4,
+ },
+ },
+ Default::default(),
+ );
+
+ let vec_ty = types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Vector {
+ size: VectorSize::Tri,
+ kind: ScalarKind::Float,
+ width: 4,
+ },
+ },
+ Default::default(),
+ );
+
+ let mut vec1_components = Vec::with_capacity(3);
+ let mut vec2_components = Vec::with_capacity(3);
+
+ for i in 0..3 {
+ let h = constants.append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner: ConstantInner::Scalar {
+ width: 4,
+ value: ScalarValue::Float(i as f64),
+ },
+ },
+ Default::default(),
+ );
+
+ vec1_components.push(h)
+ }
+
+ for i in 3..6 {
+ let h = constants.append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner: ConstantInner::Scalar {
+ width: 4,
+ value: ScalarValue::Float(i as f64),
+ },
+ },
+ Default::default(),
+ );
+
+ vec2_components.push(h)
+ }
+
+ let vec1 = constants.append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner: ConstantInner::Composite {
+ ty: vec_ty,
+ components: vec1_components,
+ },
+ },
+ Default::default(),
+ );
+
+ let vec2 = constants.append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner: ConstantInner::Composite {
+ ty: vec_ty,
+ components: vec2_components,
+ },
+ },
+ Default::default(),
+ );
+
+ let h = constants.append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner: ConstantInner::Composite {
+ ty: matrix_ty,
+ components: vec![vec1, vec2],
+ },
+ },
+ Default::default(),
+ );
+
+ let base = expressions.append(Expression::Constant(h), Default::default());
+ let root1 = expressions.append(
+ Expression::AccessIndex { base, index: 1 },
+ Default::default(),
+ );
+ let root2 = expressions.append(
+ Expression::AccessIndex {
+ base: root1,
+ index: 2,
+ },
+ Default::default(),
+ );
+
+ let mut solver = ConstantSolver {
+ types: &mut types,
+ expressions: &expressions,
+ constants: &mut constants,
+ };
+
+ let res1 = solver.solve(root1).unwrap();
+ let res2 = solver.solve(root2).unwrap();
+
+ let res1_inner = &constants[res1].inner;
+
+ match *res1_inner {
+ ConstantInner::Composite {
+ ref ty,
+ ref components,
+ } => {
+ assert_eq!(*ty, vec_ty);
+ let mut components_iter = components.iter().copied();
+ assert_eq!(
+ constants[components_iter.next().unwrap()].inner,
+ ConstantInner::Scalar {
+ width: 4,
+ value: ScalarValue::Float(3.),
+ },
+ );
+ assert_eq!(
+ constants[components_iter.next().unwrap()].inner,
+ ConstantInner::Scalar {
+ width: 4,
+ value: ScalarValue::Float(4.),
+ },
+ );
+ assert_eq!(
+ constants[components_iter.next().unwrap()].inner,
+ ConstantInner::Scalar {
+ width: 4,
+ value: ScalarValue::Float(5.),
+ },
+ );
+ assert!(components_iter.next().is_none());
+ }
+ _ => panic!("Expected vector"),
+ }
+
+ assert_eq!(
+ constants[res2].inner,
+ ConstantInner::Scalar {
+ width: 4,
+ value: ScalarValue::Float(5.),
+ },
+ );
+ }
+}
diff --git a/third_party/rust/naga/src/front/glsl/context.rs b/third_party/rust/naga/src/front/glsl/context.rs
new file mode 100644
index 0000000000..6df4850efa
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/context.rs
@@ -0,0 +1,1609 @@
+use super::{
+ ast::{
+ GlobalLookup, GlobalLookupKind, HirExpr, HirExprKind, ParameterInfo, ParameterQualifier,
+ VariableReference,
+ },
+ error::{Error, ErrorKind},
+ types::{scalar_components, type_power},
+ Frontend, Result,
+};
+use crate::{
+ front::{Emitter, Typifier},
+ AddressSpace, Arena, BinaryOperator, Block, Constant, Expression, FastHashMap,
+ FunctionArgument, Handle, LocalVariable, RelationalFunction, ScalarKind, ScalarValue, Span,
+ Statement, Type, TypeInner, VectorSize,
+};
+use std::{convert::TryFrom, ops::Index};
+
+/// The position at which an expression is, used while lowering
+#[derive(Clone, Copy, PartialEq, Eq, Debug)]
+pub enum ExprPos {
+ /// The expression is in the left hand side of an assignment
+ Lhs,
+ /// The expression is in the right hand side of an assignment
+ Rhs,
+ /// The expression is an array being indexed, needed to allow constant
+ /// arrays to be dynamically indexed
+ AccessBase {
+ /// The index is a constant
+ constant_index: bool,
+ },
+}
+
+impl ExprPos {
+ /// Returns an lhs position if the current position is lhs otherwise AccessBase
+ const fn maybe_access_base(&self, constant_index: bool) -> Self {
+ match *self {
+ ExprPos::Lhs
+ | ExprPos::AccessBase {
+ constant_index: false,
+ } => *self,
+ _ => ExprPos::AccessBase { constant_index },
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct Context {
+ pub expressions: Arena<Expression>,
+ pub locals: Arena<LocalVariable>,
+
+ /// The [`FunctionArgument`]s for the final [`crate::Function`].
+ ///
+ /// Parameters with the `out` and `inout` qualifiers have [`Pointer`] types
+ /// here. For example, an `inout vec2 a` argument would be a [`Pointer`] to
+ /// a [`Vector`].
+ ///
+ /// [`Pointer`]: crate::TypeInner::Pointer
+ /// [`Vector`]: crate::TypeInner::Vector
+ pub arguments: Vec<FunctionArgument>,
+
+ /// The parameter types given in the source code.
+ ///
+ /// The `out` and `inout` qualifiers don't affect the types that appear
+ /// here. For example, an `inout vec2 a` argument would simply be a
+ /// [`Vector`], not a pointer to one.
+ ///
+ /// [`Vector`]: crate::TypeInner::Vector
+ pub parameters: Vec<Handle<Type>>,
+ pub parameters_info: Vec<ParameterInfo>,
+
+ pub symbol_table: crate::front::SymbolTable<String, VariableReference>,
+ pub samplers: FastHashMap<Handle<Expression>, Handle<Expression>>,
+
+ pub typifier: Typifier,
+ emitter: Emitter,
+ stmt_ctx: Option<StmtContext>,
+}
+
+impl Context {
+ pub fn new(frontend: &Frontend, body: &mut Block) -> Self {
+ let mut this = Context {
+ expressions: Arena::new(),
+ locals: Arena::new(),
+ arguments: Vec::new(),
+
+ parameters: Vec::new(),
+ parameters_info: Vec::new(),
+
+ symbol_table: crate::front::SymbolTable::default(),
+ samplers: FastHashMap::default(),
+
+ typifier: Typifier::new(),
+ emitter: Emitter::default(),
+ stmt_ctx: Some(StmtContext::new()),
+ };
+
+ this.emit_start();
+
+ for &(ref name, lookup) in frontend.global_variables.iter() {
+ this.add_global(frontend, name, lookup, body)
+ }
+
+ this
+ }
+
+ pub fn add_global(
+ &mut self,
+ frontend: &Frontend,
+ name: &str,
+ GlobalLookup {
+ kind,
+ entry_arg,
+ mutable,
+ }: GlobalLookup,
+ body: &mut Block,
+ ) {
+ self.emit_end(body);
+ let (expr, load, constant) = match kind {
+ GlobalLookupKind::Variable(v) => {
+ let span = frontend.module.global_variables.get_span(v);
+ let res = (
+ self.expressions.append(Expression::GlobalVariable(v), span),
+ frontend.module.global_variables[v].space != AddressSpace::Handle,
+ None,
+ );
+ self.emit_start();
+
+ res
+ }
+ GlobalLookupKind::BlockSelect(handle, index) => {
+ let span = frontend.module.global_variables.get_span(handle);
+ let base = self
+ .expressions
+ .append(Expression::GlobalVariable(handle), span);
+ self.emit_start();
+ let expr = self
+ .expressions
+ .append(Expression::AccessIndex { base, index }, span);
+
+ (
+ expr,
+ {
+ let ty = frontend.module.global_variables[handle].ty;
+
+ match frontend.module.types[ty].inner {
+ TypeInner::Struct { ref members, .. } => {
+ if let TypeInner::Array {
+ size: crate::ArraySize::Dynamic,
+ ..
+ } = frontend.module.types[members[index as usize].ty].inner
+ {
+ false
+ } else {
+ true
+ }
+ }
+ _ => true,
+ }
+ },
+ None,
+ )
+ }
+ GlobalLookupKind::Constant(v, ty) => {
+ let span = frontend.module.constants.get_span(v);
+ let res = (
+ self.expressions.append(Expression::Constant(v), span),
+ false,
+ Some((v, ty)),
+ );
+ self.emit_start();
+ res
+ }
+ };
+
+ let var = VariableReference {
+ expr,
+ load,
+ mutable,
+ constant,
+ entry_arg,
+ };
+
+ self.symbol_table.add(name.into(), var);
+ }
+
+ /// Starts the expression emitter
+ ///
+ /// # Panics
+ ///
+ /// - If called twice in a row without calling [`emit_end`][Self::emit_end].
+ #[inline]
+ pub fn emit_start(&mut self) {
+ self.emitter.start(&self.expressions)
+ }
+
+ /// Emits all the expressions captured by the emitter to the passed `body`
+ ///
+ /// # Panics
+ ///
+ /// - If called before calling [`emit_start`].
+ /// - If called twice in a row without calling [`emit_start`].
+ ///
+ /// [`emit_start`]: Self::emit_start
+ pub fn emit_end(&mut self, body: &mut Block) {
+ body.extend(self.emitter.finish(&self.expressions))
+ }
+
+ /// Emits all the expressions captured by the emitter to the passed `body`
+ /// and starts the emitter again
+ ///
+ /// # Panics
+ ///
+ /// - If called before calling [`emit_start`][Self::emit_start].
+ pub fn emit_restart(&mut self, body: &mut Block) {
+ self.emit_end(body);
+ self.emit_start()
+ }
+
+ pub fn add_expression(
+ &mut self,
+ expr: Expression,
+ meta: Span,
+ body: &mut Block,
+ ) -> Handle<Expression> {
+ let needs_pre_emit = expr.needs_pre_emit();
+ if needs_pre_emit {
+ self.emit_end(body);
+ }
+ let handle = self.expressions.append(expr, meta);
+ if needs_pre_emit {
+ self.emit_start();
+ }
+ handle
+ }
+
+ /// Add variable to current scope
+ ///
+ /// Returns a variable if a variable with the same name was already defined,
+ /// otherwise returns `None`
+ pub fn add_local_var(
+ &mut self,
+ name: String,
+ expr: Handle<Expression>,
+ mutable: bool,
+ ) -> Option<VariableReference> {
+ let var = VariableReference {
+ expr,
+ load: true,
+ mutable,
+ constant: None,
+ entry_arg: None,
+ };
+
+ self.symbol_table.add(name, var)
+ }
+
+ /// Add function argument to current scope
+ pub fn add_function_arg(
+ &mut self,
+ frontend: &mut Frontend,
+ body: &mut Block,
+ name_meta: Option<(String, Span)>,
+ ty: Handle<Type>,
+ qualifier: ParameterQualifier,
+ ) {
+ let index = self.arguments.len();
+ let mut arg = FunctionArgument {
+ name: name_meta.as_ref().map(|&(ref name, _)| name.clone()),
+ ty,
+ binding: None,
+ };
+ self.parameters.push(ty);
+
+ let opaque = match frontend.module.types[ty].inner {
+ TypeInner::Image { .. } | TypeInner::Sampler { .. } => true,
+ _ => false,
+ };
+
+ if qualifier.is_lhs() {
+ let span = frontend.module.types.get_span(arg.ty);
+ arg.ty = frontend.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Pointer {
+ base: arg.ty,
+ space: AddressSpace::Function,
+ },
+ },
+ span,
+ )
+ }
+
+ self.arguments.push(arg);
+
+ self.parameters_info.push(ParameterInfo {
+ qualifier,
+ depth: false,
+ });
+
+ if let Some((name, meta)) = name_meta {
+ let expr = self.add_expression(Expression::FunctionArgument(index as u32), meta, body);
+ let mutable = qualifier != ParameterQualifier::Const && !opaque;
+ let load = qualifier.is_lhs();
+
+ let var = if mutable && !load {
+ let handle = self.locals.append(
+ LocalVariable {
+ name: Some(name.clone()),
+ ty,
+ init: None,
+ },
+ meta,
+ );
+ let local_expr = self.add_expression(Expression::LocalVariable(handle), meta, body);
+
+ self.emit_restart(body);
+
+ body.push(
+ Statement::Store {
+ pointer: local_expr,
+ value: expr,
+ },
+ meta,
+ );
+
+ VariableReference {
+ expr: local_expr,
+ load: true,
+ mutable,
+ constant: None,
+ entry_arg: None,
+ }
+ } else {
+ VariableReference {
+ expr,
+ load,
+ mutable,
+ constant: None,
+ entry_arg: None,
+ }
+ };
+
+ self.symbol_table.add(name, var);
+ }
+ }
+
+ /// Returns a [`StmtContext`](StmtContext) to be used in parsing and lowering
+ ///
+ /// # Panics
+ /// - If more than one [`StmtContext`](StmtContext) are active at the same
+ /// time or if the previous call didn't use it in lowering.
+ #[must_use]
+ pub fn stmt_ctx(&mut self) -> StmtContext {
+ self.stmt_ctx.take().unwrap()
+ }
+
+ /// Lowers a [`HirExpr`](HirExpr) which might produce a [`Expression`](Expression).
+ ///
+ /// consumes a [`StmtContext`](StmtContext) returning it to the context so
+ /// that it can be used again later.
+ pub fn lower(
+ &mut self,
+ mut stmt: StmtContext,
+ frontend: &mut Frontend,
+ expr: Handle<HirExpr>,
+ pos: ExprPos,
+ body: &mut Block,
+ ) -> Result<(Option<Handle<Expression>>, Span)> {
+ let res = self.lower_inner(&stmt, frontend, expr, pos, body);
+
+ stmt.hir_exprs.clear();
+ self.stmt_ctx = Some(stmt);
+
+ res
+ }
+
+ /// Similar to [`lower`](Self::lower) but returns an error if the expression
+ /// returns void (ie. doesn't produce a [`Expression`](Expression)).
+ ///
+ /// consumes a [`StmtContext`](StmtContext) returning it to the context so
+ /// that it can be used again later.
+ pub fn lower_expect(
+ &mut self,
+ mut stmt: StmtContext,
+ frontend: &mut Frontend,
+ expr: Handle<HirExpr>,
+ pos: ExprPos,
+ body: &mut Block,
+ ) -> Result<(Handle<Expression>, Span)> {
+ let res = self.lower_expect_inner(&stmt, frontend, expr, pos, body);
+
+ stmt.hir_exprs.clear();
+ self.stmt_ctx = Some(stmt);
+
+ res
+ }
+
+ /// internal implementation of [`lower_expect`](Self::lower_expect)
+ ///
+ /// this method is only public because it's used in
+ /// [`function_call`](Frontend::function_call), unless you know what
+ /// you're doing use [`lower_expect`](Self::lower_expect)
+ pub fn lower_expect_inner(
+ &mut self,
+ stmt: &StmtContext,
+ frontend: &mut Frontend,
+ expr: Handle<HirExpr>,
+ pos: ExprPos,
+ body: &mut Block,
+ ) -> Result<(Handle<Expression>, Span)> {
+ let (maybe_expr, meta) = self.lower_inner(stmt, frontend, expr, pos, body)?;
+
+ let expr = match maybe_expr {
+ Some(e) => e,
+ None => {
+ return Err(Error {
+ kind: ErrorKind::SemanticError("Expression returns void".into()),
+ meta,
+ })
+ }
+ };
+
+ Ok((expr, meta))
+ }
+
+ fn lower_store(
+ &mut self,
+ pointer: Handle<Expression>,
+ value: Handle<Expression>,
+ meta: Span,
+ body: &mut Block,
+ ) {
+ if let Expression::Swizzle {
+ size,
+ mut vector,
+ pattern,
+ } = self.expressions[pointer]
+ {
+ // Stores to swizzled values are not directly supported,
+ // lower them as series of per-component stores.
+ let size = match size {
+ VectorSize::Bi => 2,
+ VectorSize::Tri => 3,
+ VectorSize::Quad => 4,
+ };
+
+ if let Expression::Load { pointer } = self.expressions[vector] {
+ vector = pointer;
+ }
+
+ #[allow(clippy::needless_range_loop)]
+ for index in 0..size {
+ let dst = self.add_expression(
+ Expression::AccessIndex {
+ base: vector,
+ index: pattern[index].index(),
+ },
+ meta,
+ body,
+ );
+ let src = self.add_expression(
+ Expression::AccessIndex {
+ base: value,
+ index: index as u32,
+ },
+ meta,
+ body,
+ );
+
+ self.emit_restart(body);
+
+ body.push(
+ Statement::Store {
+ pointer: dst,
+ value: src,
+ },
+ meta,
+ );
+ }
+ } else {
+ self.emit_restart(body);
+
+ body.push(Statement::Store { pointer, value }, meta);
+ }
+ }
+
+ /// Internal implementation of [`lower`](Self::lower)
+ fn lower_inner(
+ &mut self,
+ stmt: &StmtContext,
+ frontend: &mut Frontend,
+ expr: Handle<HirExpr>,
+ pos: ExprPos,
+ body: &mut Block,
+ ) -> Result<(Option<Handle<Expression>>, Span)> {
+ let HirExpr { ref kind, meta } = stmt.hir_exprs[expr];
+
+ log::debug!("Lowering {:?} (kind {:?}, pos {:?})", expr, kind, pos);
+
+ let handle = match *kind {
+ HirExprKind::Access { base, index } => {
+ let (index, index_meta) =
+ self.lower_expect_inner(stmt, frontend, index, ExprPos::Rhs, body)?;
+ let maybe_constant_index = match pos {
+ // Don't try to generate `AccessIndex` if in a LHS position, since it
+ // wouldn't produce a pointer.
+ ExprPos::Lhs => None,
+ _ => frontend.solve_constant(self, index, index_meta).ok(),
+ };
+
+ let base = self
+ .lower_expect_inner(
+ stmt,
+ frontend,
+ base,
+ pos.maybe_access_base(maybe_constant_index.is_some()),
+ body,
+ )?
+ .0;
+
+ let pointer = maybe_constant_index
+ .and_then(|constant| {
+ Some(self.add_expression(
+ Expression::AccessIndex {
+ base,
+ index: match frontend.module.constants[constant].inner {
+ crate::ConstantInner::Scalar {
+ value: ScalarValue::Uint(i),
+ ..
+ } => u32::try_from(i).ok()?,
+ crate::ConstantInner::Scalar {
+ value: ScalarValue::Sint(i),
+ ..
+ } => u32::try_from(i).ok()?,
+ _ => return None,
+ },
+ },
+ meta,
+ body,
+ ))
+ })
+ .unwrap_or_else(|| {
+ self.add_expression(Expression::Access { base, index }, meta, body)
+ });
+
+ if ExprPos::Rhs == pos {
+ let resolved = frontend.resolve_type(self, pointer, meta)?;
+ if resolved.pointer_space().is_some() {
+ return Ok((
+ Some(self.add_expression(Expression::Load { pointer }, meta, body)),
+ meta,
+ ));
+ }
+ }
+
+ pointer
+ }
+ HirExprKind::Select { base, ref field } => {
+ let base = self.lower_expect_inner(stmt, frontend, base, pos, body)?.0;
+
+ frontend.field_selection(self, pos, body, base, field, meta)?
+ }
+ HirExprKind::Constant(constant) if pos != ExprPos::Lhs => {
+ self.add_expression(Expression::Constant(constant), meta, body)
+ }
+ HirExprKind::Binary { left, op, right } if pos != ExprPos::Lhs => {
+ let (mut left, left_meta) =
+ self.lower_expect_inner(stmt, frontend, left, ExprPos::Rhs, body)?;
+ let (mut right, right_meta) =
+ self.lower_expect_inner(stmt, frontend, right, ExprPos::Rhs, body)?;
+
+ match op {
+ BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight => self
+ .implicit_conversion(
+ frontend,
+ &mut right,
+ right_meta,
+ ScalarKind::Uint,
+ 4,
+ )?,
+ _ => self.binary_implicit_conversion(
+ frontend, &mut left, left_meta, &mut right, right_meta,
+ )?,
+ }
+
+ frontend.typifier_grow(self, left, left_meta)?;
+ frontend.typifier_grow(self, right, right_meta)?;
+
+ let left_inner = self.typifier.get(left, &frontend.module.types);
+ let right_inner = self.typifier.get(right, &frontend.module.types);
+
+ match (left_inner, right_inner) {
+ (
+ &TypeInner::Matrix {
+ columns: left_columns,
+ rows: left_rows,
+ width: left_width,
+ },
+ &TypeInner::Matrix {
+ columns: right_columns,
+ rows: right_rows,
+ width: right_width,
+ },
+ ) => {
+ let dimensions_ok = if op == BinaryOperator::Multiply {
+ left_columns == right_rows
+ } else {
+ left_columns == right_columns && left_rows == right_rows
+ };
+
+ // Check that the two arguments have the same dimensions
+ if !dimensions_ok || left_width != right_width {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ format!(
+ "Cannot apply operation to {left_inner:?} and {right_inner:?}"
+ )
+ .into(),
+ ),
+ meta,
+ })
+ }
+
+ match op {
+ BinaryOperator::Divide => {
+ // Naga IR doesn't support matrix division so we need to
+ // divide the columns individually and reassemble the matrix
+ let mut components = Vec::with_capacity(left_columns as usize);
+
+ for index in 0..left_columns as u32 {
+ // Get the column vectors
+ let left_vector = self.add_expression(
+ Expression::AccessIndex { base: left, index },
+ meta,
+ body,
+ );
+ let right_vector = self.add_expression(
+ Expression::AccessIndex { base: right, index },
+ meta,
+ body,
+ );
+
+ // Divide the vectors
+ let column = self.expressions.append(
+ Expression::Binary {
+ op,
+ left: left_vector,
+ right: right_vector,
+ },
+ meta,
+ );
+
+ components.push(column)
+ }
+
+ // Rebuild the matrix from the divided vectors
+ self.expressions.append(
+ Expression::Compose {
+ ty: frontend.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Matrix {
+ columns: left_columns,
+ rows: left_rows,
+ width: left_width,
+ },
+ },
+ Span::default(),
+ ),
+ components,
+ },
+ meta,
+ )
+ }
+ BinaryOperator::Equal | BinaryOperator::NotEqual => {
+ // Naga IR doesn't support matrix comparisons so we need to
+ // compare the columns individually and then fold them together
+ //
+ // The folding is done using a logical and for equality and
+ // a logical or for inequality
+ let equals = op == BinaryOperator::Equal;
+
+ let (op, combine, fun) = match equals {
+ true => (
+ BinaryOperator::Equal,
+ BinaryOperator::LogicalAnd,
+ RelationalFunction::All,
+ ),
+ false => (
+ BinaryOperator::NotEqual,
+ BinaryOperator::LogicalOr,
+ RelationalFunction::Any,
+ ),
+ };
+
+ let mut root = None;
+
+ for index in 0..left_columns as u32 {
+ // Get the column vectors
+ let left_vector = self.add_expression(
+ Expression::AccessIndex { base: left, index },
+ meta,
+ body,
+ );
+ let right_vector = self.add_expression(
+ Expression::AccessIndex { base: right, index },
+ meta,
+ body,
+ );
+
+ let argument = self.expressions.append(
+ Expression::Binary {
+ op,
+ left: left_vector,
+ right: right_vector,
+ },
+ meta,
+ );
+
+ // The result of comparing two vectors is a boolean vector
+ // so use a relational function like all to get a single
+ // boolean value
+ let compare = self.add_expression(
+ Expression::Relational { fun, argument },
+ meta,
+ body,
+ );
+
+ // Fold the result
+ root = Some(match root {
+ Some(right) => self.add_expression(
+ Expression::Binary {
+ op: combine,
+ left: compare,
+ right,
+ },
+ meta,
+ body,
+ ),
+ None => compare,
+ });
+ }
+
+ root.unwrap()
+ }
+ _ => self.add_expression(
+ Expression::Binary { left, op, right },
+ meta,
+ body,
+ ),
+ }
+ }
+ (&TypeInner::Vector { .. }, &TypeInner::Vector { .. }) => match op {
+ BinaryOperator::Equal | BinaryOperator::NotEqual => {
+ let equals = op == BinaryOperator::Equal;
+
+ let (op, fun) = match equals {
+ true => (BinaryOperator::Equal, RelationalFunction::All),
+ false => (BinaryOperator::NotEqual, RelationalFunction::Any),
+ };
+
+ let argument = self
+ .expressions
+ .append(Expression::Binary { op, left, right }, meta);
+
+ self.add_expression(
+ Expression::Relational { fun, argument },
+ meta,
+ body,
+ )
+ }
+ _ => {
+ self.add_expression(Expression::Binary { left, op, right }, meta, body)
+ }
+ },
+ (&TypeInner::Vector { size, .. }, &TypeInner::Scalar { .. }) => match op {
+ BinaryOperator::Add
+ | BinaryOperator::Subtract
+ | BinaryOperator::Divide
+ | BinaryOperator::And
+ | BinaryOperator::ExclusiveOr
+ | BinaryOperator::InclusiveOr
+ | BinaryOperator::ShiftLeft
+ | BinaryOperator::ShiftRight => {
+ let scalar_vector = self.add_expression(
+ Expression::Splat { size, value: right },
+ meta,
+ body,
+ );
+
+ self.add_expression(
+ Expression::Binary {
+ op,
+ left,
+ right: scalar_vector,
+ },
+ meta,
+ body,
+ )
+ }
+ _ => {
+ self.add_expression(Expression::Binary { left, op, right }, meta, body)
+ }
+ },
+ (&TypeInner::Scalar { .. }, &TypeInner::Vector { size, .. }) => match op {
+ BinaryOperator::Add
+ | BinaryOperator::Subtract
+ | BinaryOperator::Divide
+ | BinaryOperator::And
+ | BinaryOperator::ExclusiveOr
+ | BinaryOperator::InclusiveOr => {
+ let scalar_vector = self.add_expression(
+ Expression::Splat { size, value: left },
+ meta,
+ body,
+ );
+
+ self.add_expression(
+ Expression::Binary {
+ op,
+ left: scalar_vector,
+ right,
+ },
+ meta,
+ body,
+ )
+ }
+ _ => {
+ self.add_expression(Expression::Binary { left, op, right }, meta, body)
+ }
+ },
+ (
+ &TypeInner::Scalar {
+ width: left_width, ..
+ },
+ &TypeInner::Matrix {
+ rows,
+ columns,
+ width: right_width,
+ },
+ ) => {
+ // Check that the two arguments have the same width
+ if left_width != right_width {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ format!(
+ "Cannot apply operation to {left_inner:?} and {right_inner:?}"
+ )
+ .into(),
+ ),
+ meta,
+ })
+ }
+
+ match op {
+ BinaryOperator::Divide
+ | BinaryOperator::Add
+ | BinaryOperator::Subtract => {
+ // Naga IR doesn't support all matrix by scalar operations so
+ // we need for some to turn the scalar into a vector by
+ // splatting it and then for each column vector apply the
+ // operation and finally reconstruct the matrix
+ let scalar_vector = self.add_expression(
+ Expression::Splat {
+ size: rows,
+ value: left,
+ },
+ meta,
+ body,
+ );
+
+ let mut components = Vec::with_capacity(columns as usize);
+
+ for index in 0..columns as u32 {
+ // Get the column vector
+ let matrix_column = self.add_expression(
+ Expression::AccessIndex { base: right, index },
+ meta,
+ body,
+ );
+
+ // Apply the operation to the splatted vector and
+ // the column vector
+ let column = self.expressions.append(
+ Expression::Binary {
+ op,
+ left: scalar_vector,
+ right: matrix_column,
+ },
+ meta,
+ );
+
+ components.push(column)
+ }
+
+ // Rebuild the matrix from the operation result vectors
+ self.expressions.append(
+ Expression::Compose {
+ ty: frontend.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Matrix {
+ columns,
+ rows,
+ width: left_width,
+ },
+ },
+ Span::default(),
+ ),
+ components,
+ },
+ meta,
+ )
+ }
+ _ => self.add_expression(
+ Expression::Binary { left, op, right },
+ meta,
+ body,
+ ),
+ }
+ }
+ (
+ &TypeInner::Matrix {
+ rows,
+ columns,
+ width: left_width,
+ },
+ &TypeInner::Scalar {
+ width: right_width, ..
+ },
+ ) => {
+ // Check that the two arguments have the same width
+ if left_width != right_width {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ format!(
+ "Cannot apply operation to {left_inner:?} and {right_inner:?}"
+ )
+ .into(),
+ ),
+ meta,
+ })
+ }
+
+ match op {
+ BinaryOperator::Divide
+ | BinaryOperator::Add
+ | BinaryOperator::Subtract => {
+ // Naga IR doesn't support all matrix by scalar operations so
+ // we need for some to turn the scalar into a vector by
+ // splatting it and then for each column vector apply the
+ // operation and finally reconstruct the matrix
+
+ let scalar_vector = self.add_expression(
+ Expression::Splat {
+ size: rows,
+ value: right,
+ },
+ meta,
+ body,
+ );
+
+ let mut components = Vec::with_capacity(columns as usize);
+
+ for index in 0..columns as u32 {
+ // Get the column vector
+ let matrix_column = self.add_expression(
+ Expression::AccessIndex { base: left, index },
+ meta,
+ body,
+ );
+
+ // Apply the operation to the splatted vector and
+ // the column vector
+ let column = self.expressions.append(
+ Expression::Binary {
+ op,
+ left: matrix_column,
+ right: scalar_vector,
+ },
+ meta,
+ );
+
+ components.push(column)
+ }
+
+ // Rebuild the matrix from the operation result vectors
+ self.expressions.append(
+ Expression::Compose {
+ ty: frontend.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Matrix {
+ columns,
+ rows,
+ width: left_width,
+ },
+ },
+ Span::default(),
+ ),
+ components,
+ },
+ meta,
+ )
+ }
+ _ => self.add_expression(
+ Expression::Binary { left, op, right },
+ meta,
+ body,
+ ),
+ }
+ }
+ _ => self.add_expression(Expression::Binary { left, op, right }, meta, body),
+ }
+ }
+ HirExprKind::Unary { op, expr } if pos != ExprPos::Lhs => {
+ let expr = self
+ .lower_expect_inner(stmt, frontend, expr, ExprPos::Rhs, body)?
+ .0;
+
+ self.add_expression(Expression::Unary { op, expr }, meta, body)
+ }
+ HirExprKind::Variable(ref var) => match pos {
+ ExprPos::Lhs => {
+ if !var.mutable {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "Variable cannot be used in LHS position".into(),
+ ),
+ meta,
+ })
+ }
+
+ var.expr
+ }
+ ExprPos::AccessBase { constant_index } => {
+ // If the index isn't constant all accesses backed by a constant base need
+ // to be done through a proxy local variable, since constants have a non
+ // pointer type which is required for dynamic indexing
+ if !constant_index {
+ if let Some((constant, ty)) = var.constant {
+ let local = self.locals.append(
+ LocalVariable {
+ name: None,
+ ty,
+ init: Some(constant),
+ },
+ Span::default(),
+ );
+
+ self.add_expression(
+ Expression::LocalVariable(local),
+ Span::default(),
+ body,
+ )
+ } else {
+ var.expr
+ }
+ } else {
+ var.expr
+ }
+ }
+ _ if var.load => {
+ self.add_expression(Expression::Load { pointer: var.expr }, meta, body)
+ }
+ ExprPos::Rhs => var.expr,
+ },
+ HirExprKind::Call(ref call) if pos != ExprPos::Lhs => {
+ let maybe_expr = frontend.function_or_constructor_call(
+ self,
+ stmt,
+ body,
+ call.kind.clone(),
+ &call.args,
+ meta,
+ )?;
+ return Ok((maybe_expr, meta));
+ }
+ // `HirExprKind::Conditional` represents the ternary operator in glsl (`:?`)
+ //
+ // The ternary operator is defined to only evaluate one of the two possible
+ // expressions which means that it's behavior is that of an `if` statement,
+ // and it's merely syntatic sugar for it.
+ HirExprKind::Conditional {
+ condition,
+ accept,
+ reject,
+ } if ExprPos::Lhs != pos => {
+ // Given an expression `a ? b : c`, we need to produce a Naga
+ // statement roughly like:
+ //
+ // var temp;
+ // if a {
+ // temp = convert(b);
+ // } else {
+ // temp = convert(c);
+ // }
+ //
+ // where `convert` stands for type conversions to bring `b` and `c` to
+ // the same type, and then use `temp` to represent the value of the whole
+ // conditional expression in subsequent code.
+
+ // Lower the condition first to the current bodyy
+ let condition = self
+ .lower_expect_inner(stmt, frontend, condition, ExprPos::Rhs, body)?
+ .0;
+
+ // Emit all expressions since we will be adding statements to
+ // other bodies next
+ self.emit_restart(body);
+
+ // Create the bodies for the two cases
+ let mut accept_body = Block::new();
+ let mut reject_body = Block::new();
+
+ // Lower the `true` branch
+ let (mut accept, accept_meta) =
+ self.lower_expect_inner(stmt, frontend, accept, pos, &mut accept_body)?;
+
+ // Flush the body of the `true` branch, to start emitting on the
+ // `false` branch
+ self.emit_restart(&mut accept_body);
+
+ // Lower the `false` branch
+ let (mut reject, reject_meta) =
+ self.lower_expect_inner(stmt, frontend, reject, pos, &mut reject_body)?;
+
+ // Flush the body of the `false` branch
+ self.emit_restart(&mut reject_body);
+
+ // We need to do some custom implicit conversions since the two target expressions
+ // are in different bodies
+ if let (
+ Some((accept_power, accept_width, accept_kind)),
+ Some((reject_power, reject_width, reject_kind)),
+ ) = (
+ // Get the components of both branches and calculate the type power
+ self.expr_scalar_components(frontend, accept, accept_meta)?
+ .and_then(|(kind, width)| Some((type_power(kind, width)?, width, kind))),
+ self.expr_scalar_components(frontend, reject, reject_meta)?
+ .and_then(|(kind, width)| Some((type_power(kind, width)?, width, kind))),
+ ) {
+ match accept_power.cmp(&reject_power) {
+ std::cmp::Ordering::Less => {
+ self.conversion(&mut accept, accept_meta, reject_kind, reject_width)?;
+ // The expression belongs to the `true` branch so we need to flush to
+ // the respective body
+ self.emit_end(&mut accept_body);
+ }
+ // Technically there's nothing to flush but later we will need to
+ // add some expressions that must not be emitted so instead
+ // of flushing, starting and flushing again, just make sure
+ // everything is flushed.
+ std::cmp::Ordering::Equal => self.emit_end(body),
+ std::cmp::Ordering::Greater => {
+ self.conversion(&mut reject, reject_meta, accept_kind, accept_width)?;
+ // The expression belongs to the `false` branch so we need to flush to
+ // the respective body
+ self.emit_end(&mut reject_body);
+ }
+ }
+ } else {
+ // Technically there's nothing to flush but later we will need to
+ // add some expressions that must not be emitted.
+ self.emit_end(body)
+ }
+
+ // We need to get the type of the resulting expression to create the local,
+ // this must be done after implicit conversions to ensure both branches have
+ // the same type.
+ let ty = frontend.resolve_type_handle(self, accept, accept_meta)?;
+
+ // Add the local that will hold the result of our conditional
+ let local = self.locals.append(
+ LocalVariable {
+ name: None,
+ ty,
+ init: None,
+ },
+ meta,
+ );
+
+ // Note: `Expression::LocalVariable` must not be emited so it's important
+ // that at this point the emitter is flushed but not started.
+ let local_expr = self
+ .expressions
+ .append(Expression::LocalVariable(local), meta);
+
+ // Add to each body the store to the result variable
+ accept_body.push(
+ Statement::Store {
+ pointer: local_expr,
+ value: accept,
+ },
+ accept_meta,
+ );
+ reject_body.push(
+ Statement::Store {
+ pointer: local_expr,
+ value: reject,
+ },
+ reject_meta,
+ );
+
+ // Finally add the `If` to the main body with the `condition` we lowered
+ // earlier and the branches we prepared.
+ body.push(
+ Statement::If {
+ condition,
+ accept: accept_body,
+ reject: reject_body,
+ },
+ meta,
+ );
+
+ // Restart the emitter
+ self.emit_start();
+
+ // Note: `Expression::Load` must be emited before it's used so make
+ // sure the emitter is active here.
+ self.expressions.append(
+ Expression::Load {
+ pointer: local_expr,
+ },
+ meta,
+ )
+ }
+ HirExprKind::Assign { tgt, value } if ExprPos::Lhs != pos => {
+ let (pointer, ptr_meta) =
+ self.lower_expect_inner(stmt, frontend, tgt, ExprPos::Lhs, body)?;
+ let (mut value, value_meta) =
+ self.lower_expect_inner(stmt, frontend, value, ExprPos::Rhs, body)?;
+
+ let ty = match *frontend.resolve_type(self, pointer, ptr_meta)? {
+ TypeInner::Pointer { base, .. } => &frontend.module.types[base].inner,
+ ref ty => ty,
+ };
+
+ if let Some((kind, width)) = scalar_components(ty) {
+ self.implicit_conversion(frontend, &mut value, value_meta, kind, width)?;
+ }
+
+ self.lower_store(pointer, value, meta, body);
+
+ value
+ }
+ HirExprKind::PrePostfix { op, postfix, expr } if ExprPos::Lhs != pos => {
+ let (pointer, _) =
+ self.lower_expect_inner(stmt, frontend, expr, ExprPos::Lhs, body)?;
+ let left = if let Expression::Swizzle { .. } = self.expressions[pointer] {
+ pointer
+ } else {
+ self.add_expression(Expression::Load { pointer }, meta, body)
+ };
+
+ let make_constant_inner = |kind, width| {
+ let value = match kind {
+ ScalarKind::Sint => crate::ScalarValue::Sint(1),
+ ScalarKind::Uint => crate::ScalarValue::Uint(1),
+ ScalarKind::Float => crate::ScalarValue::Float(1.0),
+ ScalarKind::Bool => return None,
+ };
+
+ Some(crate::ConstantInner::Scalar { width, value })
+ };
+ let res = match *frontend.resolve_type(self, left, meta)? {
+ TypeInner::Scalar { kind, width } => {
+ let ty = TypeInner::Scalar { kind, width };
+ make_constant_inner(kind, width).map(|i| (ty, i, None, None))
+ }
+ TypeInner::Vector { size, kind, width } => {
+ let ty = TypeInner::Vector { size, kind, width };
+ make_constant_inner(kind, width).map(|i| (ty, i, Some(size), None))
+ }
+ TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ let ty = TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ };
+ make_constant_inner(ScalarKind::Float, width)
+ .map(|i| (ty, i, Some(rows), Some(columns)))
+ }
+ _ => None,
+ };
+ let (ty_inner, inner, rows, columns) = match res {
+ Some(res) => res,
+ None => {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "Increment/decrement only works on scalar/vector/matrix".into(),
+ ),
+ meta,
+ });
+ return Ok((Some(left), meta));
+ }
+ };
+
+ let constant_1 = frontend.module.constants.append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner,
+ },
+ Default::default(),
+ );
+ let mut right = self.add_expression(Expression::Constant(constant_1), meta, body);
+
+ // Glsl allows pre/postfixes operations on vectors and matrices, so if the
+ // target is either of them change the right side of the addition to be splatted
+ // to the same size as the target, furthermore if the target is a matrix
+ // use a composed matrix using the splatted value.
+ if let Some(size) = rows {
+ right =
+ self.add_expression(Expression::Splat { size, value: right }, meta, body);
+
+ if let Some(cols) = columns {
+ let ty = frontend.module.types.insert(
+ Type {
+ name: None,
+ inner: ty_inner,
+ },
+ meta,
+ );
+
+ right = self.add_expression(
+ Expression::Compose {
+ ty,
+ components: std::iter::repeat(right).take(cols as usize).collect(),
+ },
+ meta,
+ body,
+ );
+ }
+ }
+
+ let value = self.add_expression(Expression::Binary { op, left, right }, meta, body);
+
+ self.lower_store(pointer, value, meta, body);
+
+ if postfix {
+ left
+ } else {
+ value
+ }
+ }
+ HirExprKind::Method {
+ expr: object,
+ ref name,
+ ref args,
+ } if ExprPos::Lhs != pos => {
+ let args = args
+ .iter()
+ .map(|e| self.lower_expect_inner(stmt, frontend, *e, ExprPos::Rhs, body))
+ .collect::<Result<Vec<_>>>()?;
+ match name.as_ref() {
+ "length" => {
+ if !args.is_empty() {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ ".length() doesn't take any arguments".into(),
+ ),
+ meta,
+ });
+ }
+ let lowered_array = self
+ .lower_expect_inner(stmt, frontend, object, pos, body)?
+ .0;
+ let array_type = frontend.resolve_type(self, lowered_array, meta)?;
+
+ match *array_type {
+ TypeInner::Array {
+ size: crate::ArraySize::Constant(size),
+ ..
+ } => {
+ let mut array_length =
+ self.add_expression(Expression::Constant(size), meta, body);
+ self.forced_conversion(
+ frontend,
+ &mut array_length,
+ meta,
+ ScalarKind::Sint,
+ 4,
+ )?;
+ array_length
+ }
+ // let the error be handled in type checking if it's not a dynamic array
+ _ => {
+ let mut array_length = self.add_expression(
+ Expression::ArrayLength(lowered_array),
+ meta,
+ body,
+ );
+ self.conversion(&mut array_length, meta, ScalarKind::Sint, 4)?;
+ array_length
+ }
+ }
+ }
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::SemanticError(
+ format!("unknown method '{name}'").into(),
+ ),
+ meta,
+ });
+ }
+ }
+ }
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::SemanticError(
+ format!("{:?} cannot be in the left hand side", stmt.hir_exprs[expr])
+ .into(),
+ ),
+ meta,
+ })
+ }
+ };
+
+ log::trace!(
+ "Lowered {:?}\n\tKind = {:?}\n\tPos = {:?}\n\tResult = {:?}",
+ expr,
+ kind,
+ pos,
+ handle
+ );
+
+ Ok((Some(handle), meta))
+ }
+
+ pub fn expr_scalar_components(
+ &mut self,
+ frontend: &Frontend,
+ expr: Handle<Expression>,
+ meta: Span,
+ ) -> Result<Option<(ScalarKind, crate::Bytes)>> {
+ let ty = frontend.resolve_type(self, expr, meta)?;
+ Ok(scalar_components(ty))
+ }
+
+ pub fn expr_power(
+ &mut self,
+ frontend: &Frontend,
+ expr: Handle<Expression>,
+ meta: Span,
+ ) -> Result<Option<u32>> {
+ Ok(self
+ .expr_scalar_components(frontend, expr, meta)?
+ .and_then(|(kind, width)| type_power(kind, width)))
+ }
+
+ pub fn conversion(
+ &mut self,
+ expr: &mut Handle<Expression>,
+ meta: Span,
+ kind: ScalarKind,
+ width: crate::Bytes,
+ ) -> Result<()> {
+ *expr = self.expressions.append(
+ Expression::As {
+ expr: *expr,
+ kind,
+ convert: Some(width),
+ },
+ meta,
+ );
+
+ Ok(())
+ }
+
+ pub fn implicit_conversion(
+ &mut self,
+ frontend: &Frontend,
+ expr: &mut Handle<Expression>,
+ meta: Span,
+ kind: ScalarKind,
+ width: crate::Bytes,
+ ) -> Result<()> {
+ if let (Some(tgt_power), Some(expr_power)) = (
+ type_power(kind, width),
+ self.expr_power(frontend, *expr, meta)?,
+ ) {
+ if tgt_power > expr_power {
+ self.conversion(expr, meta, kind, width)?;
+ }
+ }
+
+ Ok(())
+ }
+
+ pub fn forced_conversion(
+ &mut self,
+ frontend: &Frontend,
+ expr: &mut Handle<Expression>,
+ meta: Span,
+ kind: ScalarKind,
+ width: crate::Bytes,
+ ) -> Result<()> {
+ if let Some((expr_scalar_kind, expr_width)) =
+ self.expr_scalar_components(frontend, *expr, meta)?
+ {
+ if expr_scalar_kind != kind || expr_width != width {
+ self.conversion(expr, meta, kind, width)?;
+ }
+ }
+
+ Ok(())
+ }
+
+ pub fn binary_implicit_conversion(
+ &mut self,
+ frontend: &Frontend,
+ left: &mut Handle<Expression>,
+ left_meta: Span,
+ right: &mut Handle<Expression>,
+ right_meta: Span,
+ ) -> Result<()> {
+ let left_components = self.expr_scalar_components(frontend, *left, left_meta)?;
+ let right_components = self.expr_scalar_components(frontend, *right, right_meta)?;
+
+ if let (
+ Some((left_power, left_width, left_kind)),
+ Some((right_power, right_width, right_kind)),
+ ) = (
+ left_components.and_then(|(kind, width)| Some((type_power(kind, width)?, width, kind))),
+ right_components
+ .and_then(|(kind, width)| Some((type_power(kind, width)?, width, kind))),
+ ) {
+ match left_power.cmp(&right_power) {
+ std::cmp::Ordering::Less => {
+ self.conversion(left, left_meta, right_kind, right_width)?;
+ }
+ std::cmp::Ordering::Equal => {}
+ std::cmp::Ordering::Greater => {
+ self.conversion(right, right_meta, left_kind, left_width)?;
+ }
+ }
+ }
+
+ Ok(())
+ }
+
+ pub fn implicit_splat(
+ &mut self,
+ frontend: &Frontend,
+ expr: &mut Handle<Expression>,
+ meta: Span,
+ vector_size: Option<VectorSize>,
+ ) -> Result<()> {
+ let expr_type = frontend.resolve_type(self, *expr, meta)?;
+
+ if let (&TypeInner::Scalar { .. }, Some(size)) = (expr_type, vector_size) {
+ *expr = self
+ .expressions
+ .append(Expression::Splat { size, value: *expr }, meta)
+ }
+
+ Ok(())
+ }
+
+ pub fn vector_resize(
+ &mut self,
+ size: VectorSize,
+ vector: Handle<Expression>,
+ meta: Span,
+ body: &mut Block,
+ ) -> Handle<Expression> {
+ self.add_expression(
+ Expression::Swizzle {
+ size,
+ vector,
+ pattern: crate::SwizzleComponent::XYZW,
+ },
+ meta,
+ body,
+ )
+ }
+}
+
+impl Index<Handle<Expression>> for Context {
+ type Output = Expression;
+
+ fn index(&self, index: Handle<Expression>) -> &Self::Output {
+ &self.expressions[index]
+ }
+}
+
+/// Helper struct passed when parsing expressions
+///
+/// This struct should only be obtained through [`stmt_ctx`](Context::stmt_ctx)
+/// and only one of these may be active at any time per context.
+#[derive(Debug)]
+pub struct StmtContext {
+ /// A arena of high level expressions which can be lowered through a
+ /// [`Context`](Context) to naga's [`Expression`](crate::Expression)s
+ pub hir_exprs: Arena<HirExpr>,
+}
+
+impl StmtContext {
+ const fn new() -> Self {
+ StmtContext {
+ hir_exprs: Arena::new(),
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/front/glsl/error.rs b/third_party/rust/naga/src/front/glsl/error.rs
new file mode 100644
index 0000000000..46b3aecd08
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/error.rs
@@ -0,0 +1,134 @@
+use super::{constants::ConstantSolvingError, token::TokenValue};
+use crate::Span;
+use pp_rs::token::PreprocessorError;
+use std::borrow::Cow;
+use thiserror::Error;
+
+fn join_with_comma(list: &[ExpectedToken]) -> String {
+ let mut string = "".to_string();
+ for (i, val) in list.iter().enumerate() {
+ string.push_str(&val.to_string());
+ match i {
+ i if i == list.len() - 1 => {}
+ i if i == list.len() - 2 => string.push_str(" or "),
+ _ => string.push_str(", "),
+ }
+ }
+ string
+}
+
+/// One of the expected tokens returned in [`InvalidToken`](ErrorKind::InvalidToken).
+#[derive(Debug, PartialEq)]
+pub enum ExpectedToken {
+ /// A specific token was expected.
+ Token(TokenValue),
+ /// A type was expected.
+ TypeName,
+ /// An identifier was expected.
+ Identifier,
+ /// An integer literal was expected.
+ IntLiteral,
+ /// A float literal was expected.
+ FloatLiteral,
+ /// A boolean literal was expected.
+ BoolLiteral,
+ /// The end of file was expected.
+ Eof,
+}
+impl From<TokenValue> for ExpectedToken {
+ fn from(token: TokenValue) -> Self {
+ ExpectedToken::Token(token)
+ }
+}
+impl std::fmt::Display for ExpectedToken {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match *self {
+ ExpectedToken::Token(ref token) => write!(f, "{token:?}"),
+ ExpectedToken::TypeName => write!(f, "a type"),
+ ExpectedToken::Identifier => write!(f, "identifier"),
+ ExpectedToken::IntLiteral => write!(f, "integer literal"),
+ ExpectedToken::FloatLiteral => write!(f, "float literal"),
+ ExpectedToken::BoolLiteral => write!(f, "bool literal"),
+ ExpectedToken::Eof => write!(f, "end of file"),
+ }
+ }
+}
+
+/// Information about the cause of an error.
+#[derive(Debug, Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum ErrorKind {
+ /// Whilst parsing as encountered an unexpected EOF.
+ #[error("Unexpected end of file")]
+ EndOfFile,
+ /// The shader specified an unsupported or invalid profile.
+ #[error("Invalid profile: {0}")]
+ InvalidProfile(String),
+ /// The shader requested an unsupported or invalid version.
+ #[error("Invalid version: {0}")]
+ InvalidVersion(u64),
+ /// Whilst parsing an unexpected token was encountered.
+ ///
+ /// A list of expected tokens is also returned.
+ #[error("Expected {}, found {0:?}", join_with_comma(.1))]
+ InvalidToken(TokenValue, Vec<ExpectedToken>),
+ /// A specific feature is not yet implemented.
+ ///
+ /// To help prioritize work please open an issue in the github issue tracker
+ /// if none exist already or react to the already existing one.
+ #[error("Not implemented: {0}")]
+ NotImplemented(&'static str),
+ /// A reference to a variable that wasn't declared was used.
+ #[error("Unknown variable: {0}")]
+ UnknownVariable(String),
+ /// A reference to a type that wasn't declared was used.
+ #[error("Unknown type: {0}")]
+ UnknownType(String),
+ /// A reference to a non existent member of a type was made.
+ #[error("Unknown field: {0}")]
+ UnknownField(String),
+ /// An unknown layout qualifier was used.
+ ///
+ /// If the qualifier does exist please open an issue in the github issue tracker
+ /// if none exist already or react to the already existing one to help
+ /// prioritize work.
+ #[error("Unknown layout qualifier: {0}")]
+ UnknownLayoutQualifier(String),
+ /// Unsupported matrix of the form matCx2
+ ///
+ /// Our IR expects matrices of the form matCx2 to have a stride of 8 however
+ /// matrices in the std140 layout have a stride of at least 16
+ #[error("unsupported matrix of the form matCx2 in std140 block layout")]
+ UnsupportedMatrixTypeInStd140,
+ /// A variable with the same name already exists in the current scope.
+ #[error("Variable already declared: {0}")]
+ VariableAlreadyDeclared(String),
+ /// A semantic error was detected in the shader.
+ #[error("{0}")]
+ SemanticError(Cow<'static, str>),
+ /// An error was returned by the preprocessor.
+ #[error("{0:?}")]
+ PreprocessorError(PreprocessorError),
+ /// The parser entered an illegal state and exited
+ ///
+ /// This obviously is a bug and as such should be reported in the github issue tracker
+ #[error("Internal error: {0}")]
+ InternalError(&'static str),
+}
+
+impl From<ConstantSolvingError> for ErrorKind {
+ fn from(err: ConstantSolvingError) -> Self {
+ ErrorKind::SemanticError(err.to_string().into())
+ }
+}
+
+/// Error returned during shader parsing.
+#[derive(Debug, Error)]
+#[error("{kind}")]
+#[cfg_attr(test, derive(PartialEq))]
+pub struct Error {
+ /// Holds the information about the error itself.
+ pub kind: ErrorKind,
+ /// Holds information about the range of the source code where the error happened.
+ pub meta: Span,
+}
diff --git a/third_party/rust/naga/src/front/glsl/functions.rs b/third_party/rust/naga/src/front/glsl/functions.rs
new file mode 100644
index 0000000000..10c964b5e0
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/functions.rs
@@ -0,0 +1,1680 @@
+use super::{
+ ast::*,
+ builtins::{inject_builtin, sampled_to_depth},
+ context::{Context, ExprPos, StmtContext},
+ error::{Error, ErrorKind},
+ types::scalar_components,
+ Frontend, Result,
+};
+use crate::{
+ front::glsl::types::type_power, proc::ensure_block_returns, AddressSpace, Arena, Block,
+ Constant, ConstantInner, EntryPoint, Expression, Function, FunctionArgument, FunctionResult,
+ Handle, LocalVariable, ScalarKind, ScalarValue, Span, Statement, StructMember, Type, TypeInner,
+};
+use std::iter;
+
+/// Struct detailing a store operation that must happen after a function call
+struct ProxyWrite {
+ /// The store target
+ target: Handle<Expression>,
+ /// A pointer to read the value of the store
+ value: Handle<Expression>,
+ /// An optional conversion to be applied
+ convert: Option<(ScalarKind, crate::Bytes)>,
+}
+
+impl Frontend {
+ fn add_constant_value(
+ &mut self,
+ scalar_kind: ScalarKind,
+ value: u64,
+ meta: Span,
+ ) -> Handle<Constant> {
+ let value = match scalar_kind {
+ ScalarKind::Uint => ScalarValue::Uint(value),
+ ScalarKind::Sint => ScalarValue::Sint(value as i64),
+ ScalarKind::Float => ScalarValue::Float(value as f64),
+ ScalarKind::Bool => unreachable!(),
+ };
+
+ self.module.constants.fetch_or_append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner: ConstantInner::Scalar { width: 4, value },
+ },
+ meta,
+ )
+ }
+
+ pub(crate) fn function_or_constructor_call(
+ &mut self,
+ ctx: &mut Context,
+ stmt: &StmtContext,
+ body: &mut Block,
+ fc: FunctionCallKind,
+ raw_args: &[Handle<HirExpr>],
+ meta: Span,
+ ) -> Result<Option<Handle<Expression>>> {
+ let args: Vec<_> = raw_args
+ .iter()
+ .map(|e| ctx.lower_expect_inner(stmt, self, *e, ExprPos::Rhs, body))
+ .collect::<Result<_>>()?;
+
+ match fc {
+ FunctionCallKind::TypeConstructor(ty) => {
+ if args.len() == 1 {
+ self.constructor_single(ctx, body, ty, args[0], meta)
+ .map(Some)
+ } else {
+ self.constructor_many(ctx, body, ty, args, meta).map(Some)
+ }
+ }
+ FunctionCallKind::Function(name) => {
+ self.function_call(ctx, stmt, body, name, args, raw_args, meta)
+ }
+ }
+ }
+
+ fn constructor_single(
+ &mut self,
+ ctx: &mut Context,
+ body: &mut Block,
+ ty: Handle<Type>,
+ (mut value, expr_meta): (Handle<Expression>, Span),
+ meta: Span,
+ ) -> Result<Handle<Expression>> {
+ let expr_type = self.resolve_type(ctx, value, expr_meta)?;
+
+ let vector_size = match *expr_type {
+ TypeInner::Vector { size, .. } => Some(size),
+ _ => None,
+ };
+
+ // Special case: if casting from a bool, we need to use Select and not As.
+ match self.module.types[ty].inner.scalar_kind() {
+ Some(result_scalar_kind)
+ if expr_type.scalar_kind() == Some(ScalarKind::Bool)
+ && result_scalar_kind != ScalarKind::Bool =>
+ {
+ let c0 = self.add_constant_value(result_scalar_kind, 0u64, meta);
+ let c1 = self.add_constant_value(result_scalar_kind, 1u64, meta);
+ let mut reject = ctx.add_expression(Expression::Constant(c0), expr_meta, body);
+ let mut accept = ctx.add_expression(Expression::Constant(c1), expr_meta, body);
+
+ ctx.implicit_splat(self, &mut reject, meta, vector_size)?;
+ ctx.implicit_splat(self, &mut accept, meta, vector_size)?;
+
+ let h = ctx.add_expression(
+ Expression::Select {
+ accept,
+ reject,
+ condition: value,
+ },
+ expr_meta,
+ body,
+ );
+
+ return Ok(h);
+ }
+ _ => {}
+ }
+
+ Ok(match self.module.types[ty].inner {
+ TypeInner::Vector { size, kind, width } if vector_size.is_none() => {
+ ctx.forced_conversion(self, &mut value, expr_meta, kind, width)?;
+
+ if let TypeInner::Scalar { .. } = *self.resolve_type(ctx, value, expr_meta)? {
+ ctx.add_expression(Expression::Splat { size, value }, meta, body)
+ } else {
+ self.vector_constructor(
+ ctx,
+ body,
+ ty,
+ size,
+ kind,
+ width,
+ &[(value, expr_meta)],
+ meta,
+ )?
+ }
+ }
+ TypeInner::Scalar { kind, width } => {
+ let mut expr = value;
+ if let TypeInner::Vector { .. } | TypeInner::Matrix { .. } =
+ *self.resolve_type(ctx, value, expr_meta)?
+ {
+ expr = ctx.add_expression(
+ Expression::AccessIndex {
+ base: expr,
+ index: 0,
+ },
+ meta,
+ body,
+ );
+ }
+
+ if let TypeInner::Matrix { .. } = *self.resolve_type(ctx, value, expr_meta)? {
+ expr = ctx.add_expression(
+ Expression::AccessIndex {
+ base: expr,
+ index: 0,
+ },
+ meta,
+ body,
+ );
+ }
+
+ ctx.add_expression(
+ Expression::As {
+ kind,
+ expr,
+ convert: Some(width),
+ },
+ meta,
+ body,
+ )
+ }
+ TypeInner::Vector { size, kind, width } => {
+ if vector_size.map_or(true, |s| s != size) {
+ value = ctx.vector_resize(size, value, expr_meta, body);
+ }
+
+ ctx.add_expression(
+ Expression::As {
+ kind,
+ expr: value,
+ convert: Some(width),
+ },
+ meta,
+ body,
+ )
+ }
+ TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => self.matrix_one_arg(
+ ctx,
+ body,
+ ty,
+ columns,
+ rows,
+ width,
+ (value, expr_meta),
+ meta,
+ )?,
+ TypeInner::Struct { ref members, .. } => {
+ let scalar_components = members
+ .get(0)
+ .and_then(|member| scalar_components(&self.module.types[member.ty].inner));
+ if let Some((kind, width)) = scalar_components {
+ ctx.implicit_conversion(self, &mut value, expr_meta, kind, width)?;
+ }
+
+ ctx.add_expression(
+ Expression::Compose {
+ ty,
+ components: vec![value],
+ },
+ meta,
+ body,
+ )
+ }
+
+ TypeInner::Array { base, .. } => {
+ let scalar_components = scalar_components(&self.module.types[base].inner);
+ if let Some((kind, width)) = scalar_components {
+ ctx.implicit_conversion(self, &mut value, expr_meta, kind, width)?;
+ }
+
+ ctx.add_expression(
+ Expression::Compose {
+ ty,
+ components: vec![value],
+ },
+ meta,
+ body,
+ )
+ }
+ _ => {
+ self.errors.push(Error {
+ kind: ErrorKind::SemanticError("Bad type constructor".into()),
+ meta,
+ });
+
+ value
+ }
+ })
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ fn matrix_one_arg(
+ &mut self,
+ ctx: &mut Context,
+ body: &mut Block,
+ ty: Handle<Type>,
+ columns: crate::VectorSize,
+ rows: crate::VectorSize,
+ width: crate::Bytes,
+ (mut value, expr_meta): (Handle<Expression>, Span),
+ meta: Span,
+ ) -> Result<Handle<Expression>> {
+ let mut components = Vec::with_capacity(columns as usize);
+ // TODO: casts
+ // `Expression::As` doesn't support matrix width
+ // casts so we need to do some extra work for casts
+
+ ctx.forced_conversion(self, &mut value, expr_meta, ScalarKind::Float, width)?;
+ match *self.resolve_type(ctx, value, expr_meta)? {
+ TypeInner::Scalar { .. } => {
+ // If a matrix is constructed with a single scalar value, then that
+ // value is used to initialize all the values along the diagonal of
+ // the matrix; the rest are given zeros.
+ let vector_ty = self.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Vector {
+ size: rows,
+ kind: ScalarKind::Float,
+ width,
+ },
+ },
+ meta,
+ );
+ let zero_constant = self.module.constants.fetch_or_append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner: ConstantInner::Scalar {
+ width,
+ value: ScalarValue::Float(0.0),
+ },
+ },
+ meta,
+ );
+ let zero = ctx.add_expression(Expression::Constant(zero_constant), meta, body);
+
+ for i in 0..columns as u32 {
+ components.push(
+ ctx.add_expression(
+ Expression::Compose {
+ ty: vector_ty,
+ components: (0..rows as u32)
+ .map(|r| match r == i {
+ true => value,
+ false => zero,
+ })
+ .collect(),
+ },
+ meta,
+ body,
+ ),
+ )
+ }
+ }
+ TypeInner::Matrix {
+ rows: ori_rows,
+ columns: ori_cols,
+ ..
+ } => {
+ // If a matrix is constructed from a matrix, then each component
+ // (column i, row j) in the result that has a corresponding component
+ // (column i, row j) in the argument will be initialized from there. All
+ // other components will be initialized to the identity matrix.
+
+ let zero_constant = self.module.constants.fetch_or_append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner: ConstantInner::Scalar {
+ width,
+ value: ScalarValue::Float(0.0),
+ },
+ },
+ meta,
+ );
+ let zero = ctx.add_expression(Expression::Constant(zero_constant), meta, body);
+ let one_constant = self.module.constants.fetch_or_append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner: ConstantInner::Scalar {
+ width,
+ value: ScalarValue::Float(1.0),
+ },
+ },
+ meta,
+ );
+ let one = ctx.add_expression(Expression::Constant(one_constant), meta, body);
+ let vector_ty = self.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Vector {
+ size: rows,
+ kind: ScalarKind::Float,
+ width,
+ },
+ },
+ meta,
+ );
+
+ for i in 0..columns as u32 {
+ if i < ori_cols as u32 {
+ use std::cmp::Ordering;
+
+ let vector = ctx.add_expression(
+ Expression::AccessIndex {
+ base: value,
+ index: i,
+ },
+ meta,
+ body,
+ );
+
+ components.push(match ori_rows.cmp(&rows) {
+ Ordering::Less => {
+ let components = (0..rows as u32)
+ .map(|r| {
+ if r < ori_rows as u32 {
+ ctx.add_expression(
+ Expression::AccessIndex {
+ base: vector,
+ index: r,
+ },
+ meta,
+ body,
+ )
+ } else if r == i {
+ one
+ } else {
+ zero
+ }
+ })
+ .collect();
+
+ ctx.add_expression(
+ Expression::Compose {
+ ty: vector_ty,
+ components,
+ },
+ meta,
+ body,
+ )
+ }
+ Ordering::Equal => vector,
+ Ordering::Greater => ctx.vector_resize(rows, vector, meta, body),
+ })
+ } else {
+ let vec_constant = self.module.constants.fetch_or_append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner: ConstantInner::Composite {
+ ty: vector_ty,
+ components: (0..rows as u32)
+ .map(|r| match r == i {
+ true => one_constant,
+ false => zero_constant,
+ })
+ .collect(),
+ },
+ },
+ meta,
+ );
+ let vec =
+ ctx.add_expression(Expression::Constant(vec_constant), meta, body);
+
+ components.push(vec)
+ }
+ }
+ }
+ _ => {
+ components = iter::repeat(value).take(columns as usize).collect();
+ }
+ }
+
+ Ok(ctx.add_expression(Expression::Compose { ty, components }, meta, body))
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ fn vector_constructor(
+ &mut self,
+ ctx: &mut Context,
+ body: &mut Block,
+ ty: Handle<Type>,
+ size: crate::VectorSize,
+ kind: ScalarKind,
+ width: crate::Bytes,
+ args: &[(Handle<Expression>, Span)],
+ meta: Span,
+ ) -> Result<Handle<Expression>> {
+ let mut components = Vec::with_capacity(size as usize);
+
+ for (mut arg, expr_meta) in args.iter().copied() {
+ ctx.forced_conversion(self, &mut arg, expr_meta, kind, width)?;
+
+ if components.len() >= size as usize {
+ break;
+ }
+
+ match *self.resolve_type(ctx, arg, expr_meta)? {
+ TypeInner::Scalar { .. } => components.push(arg),
+ TypeInner::Matrix { rows, columns, .. } => {
+ components.reserve(rows as usize * columns as usize);
+ for c in 0..(columns as u32) {
+ let base = ctx.add_expression(
+ Expression::AccessIndex {
+ base: arg,
+ index: c,
+ },
+ expr_meta,
+ body,
+ );
+ for r in 0..(rows as u32) {
+ components.push(ctx.add_expression(
+ Expression::AccessIndex { base, index: r },
+ expr_meta,
+ body,
+ ))
+ }
+ }
+ }
+ TypeInner::Vector { size: ori_size, .. } => {
+ components.reserve(ori_size as usize);
+ for index in 0..(ori_size as u32) {
+ components.push(ctx.add_expression(
+ Expression::AccessIndex { base: arg, index },
+ expr_meta,
+ body,
+ ))
+ }
+ }
+ _ => components.push(arg),
+ }
+ }
+
+ components.truncate(size as usize);
+
+ Ok(ctx.add_expression(Expression::Compose { ty, components }, meta, body))
+ }
+
+ fn constructor_many(
+ &mut self,
+ ctx: &mut Context,
+ body: &mut Block,
+ ty: Handle<Type>,
+ args: Vec<(Handle<Expression>, Span)>,
+ meta: Span,
+ ) -> Result<Handle<Expression>> {
+ let mut components = Vec::with_capacity(args.len());
+
+ match self.module.types[ty].inner {
+ TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ let mut flattened = Vec::with_capacity(columns as usize * rows as usize);
+
+ for (mut arg, meta) in args.iter().copied() {
+ ctx.forced_conversion(self, &mut arg, meta, ScalarKind::Float, width)?;
+
+ match *self.resolve_type(ctx, arg, meta)? {
+ TypeInner::Vector { size, .. } => {
+ for i in 0..(size as u32) {
+ flattened.push(ctx.add_expression(
+ Expression::AccessIndex {
+ base: arg,
+ index: i,
+ },
+ meta,
+ body,
+ ))
+ }
+ }
+ _ => flattened.push(arg),
+ }
+ }
+
+ let ty = self.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Vector {
+ size: rows,
+ kind: ScalarKind::Float,
+ width,
+ },
+ },
+ meta,
+ );
+
+ for chunk in flattened.chunks(rows as usize) {
+ components.push(ctx.add_expression(
+ Expression::Compose {
+ ty,
+ components: Vec::from(chunk),
+ },
+ meta,
+ body,
+ ))
+ }
+ }
+ TypeInner::Vector { size, kind, width } => {
+ return self.vector_constructor(ctx, body, ty, size, kind, width, &args, meta)
+ }
+ TypeInner::Array { base, .. } => {
+ for (mut arg, meta) in args.iter().copied() {
+ let scalar_components = scalar_components(&self.module.types[base].inner);
+ if let Some((kind, width)) = scalar_components {
+ ctx.implicit_conversion(self, &mut arg, meta, kind, width)?;
+ }
+
+ components.push(arg)
+ }
+ }
+ TypeInner::Struct { ref members, .. } => {
+ for ((mut arg, meta), member) in args.iter().copied().zip(members.iter()) {
+ let scalar_components = scalar_components(&self.module.types[member.ty].inner);
+ if let Some((kind, width)) = scalar_components {
+ ctx.implicit_conversion(self, &mut arg, meta, kind, width)?;
+ }
+
+ components.push(arg)
+ }
+ }
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::SemanticError("Constructor: Too many arguments".into()),
+ meta,
+ })
+ }
+ }
+
+ Ok(ctx.add_expression(Expression::Compose { ty, components }, meta, body))
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ fn function_call(
+ &mut self,
+ ctx: &mut Context,
+ stmt: &StmtContext,
+ body: &mut Block,
+ name: String,
+ args: Vec<(Handle<Expression>, Span)>,
+ raw_args: &[Handle<HirExpr>],
+ meta: Span,
+ ) -> Result<Option<Handle<Expression>>> {
+ // Grow the typifier to be able to index it later without needing
+ // to hold the context mutably
+ for &(expr, span) in args.iter() {
+ self.typifier_grow(ctx, expr, span)?;
+ }
+
+ // Check if the passed arguments require any special variations
+ let mut variations = builtin_required_variations(
+ args.iter()
+ .map(|&(expr, _)| ctx.typifier.get(expr, &self.module.types)),
+ );
+
+ // Initiate the declaration if it wasn't previously initialized and inject builtins
+ let declaration = self.lookup_function.entry(name.clone()).or_insert_with(|| {
+ variations |= BuiltinVariations::STANDARD;
+ Default::default()
+ });
+ inject_builtin(declaration, &mut self.module, &name, variations);
+
+ // Borrow again but without mutability, at this point a declaration is guaranteed
+ let declaration = self.lookup_function.get(&name).unwrap();
+
+ // Possibly contains the overload to be used in the call
+ let mut maybe_overload = None;
+ // The conversions needed for the best analyzed overload, this is initialized all to
+ // `NONE` to make sure that conversions always pass the first time without ambiguity
+ let mut old_conversions = vec![Conversion::None; args.len()];
+ // Tracks whether the comparison between overloads lead to an ambiguity
+ let mut ambiguous = false;
+
+ // Iterate over all the available overloads to select either an exact match or a
+ // overload which has suitable implicit conversions
+ 'outer: for (overload_idx, overload) in declaration.overloads.iter().enumerate() {
+ // If the overload and the function call don't have the same number of arguments
+ // continue to the next overload
+ if args.len() != overload.parameters.len() {
+ continue;
+ }
+
+ log::trace!("Testing overload {}", overload_idx);
+
+ // Stores whether the current overload matches exactly the function call
+ let mut exact = true;
+ // State of the selection
+ // If None we still don't know what is the best overload
+ // If Some(true) the new overload is better
+ // If Some(false) the old overload is better
+ let mut superior = None;
+ // Store the conversions for the current overload so that later they can replace the
+ // conversions used for querying the best overload
+ let mut new_conversions = vec![Conversion::None; args.len()];
+
+ // Loop through the overload parameters and check if the current overload is better
+ // compared to the previous best overload.
+ for (i, overload_parameter) in overload.parameters.iter().enumerate() {
+ let call_argument = &args[i];
+ let parameter_info = &overload.parameters_info[i];
+
+ // If the image is used in the overload as a depth texture convert it
+ // before comparing, otherwise exact matches wouldn't be reported
+ if parameter_info.depth {
+ sampled_to_depth(
+ &mut self.module,
+ ctx,
+ call_argument.0,
+ call_argument.1,
+ &mut self.errors,
+ );
+ self.invalidate_expression(ctx, call_argument.0, call_argument.1)?
+ }
+
+ let overload_param_ty = &self.module.types[*overload_parameter].inner;
+ let call_arg_ty = self.resolve_type(ctx, call_argument.0, call_argument.1)?;
+
+ log::trace!(
+ "Testing parameter {}\n\tOverload = {:?}\n\tCall = {:?}",
+ i,
+ overload_param_ty,
+ call_arg_ty
+ );
+
+ // Storage images cannot be directly compared since while the access is part of the
+ // type in naga's IR, in glsl they are a qualifier and don't enter in the match as
+ // long as the access needed is satisfied.
+ if let (
+ &TypeInner::Image {
+ class:
+ crate::ImageClass::Storage {
+ format: overload_format,
+ access: overload_access,
+ },
+ dim: overload_dim,
+ arrayed: overload_arrayed,
+ },
+ &TypeInner::Image {
+ class:
+ crate::ImageClass::Storage {
+ format: call_format,
+ access: call_access,
+ },
+ dim: call_dim,
+ arrayed: call_arrayed,
+ },
+ ) = (overload_param_ty, call_arg_ty)
+ {
+ // Images size must match otherwise the overload isn't what we want
+ let good_size = call_dim == overload_dim && call_arrayed == overload_arrayed;
+ // Glsl requires the formats to strictly match unless you are builtin
+ // function overload and have not been replaced, in which case we only
+ // check that the format scalar kind matches
+ let good_format = overload_format == call_format
+ || (overload.internal
+ && ScalarKind::from(overload_format) == ScalarKind::from(call_format));
+ if !(good_size && good_format) {
+ continue 'outer;
+ }
+
+ // While storage access mismatch is an error it isn't one that causes
+ // the overload matching to fail so we defer the error and consider
+ // that the images match exactly
+ if !call_access.contains(overload_access) {
+ self.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ format!(
+ "'{name}': image needs {overload_access:?} access but only {call_access:?} was provided"
+ )
+ .into(),
+ ),
+ meta,
+ });
+ }
+
+ // The images satisfy the conditions to be considered as an exact match
+ new_conversions[i] = Conversion::Exact;
+ continue;
+ } else if overload_param_ty == call_arg_ty {
+ // If the types match there's no need to check for conversions so continue
+ new_conversions[i] = Conversion::Exact;
+ continue;
+ }
+
+ // Glsl defines that inout follows both the conversions for input parameters and
+ // output parameters, this means that the type must have a conversion from both the
+ // call argument to the function parameter and the function parameter to the call
+ // argument, the only way this is possible is for the conversion to be an identity
+ // (i.e. call argument = function parameter)
+ if let ParameterQualifier::InOut = parameter_info.qualifier {
+ continue 'outer;
+ }
+
+ // The function call argument and the function definition
+ // parameter are not equal at this point, so we need to try
+ // implicit conversions.
+ //
+ // Now there are two cases, the argument is defined as a normal
+ // parameter (`in` or `const`), in this case an implicit
+ // conversion is made from the calling argument to the
+ // definition argument. If the parameter is `out` the
+ // opposite needs to be done, so the implicit conversion is made
+ // from the definition argument to the calling argument.
+ let maybe_conversion = if parameter_info.qualifier.is_lhs() {
+ conversion(call_arg_ty, overload_param_ty)
+ } else {
+ conversion(overload_param_ty, call_arg_ty)
+ };
+
+ let conversion = match maybe_conversion {
+ Some(info) => info,
+ None => continue 'outer,
+ };
+
+ // At this point a conversion will be needed so the overload no longer
+ // exactly matches the call arguments
+ exact = false;
+
+ // Compare the conversions needed for this overload parameter to that of the
+ // last overload analyzed respective parameter, the value is:
+ // - `true` when the new overload argument has a better conversion
+ // - `false` when the old overload argument has a better conversion
+ let best_arg = match (conversion, old_conversions[i]) {
+ // An exact match is always better, we don't need to check this for the
+ // current overload since it was checked earlier
+ (_, Conversion::Exact) => false,
+ // No overload was yet analyzed so this one is the best yet
+ (_, Conversion::None) => true,
+ // A conversion from a float to a double is the best possible conversion
+ (Conversion::FloatToDouble, _) => true,
+ (_, Conversion::FloatToDouble) => false,
+ // A conversion from a float to an integer is preferred than one
+ // from double to an integer
+ (Conversion::IntToFloat, Conversion::IntToDouble) => true,
+ (Conversion::IntToDouble, Conversion::IntToFloat) => false,
+ // This case handles things like no conversion and exact which were already
+ // treated and other cases which no conversion is better than the other
+ _ => continue,
+ };
+
+ // Check if the best parameter corresponds to the current selected overload
+ // to pass to the next comparison, if this isn't true mark it as ambiguous
+ match best_arg {
+ true => match superior {
+ Some(false) => ambiguous = true,
+ _ => {
+ superior = Some(true);
+ new_conversions[i] = conversion
+ }
+ },
+ false => match superior {
+ Some(true) => ambiguous = true,
+ _ => superior = Some(false),
+ },
+ }
+ }
+
+ // The overload matches exactly the function call so there's no ambiguity (since
+ // repeated overload aren't allowed) and the current overload is selected, no
+ // further querying is needed.
+ if exact {
+ maybe_overload = Some(overload);
+ ambiguous = false;
+ break;
+ }
+
+ match superior {
+ // New overload is better keep it
+ Some(true) => {
+ maybe_overload = Some(overload);
+ // Replace the conversions
+ old_conversions = new_conversions;
+ }
+ // Old overload is better do nothing
+ Some(false) => {}
+ // No overload was better than the other this can be caused
+ // when all conversions are ambiguous in which the overloads themselves are
+ // ambiguous.
+ None => {
+ ambiguous = true;
+ // Assign the new overload, this helps ensures that in this case of
+ // ambiguity the parsing won't end immediately and allow for further
+ // collection of errors.
+ maybe_overload = Some(overload);
+ }
+ }
+ }
+
+ if ambiguous {
+ self.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ format!("Ambiguous best function for '{name}'").into(),
+ ),
+ meta,
+ })
+ }
+
+ let overload = maybe_overload.ok_or_else(|| Error {
+ kind: ErrorKind::SemanticError(format!("Unknown function '{name}'").into()),
+ meta,
+ })?;
+
+ let parameters_info = overload.parameters_info.clone();
+ let parameters = overload.parameters.clone();
+ let is_void = overload.void;
+ let kind = overload.kind;
+
+ let mut arguments = Vec::with_capacity(args.len());
+ let mut proxy_writes = Vec::new();
+
+ // Iterate through the function call arguments applying transformations as needed
+ for (((parameter_info, call_argument), expr), parameter) in parameters_info
+ .iter()
+ .zip(&args)
+ .zip(raw_args)
+ .zip(&parameters)
+ {
+ let (mut handle, meta) =
+ ctx.lower_expect_inner(stmt, self, *expr, parameter_info.qualifier.as_pos(), body)?;
+
+ if parameter_info.qualifier.is_lhs() {
+ self.process_lhs_argument(
+ ctx,
+ body,
+ meta,
+ *parameter,
+ parameter_info,
+ handle,
+ call_argument,
+ &mut proxy_writes,
+ &mut arguments,
+ )?;
+
+ continue;
+ }
+
+ let scalar_comps = scalar_components(&self.module.types[*parameter].inner);
+
+ // Apply implicit conversions as needed
+ if let Some((kind, width)) = scalar_comps {
+ ctx.implicit_conversion(self, &mut handle, meta, kind, width)?;
+ }
+
+ arguments.push(handle)
+ }
+
+ match kind {
+ FunctionKind::Call(function) => {
+ ctx.emit_end(body);
+
+ let result = if !is_void {
+ Some(ctx.add_expression(Expression::CallResult(function), meta, body))
+ } else {
+ None
+ };
+
+ body.push(
+ crate::Statement::Call {
+ function,
+ arguments,
+ result,
+ },
+ meta,
+ );
+
+ ctx.emit_start();
+
+ // Write back all the variables that were scheduled to their original place
+ for proxy_write in proxy_writes {
+ let mut value = ctx.add_expression(
+ Expression::Load {
+ pointer: proxy_write.value,
+ },
+ meta,
+ body,
+ );
+
+ if let Some((kind, width)) = proxy_write.convert {
+ ctx.conversion(&mut value, meta, kind, width)?;
+ }
+
+ ctx.emit_restart(body);
+
+ body.push(
+ Statement::Store {
+ pointer: proxy_write.target,
+ value,
+ },
+ meta,
+ );
+ }
+
+ Ok(result)
+ }
+ FunctionKind::Macro(builtin) => {
+ builtin.call(self, ctx, body, arguments.as_mut_slice(), meta)
+ }
+ }
+ }
+
+ /// Processes a function call argument that appears in place of an output
+ /// parameter.
+ #[allow(clippy::too_many_arguments)]
+ fn process_lhs_argument(
+ &mut self,
+ ctx: &mut Context,
+ body: &mut Block,
+ meta: Span,
+ parameter_ty: Handle<Type>,
+ parameter_info: &ParameterInfo,
+ original: Handle<Expression>,
+ call_argument: &(Handle<Expression>, Span),
+ proxy_writes: &mut Vec<ProxyWrite>,
+ arguments: &mut Vec<Handle<Expression>>,
+ ) -> Result<()> {
+ let original_ty = self.resolve_type(ctx, original, meta)?;
+ let original_pointer_space = original_ty.pointer_space();
+
+ // The type of a possible spill variable needed for a proxy write
+ let mut maybe_ty = match *original_ty {
+ // If the argument is to be passed as a pointer but the type of the
+ // expression returns a vector it must mean that it was for example
+ // swizzled and it must be spilled into a local before calling
+ TypeInner::Vector { size, kind, width } => Some(self.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Vector { size, kind, width },
+ },
+ Span::default(),
+ )),
+ // If the argument is a pointer whose address space isn't `Function`, an
+ // indirection through a local variable is needed to align the address
+ // spaces of the call argument and the overload parameter.
+ TypeInner::Pointer { base, space } if space != AddressSpace::Function => Some(base),
+ TypeInner::ValuePointer {
+ size,
+ kind,
+ width,
+ space,
+ } if space != AddressSpace::Function => {
+ let inner = match size {
+ Some(size) => TypeInner::Vector { size, kind, width },
+ None => TypeInner::Scalar { kind, width },
+ };
+
+ Some(
+ self.module
+ .types
+ .insert(Type { name: None, inner }, Span::default()),
+ )
+ }
+ _ => None,
+ };
+
+ // Since the original expression might be a pointer and we want a value
+ // for the proxy writes, we might need to load the pointer.
+ let value = if original_pointer_space.is_some() {
+ ctx.add_expression(
+ Expression::Load { pointer: original },
+ Span::default(),
+ body,
+ )
+ } else {
+ original
+ };
+
+ let call_arg_ty = self.resolve_type(ctx, call_argument.0, call_argument.1)?;
+ let overload_param_ty = &self.module.types[parameter_ty].inner;
+ let needs_conversion = call_arg_ty != overload_param_ty;
+
+ let arg_scalar_comps = scalar_components(call_arg_ty);
+
+ // Since output parameters also allow implicit conversions from the
+ // parameter to the argument, we need to spill the conversion to a
+ // variable and create a proxy write for the original variable.
+ if needs_conversion {
+ maybe_ty = Some(parameter_ty);
+ }
+
+ if let Some(ty) = maybe_ty {
+ // Create the spill variable
+ let spill_var = ctx.locals.append(
+ LocalVariable {
+ name: None,
+ ty,
+ init: None,
+ },
+ Span::default(),
+ );
+ let spill_expr =
+ ctx.add_expression(Expression::LocalVariable(spill_var), Span::default(), body);
+
+ // If the argument is also copied in we must store the value of the
+ // original variable to the spill variable.
+ if let ParameterQualifier::InOut = parameter_info.qualifier {
+ body.push(
+ Statement::Store {
+ pointer: spill_expr,
+ value,
+ },
+ Span::default(),
+ );
+ }
+
+ // Add the spill variable as an argument to the function call
+ arguments.push(spill_expr);
+
+ let convert = if needs_conversion {
+ arg_scalar_comps
+ } else {
+ None
+ };
+
+ // Register the temporary local to be written back to it's original
+ // place after the function call
+ if let Expression::Swizzle {
+ size,
+ mut vector,
+ pattern,
+ } = ctx.expressions[original]
+ {
+ if let Expression::Load { pointer } = ctx.expressions[vector] {
+ vector = pointer;
+ }
+
+ for (i, component) in pattern.iter().take(size as usize).enumerate() {
+ let original = ctx.add_expression(
+ Expression::AccessIndex {
+ base: vector,
+ index: *component as u32,
+ },
+ Span::default(),
+ body,
+ );
+
+ let spill_component = ctx.add_expression(
+ Expression::AccessIndex {
+ base: spill_expr,
+ index: i as u32,
+ },
+ Span::default(),
+ body,
+ );
+
+ proxy_writes.push(ProxyWrite {
+ target: original,
+ value: spill_component,
+ convert,
+ });
+ }
+ } else {
+ proxy_writes.push(ProxyWrite {
+ target: original,
+ value: spill_expr,
+ convert,
+ });
+ }
+ } else {
+ arguments.push(original);
+ }
+
+ Ok(())
+ }
+
+ pub(crate) fn add_function(
+ &mut self,
+ ctx: Context,
+ name: String,
+ result: Option<FunctionResult>,
+ mut body: Block,
+ meta: Span,
+ ) {
+ ensure_block_returns(&mut body);
+
+ let void = result.is_none();
+
+ let &mut Frontend {
+ ref mut lookup_function,
+ ref mut module,
+ ..
+ } = self;
+
+ // Check if the passed arguments require any special variations
+ let mut variations =
+ builtin_required_variations(ctx.parameters.iter().map(|&arg| &module.types[arg].inner));
+
+ // Initiate the declaration if it wasn't previously initialized and inject builtins
+ let declaration = lookup_function.entry(name.clone()).or_insert_with(|| {
+ variations |= BuiltinVariations::STANDARD;
+ Default::default()
+ });
+ inject_builtin(declaration, module, &name, variations);
+
+ let Context {
+ expressions,
+ locals,
+ arguments,
+ parameters,
+ parameters_info,
+ ..
+ } = ctx;
+
+ let function = Function {
+ name: Some(name),
+ arguments,
+ result,
+ local_variables: locals,
+ expressions,
+ named_expressions: crate::NamedExpressions::default(),
+ body,
+ };
+
+ 'outer: for decl in declaration.overloads.iter_mut() {
+ if parameters.len() != decl.parameters.len() {
+ continue;
+ }
+
+ for (new_parameter, old_parameter) in parameters.iter().zip(decl.parameters.iter()) {
+ let new_inner = &module.types[*new_parameter].inner;
+ let old_inner = &module.types[*old_parameter].inner;
+
+ if new_inner != old_inner {
+ continue 'outer;
+ }
+ }
+
+ if decl.defined {
+ return self.errors.push(Error {
+ kind: ErrorKind::SemanticError("Function already defined".into()),
+ meta,
+ });
+ }
+
+ decl.defined = true;
+ decl.parameters_info = parameters_info;
+ match decl.kind {
+ FunctionKind::Call(handle) => *self.module.functions.get_mut(handle) = function,
+ FunctionKind::Macro(_) => {
+ let handle = module.functions.append(function, meta);
+ decl.kind = FunctionKind::Call(handle)
+ }
+ }
+ return;
+ }
+
+ let handle = module.functions.append(function, meta);
+ declaration.overloads.push(Overload {
+ parameters,
+ parameters_info,
+ kind: FunctionKind::Call(handle),
+ defined: true,
+ internal: false,
+ void,
+ });
+ }
+
+ pub(crate) fn add_prototype(
+ &mut self,
+ ctx: Context,
+ name: String,
+ result: Option<FunctionResult>,
+ meta: Span,
+ ) {
+ let void = result.is_none();
+
+ let &mut Frontend {
+ ref mut lookup_function,
+ ref mut module,
+ ..
+ } = self;
+
+ // Check if the passed arguments require any special variations
+ let mut variations =
+ builtin_required_variations(ctx.parameters.iter().map(|&arg| &module.types[arg].inner));
+
+ // Initiate the declaration if it wasn't previously initialized and inject builtins
+ let declaration = lookup_function.entry(name.clone()).or_insert_with(|| {
+ variations |= BuiltinVariations::STANDARD;
+ Default::default()
+ });
+ inject_builtin(declaration, module, &name, variations);
+
+ let Context {
+ arguments,
+ parameters,
+ parameters_info,
+ ..
+ } = ctx;
+
+ let function = Function {
+ name: Some(name),
+ arguments,
+ result,
+ ..Default::default()
+ };
+
+ 'outer: for decl in declaration.overloads.iter() {
+ if parameters.len() != decl.parameters.len() {
+ continue;
+ }
+
+ for (new_parameter, old_parameter) in parameters.iter().zip(decl.parameters.iter()) {
+ let new_inner = &module.types[*new_parameter].inner;
+ let old_inner = &module.types[*old_parameter].inner;
+
+ if new_inner != old_inner {
+ continue 'outer;
+ }
+ }
+
+ return self.errors.push(Error {
+ kind: ErrorKind::SemanticError("Prototype already defined".into()),
+ meta,
+ });
+ }
+
+ let handle = module.functions.append(function, meta);
+ declaration.overloads.push(Overload {
+ parameters,
+ parameters_info,
+ kind: FunctionKind::Call(handle),
+ defined: false,
+ internal: false,
+ void,
+ });
+ }
+
+ /// Helper function for building the input/output interface of the entry point
+ ///
+ /// Calls `f` with the data of the entry point argument, flattening composite types
+ /// recursively
+ ///
+ /// The passed arguments to the callback are:
+ /// - The name
+ /// - The pointer expression to the global storage
+ /// - The handle to the type of the entry point argument
+ /// - The binding of the entry point argument
+ /// - The expression arena
+ fn arg_type_walker(
+ &self,
+ name: Option<String>,
+ binding: crate::Binding,
+ pointer: Handle<Expression>,
+ ty: Handle<Type>,
+ expressions: &mut Arena<Expression>,
+ f: &mut impl FnMut(
+ Option<String>,
+ Handle<Expression>,
+ Handle<Type>,
+ crate::Binding,
+ &mut Arena<Expression>,
+ ),
+ ) {
+ match self.module.types[ty].inner {
+ TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(constant),
+ ..
+ } => {
+ let mut location = match binding {
+ crate::Binding::Location { location, .. } => location,
+ crate::Binding::BuiltIn(_) => return,
+ };
+
+ // TODO: Better error reporting
+ // right now we just don't walk the array if the size isn't known at
+ // compile time and let validation catch it
+ let size = match self.module.constants[constant].to_array_length() {
+ Some(val) => val,
+ None => return f(name, pointer, ty, binding, expressions),
+ };
+
+ let interpolation =
+ self.module.types[base]
+ .inner
+ .scalar_kind()
+ .map(|kind| match kind {
+ ScalarKind::Float => crate::Interpolation::Perspective,
+ _ => crate::Interpolation::Flat,
+ });
+
+ for index in 0..size {
+ let member_pointer = expressions.append(
+ Expression::AccessIndex {
+ base: pointer,
+ index,
+ },
+ crate::Span::default(),
+ );
+
+ let binding = crate::Binding::Location {
+ location,
+ interpolation,
+ sampling: None,
+ };
+ location += 1;
+
+ self.arg_type_walker(
+ name.clone(),
+ binding,
+ member_pointer,
+ base,
+ expressions,
+ f,
+ )
+ }
+ }
+ TypeInner::Struct { ref members, .. } => {
+ let mut location = match binding {
+ crate::Binding::Location { location, .. } => location,
+ crate::Binding::BuiltIn(_) => return,
+ };
+
+ for (i, member) in members.iter().enumerate() {
+ let member_pointer = expressions.append(
+ Expression::AccessIndex {
+ base: pointer,
+ index: i as u32,
+ },
+ crate::Span::default(),
+ );
+
+ let binding = match member.binding.clone() {
+ Some(binding) => binding,
+ None => {
+ let interpolation = self.module.types[member.ty]
+ .inner
+ .scalar_kind()
+ .map(|kind| match kind {
+ ScalarKind::Float => crate::Interpolation::Perspective,
+ _ => crate::Interpolation::Flat,
+ });
+ let binding = crate::Binding::Location {
+ location,
+ interpolation,
+ sampling: None,
+ };
+ location += 1;
+ binding
+ }
+ };
+
+ self.arg_type_walker(
+ member.name.clone(),
+ binding,
+ member_pointer,
+ member.ty,
+ expressions,
+ f,
+ )
+ }
+ }
+ _ => f(name, pointer, ty, binding, expressions),
+ }
+ }
+
+ pub(crate) fn add_entry_point(
+ &mut self,
+ function: Handle<Function>,
+ global_init_body: Block,
+ mut expressions: Arena<Expression>,
+ ) {
+ let mut arguments = Vec::new();
+ let mut body = Block::with_capacity(
+ // global init body
+ global_init_body.len() +
+ // prologue and epilogue
+ self.entry_args.len() * 2
+ // Call, Emit for composing struct and return
+ + 3,
+ );
+
+ for arg in self.entry_args.iter() {
+ if arg.storage != StorageQualifier::Input {
+ continue;
+ }
+
+ let pointer =
+ expressions.append(Expression::GlobalVariable(arg.handle), Default::default());
+
+ self.arg_type_walker(
+ arg.name.clone(),
+ arg.binding.clone(),
+ pointer,
+ self.module.global_variables[arg.handle].ty,
+ &mut expressions,
+ &mut |name, pointer, ty, binding, expressions| {
+ let idx = arguments.len() as u32;
+
+ arguments.push(FunctionArgument {
+ name,
+ ty,
+ binding: Some(binding),
+ });
+
+ let value =
+ expressions.append(Expression::FunctionArgument(idx), Default::default());
+ body.push(Statement::Store { pointer, value }, Default::default());
+ },
+ )
+ }
+
+ body.extend_block(global_init_body);
+
+ body.push(
+ Statement::Call {
+ function,
+ arguments: Vec::new(),
+ result: None,
+ },
+ Default::default(),
+ );
+
+ let mut span = 0;
+ let mut members = Vec::new();
+ let mut components = Vec::new();
+
+ for arg in self.entry_args.iter() {
+ if arg.storage != StorageQualifier::Output {
+ continue;
+ }
+
+ let pointer =
+ expressions.append(Expression::GlobalVariable(arg.handle), Default::default());
+
+ self.arg_type_walker(
+ arg.name.clone(),
+ arg.binding.clone(),
+ pointer,
+ self.module.global_variables[arg.handle].ty,
+ &mut expressions,
+ &mut |name, pointer, ty, binding, expressions| {
+ members.push(StructMember {
+ name,
+ ty,
+ binding: Some(binding),
+ offset: span,
+ });
+
+ span += self.module.types[ty].inner.size(&self.module.constants);
+
+ let len = expressions.len();
+ let load = expressions.append(Expression::Load { pointer }, Default::default());
+ body.push(
+ Statement::Emit(expressions.range_from(len)),
+ Default::default(),
+ );
+ components.push(load)
+ },
+ )
+ }
+
+ let (ty, value) = if !components.is_empty() {
+ let ty = self.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Struct { members, span },
+ },
+ Default::default(),
+ );
+
+ let len = expressions.len();
+ let res =
+ expressions.append(Expression::Compose { ty, components }, Default::default());
+ body.push(
+ Statement::Emit(expressions.range_from(len)),
+ Default::default(),
+ );
+
+ (Some(ty), Some(res))
+ } else {
+ (None, None)
+ };
+
+ body.push(Statement::Return { value }, Default::default());
+
+ self.module.entry_points.push(EntryPoint {
+ name: "main".to_string(),
+ stage: self.meta.stage,
+ early_depth_test: Some(crate::EarlyDepthTest { conservative: None })
+ .filter(|_| self.meta.early_fragment_tests),
+ workgroup_size: self.meta.workgroup_size,
+ function: Function {
+ arguments,
+ expressions,
+ body,
+ result: ty.map(|ty| FunctionResult { ty, binding: None }),
+ ..Default::default()
+ },
+ });
+ }
+}
+
+/// Helper enum containing the type of conversion need for a call
+#[derive(PartialEq, Eq, Clone, Copy, Debug)]
+enum Conversion {
+ /// No conversion needed
+ Exact,
+ /// Float to double conversion needed
+ FloatToDouble,
+ /// Int or uint to float conversion needed
+ IntToFloat,
+ /// Int or uint to double conversion needed
+ IntToDouble,
+ /// Other type of conversion needed
+ Other,
+ /// No conversion was yet registered
+ None,
+}
+
+/// Helper function, returns the type of conversion from `source` to `target`, if a
+/// conversion is not possible returns None.
+fn conversion(target: &TypeInner, source: &TypeInner) -> Option<Conversion> {
+ use ScalarKind::*;
+
+ // Gather the `ScalarKind` and scalar width from both the target and the source
+ let (target_kind, target_width, source_kind, source_width) = match (target, source) {
+ // Conversions between scalars are allowed
+ (
+ &TypeInner::Scalar {
+ kind: tgt_kind,
+ width: tgt_width,
+ },
+ &TypeInner::Scalar {
+ kind: src_kind,
+ width: src_width,
+ },
+ ) => (tgt_kind, tgt_width, src_kind, src_width),
+ // Conversions between vectors of the same size are allowed
+ (
+ &TypeInner::Vector {
+ kind: tgt_kind,
+ size: tgt_size,
+ width: tgt_width,
+ },
+ &TypeInner::Vector {
+ kind: src_kind,
+ size: src_size,
+ width: src_width,
+ },
+ ) if tgt_size == src_size => (tgt_kind, tgt_width, src_kind, src_width),
+ // Conversions between matrices of the same size are allowed
+ (
+ &TypeInner::Matrix {
+ rows: tgt_rows,
+ columns: tgt_cols,
+ width: tgt_width,
+ },
+ &TypeInner::Matrix {
+ rows: src_rows,
+ columns: src_cols,
+ width: src_width,
+ },
+ ) if tgt_cols == src_cols && tgt_rows == src_rows => (Float, tgt_width, Float, src_width),
+ _ => return None,
+ };
+
+ // Check if source can be converted into target, if this is the case then the type
+ // power of target must be higher than that of source
+ let target_power = type_power(target_kind, target_width);
+ let source_power = type_power(source_kind, source_width);
+ if target_power < source_power {
+ return None;
+ }
+
+ Some(
+ match ((target_kind, target_width), (source_kind, source_width)) {
+ // A conversion from a float to a double is special
+ ((Float, 8), (Float, 4)) => Conversion::FloatToDouble,
+ // A conversion from an integer to a float is special
+ ((Float, 4), (Sint | Uint, _)) => Conversion::IntToFloat,
+ // A conversion from an integer to a double is special
+ ((Float, 8), (Sint | Uint, _)) => Conversion::IntToDouble,
+ _ => Conversion::Other,
+ },
+ )
+}
+
+/// Helper method returning all the non standard builtin variations needed
+/// to process the function call with the passed arguments
+fn builtin_required_variations<'a>(args: impl Iterator<Item = &'a TypeInner>) -> BuiltinVariations {
+ let mut variations = BuiltinVariations::empty();
+
+ for ty in args {
+ match *ty {
+ TypeInner::ValuePointer { kind, width, .. }
+ | TypeInner::Scalar { kind, width }
+ | TypeInner::Vector { kind, width, .. } => {
+ if kind == ScalarKind::Float && width == 8 {
+ variations |= BuiltinVariations::DOUBLE
+ }
+ }
+ TypeInner::Matrix { width, .. } => {
+ if width == 8 {
+ variations |= BuiltinVariations::DOUBLE
+ }
+ }
+ TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ if dim == crate::ImageDimension::Cube && arrayed {
+ variations |= BuiltinVariations::CUBE_TEXTURES_ARRAY
+ }
+
+ if dim == crate::ImageDimension::D2 && arrayed && class.is_multisampled() {
+ variations |= BuiltinVariations::D2_MULTI_TEXTURES_ARRAY
+ }
+ }
+ _ => {}
+ }
+ }
+
+ variations
+}
diff --git a/third_party/rust/naga/src/front/glsl/lex.rs b/third_party/rust/naga/src/front/glsl/lex.rs
new file mode 100644
index 0000000000..1b59a9bf3e
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/lex.rs
@@ -0,0 +1,301 @@
+use super::{
+ ast::Precision,
+ token::{Directive, DirectiveKind, Token, TokenValue},
+ types::parse_type,
+};
+use crate::{FastHashMap, Span, StorageAccess};
+use pp_rs::{
+ pp::Preprocessor,
+ token::{PreprocessorError, Punct, TokenValue as PPTokenValue},
+};
+
+#[derive(Debug)]
+#[cfg_attr(test, derive(PartialEq))]
+pub struct LexerResult {
+ pub kind: LexerResultKind,
+ pub meta: Span,
+}
+
+#[derive(Debug)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum LexerResultKind {
+ Token(Token),
+ Directive(Directive),
+ Error(PreprocessorError),
+}
+
+pub struct Lexer<'a> {
+ pp: Preprocessor<'a>,
+}
+
+impl<'a> Lexer<'a> {
+ pub fn new(input: &'a str, defines: &'a FastHashMap<String, String>) -> Self {
+ let mut pp = Preprocessor::new(input);
+ for (define, value) in defines {
+ pp.add_define(define, value).unwrap(); //TODO: handle error
+ }
+ Lexer { pp }
+ }
+}
+
+impl<'a> Iterator for Lexer<'a> {
+ type Item = LexerResult;
+ fn next(&mut self) -> Option<Self::Item> {
+ let pp_token = match self.pp.next()? {
+ Ok(t) => t,
+ Err((err, loc)) => {
+ return Some(LexerResult {
+ kind: LexerResultKind::Error(err),
+ meta: loc.into(),
+ });
+ }
+ };
+
+ let meta = pp_token.location.into();
+ let value = match pp_token.value {
+ PPTokenValue::Extension(extension) => {
+ return Some(LexerResult {
+ kind: LexerResultKind::Directive(Directive {
+ kind: DirectiveKind::Extension,
+ tokens: extension.tokens,
+ }),
+ meta,
+ })
+ }
+ PPTokenValue::Float(float) => TokenValue::FloatConstant(float),
+ PPTokenValue::Ident(ident) => {
+ match ident.as_str() {
+ // Qualifiers
+ "layout" => TokenValue::Layout,
+ "in" => TokenValue::In,
+ "out" => TokenValue::Out,
+ "uniform" => TokenValue::Uniform,
+ "buffer" => TokenValue::Buffer,
+ "shared" => TokenValue::Shared,
+ "invariant" => TokenValue::Invariant,
+ "flat" => TokenValue::Interpolation(crate::Interpolation::Flat),
+ "noperspective" => TokenValue::Interpolation(crate::Interpolation::Linear),
+ "smooth" => TokenValue::Interpolation(crate::Interpolation::Perspective),
+ "centroid" => TokenValue::Sampling(crate::Sampling::Centroid),
+ "sample" => TokenValue::Sampling(crate::Sampling::Sample),
+ "const" => TokenValue::Const,
+ "inout" => TokenValue::InOut,
+ "precision" => TokenValue::Precision,
+ "highp" => TokenValue::PrecisionQualifier(Precision::High),
+ "mediump" => TokenValue::PrecisionQualifier(Precision::Medium),
+ "lowp" => TokenValue::PrecisionQualifier(Precision::Low),
+ "restrict" => TokenValue::Restrict,
+ "readonly" => TokenValue::MemoryQualifier(StorageAccess::LOAD),
+ "writeonly" => TokenValue::MemoryQualifier(StorageAccess::STORE),
+ // values
+ "true" => TokenValue::BoolConstant(true),
+ "false" => TokenValue::BoolConstant(false),
+ // jump statements
+ "continue" => TokenValue::Continue,
+ "break" => TokenValue::Break,
+ "return" => TokenValue::Return,
+ "discard" => TokenValue::Discard,
+ // selection statements
+ "if" => TokenValue::If,
+ "else" => TokenValue::Else,
+ "switch" => TokenValue::Switch,
+ "case" => TokenValue::Case,
+ "default" => TokenValue::Default,
+ // iteration statements
+ "while" => TokenValue::While,
+ "do" => TokenValue::Do,
+ "for" => TokenValue::For,
+ // types
+ "void" => TokenValue::Void,
+ "struct" => TokenValue::Struct,
+ word => match parse_type(word) {
+ Some(t) => TokenValue::TypeName(t),
+ None => TokenValue::Identifier(String::from(word)),
+ },
+ }
+ }
+ PPTokenValue::Integer(integer) => TokenValue::IntConstant(integer),
+ PPTokenValue::Punct(punct) => match punct {
+ // Compound assignments
+ Punct::AddAssign => TokenValue::AddAssign,
+ Punct::SubAssign => TokenValue::SubAssign,
+ Punct::MulAssign => TokenValue::MulAssign,
+ Punct::DivAssign => TokenValue::DivAssign,
+ Punct::ModAssign => TokenValue::ModAssign,
+ Punct::LeftShiftAssign => TokenValue::LeftShiftAssign,
+ Punct::RightShiftAssign => TokenValue::RightShiftAssign,
+ Punct::AndAssign => TokenValue::AndAssign,
+ Punct::XorAssign => TokenValue::XorAssign,
+ Punct::OrAssign => TokenValue::OrAssign,
+
+ // Two character punctuation
+ Punct::Increment => TokenValue::Increment,
+ Punct::Decrement => TokenValue::Decrement,
+ Punct::LogicalAnd => TokenValue::LogicalAnd,
+ Punct::LogicalOr => TokenValue::LogicalOr,
+ Punct::LogicalXor => TokenValue::LogicalXor,
+ Punct::LessEqual => TokenValue::LessEqual,
+ Punct::GreaterEqual => TokenValue::GreaterEqual,
+ Punct::EqualEqual => TokenValue::Equal,
+ Punct::NotEqual => TokenValue::NotEqual,
+ Punct::LeftShift => TokenValue::LeftShift,
+ Punct::RightShift => TokenValue::RightShift,
+
+ // Parenthesis or similar
+ Punct::LeftBrace => TokenValue::LeftBrace,
+ Punct::RightBrace => TokenValue::RightBrace,
+ Punct::LeftParen => TokenValue::LeftParen,
+ Punct::RightParen => TokenValue::RightParen,
+ Punct::LeftBracket => TokenValue::LeftBracket,
+ Punct::RightBracket => TokenValue::RightBracket,
+
+ // Other one character punctuation
+ Punct::LeftAngle => TokenValue::LeftAngle,
+ Punct::RightAngle => TokenValue::RightAngle,
+ Punct::Semicolon => TokenValue::Semicolon,
+ Punct::Comma => TokenValue::Comma,
+ Punct::Colon => TokenValue::Colon,
+ Punct::Dot => TokenValue::Dot,
+ Punct::Equal => TokenValue::Assign,
+ Punct::Bang => TokenValue::Bang,
+ Punct::Minus => TokenValue::Dash,
+ Punct::Tilde => TokenValue::Tilde,
+ Punct::Plus => TokenValue::Plus,
+ Punct::Star => TokenValue::Star,
+ Punct::Slash => TokenValue::Slash,
+ Punct::Percent => TokenValue::Percent,
+ Punct::Pipe => TokenValue::VerticalBar,
+ Punct::Caret => TokenValue::Caret,
+ Punct::Ampersand => TokenValue::Ampersand,
+ Punct::Question => TokenValue::Question,
+ },
+ PPTokenValue::Pragma(pragma) => {
+ return Some(LexerResult {
+ kind: LexerResultKind::Directive(Directive {
+ kind: DirectiveKind::Pragma,
+ tokens: pragma.tokens,
+ }),
+ meta,
+ })
+ }
+ PPTokenValue::Version(version) => {
+ return Some(LexerResult {
+ kind: LexerResultKind::Directive(Directive {
+ kind: DirectiveKind::Version {
+ is_first_directive: version.is_first_directive,
+ },
+ tokens: version.tokens,
+ }),
+ meta,
+ })
+ }
+ };
+
+ Some(LexerResult {
+ kind: LexerResultKind::Token(Token { value, meta }),
+ meta,
+ })
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use pp_rs::token::{Integer, Location, Token as PPToken, TokenValue as PPTokenValue};
+
+ use super::{
+ super::token::{Directive, DirectiveKind, Token, TokenValue},
+ Lexer, LexerResult, LexerResultKind,
+ };
+ use crate::Span;
+
+ #[test]
+ fn lex_tokens() {
+ let defines = crate::FastHashMap::default();
+
+ // line comments
+ let mut lex = Lexer::new("#version 450\nvoid main () {}", &defines);
+ let mut location = Location::default();
+ location.start = 9;
+ location.end = 12;
+ assert_eq!(
+ lex.next().unwrap(),
+ LexerResult {
+ kind: LexerResultKind::Directive(Directive {
+ kind: DirectiveKind::Version {
+ is_first_directive: true
+ },
+ tokens: vec![PPToken {
+ value: PPTokenValue::Integer(Integer {
+ signed: true,
+ value: 450,
+ width: 32
+ }),
+ location
+ }]
+ }),
+ meta: Span::new(1, 8)
+ }
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ LexerResult {
+ kind: LexerResultKind::Token(Token {
+ value: TokenValue::Void,
+ meta: Span::new(13, 17)
+ }),
+ meta: Span::new(13, 17)
+ }
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ LexerResult {
+ kind: LexerResultKind::Token(Token {
+ value: TokenValue::Identifier("main".into()),
+ meta: Span::new(18, 22)
+ }),
+ meta: Span::new(18, 22)
+ }
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ LexerResult {
+ kind: LexerResultKind::Token(Token {
+ value: TokenValue::LeftParen,
+ meta: Span::new(23, 24)
+ }),
+ meta: Span::new(23, 24)
+ }
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ LexerResult {
+ kind: LexerResultKind::Token(Token {
+ value: TokenValue::RightParen,
+ meta: Span::new(24, 25)
+ }),
+ meta: Span::new(24, 25)
+ }
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ LexerResult {
+ kind: LexerResultKind::Token(Token {
+ value: TokenValue::LeftBrace,
+ meta: Span::new(26, 27)
+ }),
+ meta: Span::new(26, 27)
+ }
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ LexerResult {
+ kind: LexerResultKind::Token(Token {
+ value: TokenValue::RightBrace,
+ meta: Span::new(27, 28)
+ }),
+ meta: Span::new(27, 28)
+ }
+ );
+ assert_eq!(lex.next(), None);
+ }
+}
diff --git a/third_party/rust/naga/src/front/glsl/mod.rs b/third_party/rust/naga/src/front/glsl/mod.rs
new file mode 100644
index 0000000000..a1aa4622d3
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/mod.rs
@@ -0,0 +1,235 @@
+/*!
+Frontend for [GLSL][glsl] (OpenGL Shading Language).
+
+To begin, take a look at the documentation for the [`Frontend`].
+
+# Supported versions
+## Vulkan
+- 440 (partial)
+- 450
+- 460
+
+[glsl]: https://www.khronos.org/registry/OpenGL/index_gl.php
+*/
+
+pub use ast::{Precision, Profile};
+pub use error::{Error, ErrorKind, ExpectedToken};
+pub use token::TokenValue;
+
+use crate::{proc::Layouter, FastHashMap, FastHashSet, Handle, Module, ShaderStage, Span, Type};
+use ast::{EntryArg, FunctionDeclaration, GlobalLookup};
+use parser::ParsingContext;
+
+mod ast;
+mod builtins;
+mod constants;
+mod context;
+mod error;
+mod functions;
+mod lex;
+mod offset;
+mod parser;
+#[cfg(test)]
+mod parser_tests;
+mod token;
+mod types;
+mod variables;
+
+type Result<T> = std::result::Result<T, Error>;
+
+/// Per-shader options passed to [`parse`](Frontend::parse).
+///
+/// The [`From`](From) trait is implemented for [`ShaderStage`](ShaderStage) to
+/// provide a quick way to create a Options instance.
+/// ```rust
+/// # use naga::ShaderStage;
+/// # use naga::front::glsl::Options;
+/// Options::from(ShaderStage::Vertex);
+/// ```
+#[derive(Debug)]
+pub struct Options {
+ /// The shader stage in the pipeline.
+ pub stage: ShaderStage,
+ /// Preprocesor definitions to be used, akin to having
+ /// ```glsl
+ /// #define key value
+ /// ```
+ /// for each key value pair in the map.
+ pub defines: FastHashMap<String, String>,
+}
+
+impl From<ShaderStage> for Options {
+ fn from(stage: ShaderStage) -> Self {
+ Options {
+ stage,
+ defines: FastHashMap::default(),
+ }
+ }
+}
+
+/// Additional information about the GLSL shader.
+///
+/// Stores additional information about the GLSL shader which might not be
+/// stored in the shader [`Module`](Module).
+#[derive(Debug)]
+pub struct ShaderMetadata {
+ /// The GLSL version specified in the shader through the use of the
+ /// `#version` preprocessor directive.
+ pub version: u16,
+ /// The GLSL profile specified in the shader through the use of the
+ /// `#version` preprocessor directive.
+ pub profile: Profile,
+ /// The shader stage in the pipeline, passed to the [`parse`](Frontend::parse)
+ /// method via the [`Options`](Options) struct.
+ pub stage: ShaderStage,
+
+ /// The workgroup size for compute shaders, defaults to `[1; 3]` for
+ /// compute shaders and `[0; 3]` for non compute shaders.
+ pub workgroup_size: [u32; 3],
+ /// Whether or not early fragment tests where requested by the shader.
+ /// Defaults to `false`.
+ pub early_fragment_tests: bool,
+
+ /// The shader can request extensions via the
+ /// `#extension` preprocessor directive, in the directive a behavior
+ /// parameter is used to control whether the extension should be disabled,
+ /// warn on usage, enabled if possible or required.
+ ///
+ /// This field only stores extensions which were required or requested to
+ /// be enabled if possible and they are supported.
+ pub extensions: FastHashSet<String>,
+}
+
+impl ShaderMetadata {
+ fn reset(&mut self, stage: ShaderStage) {
+ self.version = 0;
+ self.profile = Profile::Core;
+ self.stage = stage;
+ self.workgroup_size = [u32::from(stage == ShaderStage::Compute); 3];
+ self.early_fragment_tests = false;
+ self.extensions.clear();
+ }
+}
+
+impl Default for ShaderMetadata {
+ fn default() -> Self {
+ ShaderMetadata {
+ version: 0,
+ profile: Profile::Core,
+ stage: ShaderStage::Vertex,
+ workgroup_size: [0; 3],
+ early_fragment_tests: false,
+ extensions: FastHashSet::default(),
+ }
+ }
+}
+
+/// The `Frontend` is the central structure of the GLSL frontend.
+///
+/// To instantiate a new `Frontend` the [`Default`](Default) trait is used, so a
+/// call to the associated function [`Frontend::default`](Frontend::default) will
+/// return a new `Frontend` instance.
+///
+/// To parse a shader simply call the [`parse`](Frontend::parse) method with a
+/// [`Options`](Options) struct and a [`&str`](str) holding the glsl code.
+///
+/// The `Frontend` also provides the [`metadata`](Frontend::metadata) to get some
+/// further information about the previously parsed shader, like version and
+/// extensions used (see the documentation for
+/// [`ShaderMetadata`](ShaderMetadata) to see all the returned information)
+///
+/// # Example usage
+/// ```rust
+/// use naga::ShaderStage;
+/// use naga::front::glsl::{Frontend, Options};
+///
+/// let glsl = r#"
+/// #version 450 core
+///
+/// void main() {}
+/// "#;
+///
+/// let mut frontend = Frontend::default();
+/// let options = Options::from(ShaderStage::Vertex);
+/// frontend.parse(&options, glsl);
+/// ```
+///
+/// # Reusability
+///
+/// If there's a need to parse more than one shader reusing the same `Frontend`
+/// instance may be beneficial since internal allocations will be reused.
+///
+/// Calling the [`parse`](Frontend::parse) method multiple times will reset the
+/// `Frontend` so no extra care is needed when reusing.
+#[derive(Debug, Default)]
+pub struct Frontend {
+ meta: ShaderMetadata,
+
+ lookup_function: FastHashMap<String, FunctionDeclaration>,
+ lookup_type: FastHashMap<String, Handle<Type>>,
+
+ global_variables: Vec<(String, GlobalLookup)>,
+
+ entry_args: Vec<EntryArg>,
+
+ layouter: Layouter,
+
+ errors: Vec<Error>,
+
+ module: Module,
+}
+
+impl Frontend {
+ fn reset(&mut self, stage: ShaderStage) {
+ self.meta.reset(stage);
+
+ self.lookup_function.clear();
+ self.lookup_type.clear();
+ self.global_variables.clear();
+ self.entry_args.clear();
+ self.layouter.clear();
+
+ // This is necessary because if the last parsing errored out, the module
+ // wouldn't have been taken
+ self.module = Module::default();
+ }
+
+ /// Parses a shader either outputting a shader [`Module`](Module) or a list
+ /// of [`Error`](Error)s.
+ ///
+ /// Multiple calls using the same `Frontend` and different shaders are supported.
+ pub fn parse(
+ &mut self,
+ options: &Options,
+ source: &str,
+ ) -> std::result::Result<Module, Vec<Error>> {
+ self.reset(options.stage);
+
+ let lexer = lex::Lexer::new(source, &options.defines);
+ let mut ctx = ParsingContext::new(lexer);
+
+ if let Err(e) = ctx.parse(self) {
+ self.errors.push(e);
+ }
+
+ if self.errors.is_empty() {
+ Ok(std::mem::take(&mut self.module))
+ } else {
+ Err(std::mem::take(&mut self.errors))
+ }
+ }
+
+ /// Returns additional information about the parsed shader which might not be
+ /// stored in the [`Module`](Module), see the documentation for
+ /// [`ShaderMetadata`](ShaderMetadata) for more information about the
+ /// returned data.
+ ///
+ /// # Notes
+ ///
+ /// Following an unsuccessful parsing the state of the returned information
+ /// is undefined, it might contain only partial information about the
+ /// current shader, the previous shader or both.
+ pub const fn metadata(&self) -> &ShaderMetadata {
+ &self.meta
+ }
+}
diff --git a/third_party/rust/naga/src/front/glsl/offset.rs b/third_party/rust/naga/src/front/glsl/offset.rs
new file mode 100644
index 0000000000..d7950489cb
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/offset.rs
@@ -0,0 +1,173 @@
+/*!
+Module responsible for calculating the offset and span for types.
+
+There exists two types of layouts std140 and std430 (there's technically
+two more layouts, shared and packed. Shared is not supported by spirv. Packed is
+implementation dependent and for now it's just implemented as an alias to
+std140).
+
+The OpenGl spec (the layout rules are defined by the OpenGl spec in section
+7.6.2.2 as opposed to the GLSL spec) uses the term basic machine units which are
+equivalent to bytes.
+*/
+
+use super::{
+ ast::StructLayout,
+ error::{Error, ErrorKind},
+ Span,
+};
+use crate::{proc::Alignment, Arena, Constant, Handle, Type, TypeInner, UniqueArena};
+
+/// Struct with information needed for defining a struct member.
+///
+/// Returned by [`calculate_offset`](calculate_offset)
+#[derive(Debug)]
+pub struct TypeAlignSpan {
+ /// The handle to the type, this might be the same handle passed to
+ /// [`calculate_offset`](calculate_offset) or a new such a new array type
+ /// with a different stride set.
+ pub ty: Handle<Type>,
+ /// The alignment required by the type.
+ pub align: Alignment,
+ /// The size of the type.
+ pub span: u32,
+}
+
+/// Returns the type, alignment and span of a struct member according to a [`StructLayout`](StructLayout).
+///
+/// The functions returns a [`TypeAlignSpan`](TypeAlignSpan) which has a `ty` member
+/// this should be used as the struct member type because for example arrays may have to
+/// change the stride and as such need to have a different type.
+pub fn calculate_offset(
+ mut ty: Handle<Type>,
+ meta: Span,
+ layout: StructLayout,
+ types: &mut UniqueArena<Type>,
+ constants: &Arena<Constant>,
+ errors: &mut Vec<Error>,
+) -> TypeAlignSpan {
+ // When using the std430 storage layout, shader storage blocks will be laid out in buffer storage
+ // identically to uniform and shader storage blocks using the std140 layout, except
+ // that the base alignment and stride of arrays of scalars and vectors in rule 4 and of
+ // structures in rule 9 are not rounded up a multiple of the base alignment of a vec4.
+
+ let (align, span) = match types[ty].inner {
+ // 1. If the member is a scalar consuming N basic machine units,
+ // the base alignment is N.
+ TypeInner::Scalar { width, .. } => (Alignment::from_width(width), width as u32),
+ // 2. If the member is a two- or four-component vector with components
+ // consuming N basic machine units, the base alignment is 2N or 4N, respectively.
+ // 3. If the member is a three-component vector with components consuming N
+ // basic machine units, the base alignment is 4N.
+ TypeInner::Vector { size, width, .. } => (
+ Alignment::from(size) * Alignment::from_width(width),
+ size as u32 * width as u32,
+ ),
+ // 4. If the member is an array of scalars or vectors, the base alignment and array
+ // stride are set to match the base alignment of a single array element, according
+ // to rules (1), (2), and (3), and rounded up to the base alignment of a vec4.
+ // TODO: Matrices array
+ TypeInner::Array { base, size, .. } => {
+ let info = calculate_offset(base, meta, layout, types, constants, errors);
+
+ let name = types[ty].name.clone();
+
+ // See comment at the beginning of the function
+ let (align, stride) = if StructLayout::Std430 == layout {
+ (info.align, info.align.round_up(info.span))
+ } else {
+ let align = info.align.max(Alignment::MIN_UNIFORM);
+ (align, align.round_up(info.span))
+ };
+
+ let span = match size {
+ crate::ArraySize::Constant(s) => {
+ constants[s].to_array_length().unwrap_or(1) * stride
+ }
+ crate::ArraySize::Dynamic => stride,
+ };
+
+ let ty_span = types.get_span(ty);
+ ty = types.insert(
+ Type {
+ name,
+ inner: TypeInner::Array {
+ base: info.ty,
+ size,
+ stride,
+ },
+ },
+ ty_span,
+ );
+
+ (align, span)
+ }
+ // 5. If the member is a column-major matrix with C columns and R rows, the
+ // matrix is stored identically to an array of C column vectors with R
+ // components each, according to rule (4)
+ // TODO: Row major matrices
+ TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ let mut align = Alignment::from(rows) * Alignment::from_width(width);
+
+ // See comment at the beginning of the function
+ if StructLayout::Std430 != layout {
+ align = align.max(Alignment::MIN_UNIFORM);
+ }
+
+ // See comment on the error kind
+ if StructLayout::Std140 == layout && rows == crate::VectorSize::Bi {
+ errors.push(Error {
+ kind: ErrorKind::UnsupportedMatrixTypeInStd140,
+ meta,
+ });
+ }
+
+ (align, align * columns as u32)
+ }
+ TypeInner::Struct { ref members, .. } => {
+ let mut span = 0;
+ let mut align = Alignment::ONE;
+ let mut members = members.clone();
+ let name = types[ty].name.clone();
+
+ for member in members.iter_mut() {
+ let info = calculate_offset(member.ty, meta, layout, types, constants, errors);
+
+ let member_alignment = info.align;
+ span = member_alignment.round_up(span);
+ align = member_alignment.max(align);
+
+ member.ty = info.ty;
+ member.offset = span;
+
+ span += info.span;
+ }
+
+ span = align.round_up(span);
+
+ let ty_span = types.get_span(ty);
+ ty = types.insert(
+ Type {
+ name,
+ inner: TypeInner::Struct { members, span },
+ },
+ ty_span,
+ );
+
+ (align, span)
+ }
+ _ => {
+ errors.push(Error {
+ kind: ErrorKind::SemanticError("Invalid struct member type".into()),
+ meta,
+ });
+ (Alignment::ONE, 0)
+ }
+ };
+
+ TypeAlignSpan { ty, align, span }
+}
diff --git a/third_party/rust/naga/src/front/glsl/parser.rs b/third_party/rust/naga/src/front/glsl/parser.rs
new file mode 100644
index 0000000000..ed139e799d
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/parser.rs
@@ -0,0 +1,448 @@
+use super::{
+ ast::{FunctionKind, Profile, TypeQualifiers},
+ context::{Context, ExprPos},
+ error::ExpectedToken,
+ error::{Error, ErrorKind},
+ lex::{Lexer, LexerResultKind},
+ token::{Directive, DirectiveKind},
+ token::{Token, TokenValue},
+ variables::{GlobalOrConstant, VarDeclaration},
+ Frontend, Result,
+};
+use crate::{arena::Handle, Block, Constant, ConstantInner, Expression, ScalarValue, Span, Type};
+use core::convert::TryFrom;
+use pp_rs::token::{PreprocessorError, Token as PPToken, TokenValue as PPTokenValue};
+use std::iter::Peekable;
+
+mod declarations;
+mod expressions;
+mod functions;
+mod types;
+
+pub struct ParsingContext<'source> {
+ lexer: Peekable<Lexer<'source>>,
+ /// Used to store tokens already consumed by the parser but that need to be backtracked
+ backtracked_token: Option<Token>,
+ last_meta: Span,
+}
+
+impl<'source> ParsingContext<'source> {
+ pub fn new(lexer: Lexer<'source>) -> Self {
+ ParsingContext {
+ lexer: lexer.peekable(),
+ backtracked_token: None,
+ last_meta: Span::default(),
+ }
+ }
+
+ /// Helper method for backtracking from a consumed token
+ ///
+ /// This method should always be used instead of assigning to `backtracked_token` since
+ /// it validates that backtracking hasn't occurred more than one time in a row
+ ///
+ /// # Panics
+ /// - If the parser already backtracked without bumping in between
+ pub fn backtrack(&mut self, token: Token) -> Result<()> {
+ // This should never happen
+ if let Some(ref prev_token) = self.backtracked_token {
+ return Err(Error {
+ kind: ErrorKind::InternalError("The parser tried to backtrack twice in a row"),
+ meta: prev_token.meta,
+ });
+ }
+
+ self.backtracked_token = Some(token);
+
+ Ok(())
+ }
+
+ pub fn expect_ident(&mut self, frontend: &mut Frontend) -> Result<(String, Span)> {
+ let token = self.bump(frontend)?;
+
+ match token.value {
+ TokenValue::Identifier(name) => Ok((name, token.meta)),
+ _ => Err(Error {
+ kind: ErrorKind::InvalidToken(token.value, vec![ExpectedToken::Identifier]),
+ meta: token.meta,
+ }),
+ }
+ }
+
+ pub fn expect(&mut self, frontend: &mut Frontend, value: TokenValue) -> Result<Token> {
+ let token = self.bump(frontend)?;
+
+ if token.value != value {
+ Err(Error {
+ kind: ErrorKind::InvalidToken(token.value, vec![value.into()]),
+ meta: token.meta,
+ })
+ } else {
+ Ok(token)
+ }
+ }
+
+ pub fn next(&mut self, frontend: &mut Frontend) -> Option<Token> {
+ loop {
+ if let Some(token) = self.backtracked_token.take() {
+ self.last_meta = token.meta;
+ break Some(token);
+ }
+
+ let res = self.lexer.next()?;
+
+ match res.kind {
+ LexerResultKind::Token(token) => {
+ self.last_meta = token.meta;
+ break Some(token);
+ }
+ LexerResultKind::Directive(directive) => {
+ frontend.handle_directive(directive, res.meta)
+ }
+ LexerResultKind::Error(error) => frontend.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(error),
+ meta: res.meta,
+ }),
+ }
+ }
+ }
+
+ pub fn bump(&mut self, frontend: &mut Frontend) -> Result<Token> {
+ self.next(frontend).ok_or(Error {
+ kind: ErrorKind::EndOfFile,
+ meta: self.last_meta,
+ })
+ }
+
+ /// Returns None on the end of the file rather than an error like other methods
+ pub fn bump_if(&mut self, frontend: &mut Frontend, value: TokenValue) -> Option<Token> {
+ if self.peek(frontend).filter(|t| t.value == value).is_some() {
+ self.bump(frontend).ok()
+ } else {
+ None
+ }
+ }
+
+ pub fn peek(&mut self, frontend: &mut Frontend) -> Option<&Token> {
+ loop {
+ if let Some(ref token) = self.backtracked_token {
+ break Some(token);
+ }
+
+ match self.lexer.peek()?.kind {
+ LexerResultKind::Token(_) => {
+ let res = self.lexer.peek()?;
+
+ match res.kind {
+ LexerResultKind::Token(ref token) => break Some(token),
+ _ => unreachable!(),
+ }
+ }
+ LexerResultKind::Error(_) | LexerResultKind::Directive(_) => {
+ let res = self.lexer.next()?;
+
+ match res.kind {
+ LexerResultKind::Directive(directive) => {
+ frontend.handle_directive(directive, res.meta)
+ }
+ LexerResultKind::Error(error) => frontend.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(error),
+ meta: res.meta,
+ }),
+ LexerResultKind::Token(_) => unreachable!(),
+ }
+ }
+ }
+ }
+ }
+
+ pub fn expect_peek(&mut self, frontend: &mut Frontend) -> Result<&Token> {
+ let meta = self.last_meta;
+ self.peek(frontend).ok_or(Error {
+ kind: ErrorKind::EndOfFile,
+ meta,
+ })
+ }
+
+ pub fn parse(&mut self, frontend: &mut Frontend) -> Result<()> {
+ // Body and expression arena for global initialization
+ let mut body = Block::new();
+ let mut ctx = Context::new(frontend, &mut body);
+
+ while self.peek(frontend).is_some() {
+ self.parse_external_declaration(frontend, &mut ctx, &mut body)?;
+ }
+
+ // Add an `EntryPoint` to `parser.module` for `main`, if a
+ // suitable overload exists. Error out if we can't find one.
+ if let Some(declaration) = frontend.lookup_function.get("main") {
+ for decl in declaration.overloads.iter() {
+ if let FunctionKind::Call(handle) = decl.kind {
+ if decl.defined && decl.parameters.is_empty() {
+ frontend.add_entry_point(handle, body, ctx.expressions);
+ return Ok(());
+ }
+ }
+ }
+ }
+
+ Err(Error {
+ kind: ErrorKind::SemanticError("Missing entry point".into()),
+ meta: Span::default(),
+ })
+ }
+
+ fn parse_uint_constant(&mut self, frontend: &mut Frontend) -> Result<(u32, Span)> {
+ let (value, meta) = self.parse_constant_expression(frontend)?;
+
+ let int = match frontend.module.constants[value].inner {
+ ConstantInner::Scalar {
+ value: ScalarValue::Uint(int),
+ ..
+ } => u32::try_from(int).map_err(|_| Error {
+ kind: ErrorKind::SemanticError("int constant overflows".into()),
+ meta,
+ })?,
+ ConstantInner::Scalar {
+ value: ScalarValue::Sint(int),
+ ..
+ } => u32::try_from(int).map_err(|_| Error {
+ kind: ErrorKind::SemanticError("int constant overflows".into()),
+ meta,
+ })?,
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::SemanticError("Expected a uint constant".into()),
+ meta,
+ })
+ }
+ };
+
+ Ok((int, meta))
+ }
+
+ fn parse_constant_expression(
+ &mut self,
+ frontend: &mut Frontend,
+ ) -> Result<(Handle<Constant>, Span)> {
+ let mut block = Block::new();
+
+ let mut ctx = Context::new(frontend, &mut block);
+
+ let mut stmt_ctx = ctx.stmt_ctx();
+ let expr = self.parse_conditional(frontend, &mut ctx, &mut stmt_ctx, &mut block, None)?;
+ let (root, meta) = ctx.lower_expect(stmt_ctx, frontend, expr, ExprPos::Rhs, &mut block)?;
+
+ Ok((frontend.solve_constant(&ctx, root, meta)?, meta))
+ }
+}
+
+impl Frontend {
+ fn handle_directive(&mut self, directive: Directive, meta: Span) {
+ let mut tokens = directive.tokens.into_iter();
+
+ match directive.kind {
+ DirectiveKind::Version { is_first_directive } => {
+ if !is_first_directive {
+ self.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "#version must occur first in shader".into(),
+ ),
+ meta,
+ })
+ }
+
+ match tokens.next() {
+ Some(PPToken {
+ value: PPTokenValue::Integer(int),
+ location,
+ }) => match int.value {
+ 440 | 450 | 460 => self.meta.version = int.value as u16,
+ _ => self.errors.push(Error {
+ kind: ErrorKind::InvalidVersion(int.value),
+ meta: location.into(),
+ }),
+ },
+ Some(PPToken { value, location }) => self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken(
+ value,
+ )),
+ meta: location.into(),
+ }),
+ None => self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedNewLine),
+ meta,
+ }),
+ };
+
+ match tokens.next() {
+ Some(PPToken {
+ value: PPTokenValue::Ident(name),
+ location,
+ }) => match name.as_str() {
+ "core" => self.meta.profile = Profile::Core,
+ _ => self.errors.push(Error {
+ kind: ErrorKind::InvalidProfile(name),
+ meta: location.into(),
+ }),
+ },
+ Some(PPToken { value, location }) => self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken(
+ value,
+ )),
+ meta: location.into(),
+ }),
+ None => {}
+ };
+
+ if let Some(PPToken { value, location }) = tokens.next() {
+ self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken(
+ value,
+ )),
+ meta: location.into(),
+ })
+ }
+ }
+ DirectiveKind::Extension => {
+ // TODO: Proper extension handling
+ // - Checking for extension support in the compiler
+ // - Handle behaviors such as warn
+ // - Handle the all extension
+ let name = match tokens.next() {
+ Some(PPToken {
+ value: PPTokenValue::Ident(name),
+ ..
+ }) => Some(name),
+ Some(PPToken { value, location }) => {
+ self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken(
+ value,
+ )),
+ meta: location.into(),
+ });
+
+ None
+ }
+ None => {
+ self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(
+ PreprocessorError::UnexpectedNewLine,
+ ),
+ meta,
+ });
+
+ None
+ }
+ };
+
+ match tokens.next() {
+ Some(PPToken {
+ value: PPTokenValue::Punct(pp_rs::token::Punct::Colon),
+ ..
+ }) => {}
+ Some(PPToken { value, location }) => self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken(
+ value,
+ )),
+ meta: location.into(),
+ }),
+ None => self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedNewLine),
+ meta,
+ }),
+ };
+
+ match tokens.next() {
+ Some(PPToken {
+ value: PPTokenValue::Ident(behavior),
+ location,
+ }) => match behavior.as_str() {
+ "require" | "enable" | "warn" | "disable" => {
+ if let Some(name) = name {
+ self.meta.extensions.insert(name);
+ }
+ }
+ _ => self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken(
+ PPTokenValue::Ident(behavior),
+ )),
+ meta: location.into(),
+ }),
+ },
+ Some(PPToken { value, location }) => self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken(
+ value,
+ )),
+ meta: location.into(),
+ }),
+ None => self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedNewLine),
+ meta,
+ }),
+ }
+
+ if let Some(PPToken { value, location }) = tokens.next() {
+ self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken(
+ value,
+ )),
+ meta: location.into(),
+ })
+ }
+ }
+ DirectiveKind::Pragma => {
+ // TODO: handle some common pragmas?
+ }
+ }
+ }
+}
+
+pub struct DeclarationContext<'ctx, 'qualifiers> {
+ qualifiers: TypeQualifiers<'qualifiers>,
+ /// Indicates a global declaration
+ external: bool,
+
+ ctx: &'ctx mut Context,
+ body: &'ctx mut Block,
+}
+
+impl<'ctx, 'qualifiers> DeclarationContext<'ctx, 'qualifiers> {
+ fn add_var(
+ &mut self,
+ frontend: &mut Frontend,
+ ty: Handle<Type>,
+ name: String,
+ init: Option<Handle<Constant>>,
+ meta: Span,
+ ) -> Result<Handle<Expression>> {
+ let decl = VarDeclaration {
+ qualifiers: &mut self.qualifiers,
+ ty,
+ name: Some(name),
+ init,
+ meta,
+ };
+
+ match self.external {
+ true => {
+ let global = frontend.add_global_var(self.ctx, self.body, decl)?;
+ let expr = match global {
+ GlobalOrConstant::Global(handle) => Expression::GlobalVariable(handle),
+ GlobalOrConstant::Constant(handle) => Expression::Constant(handle),
+ };
+ Ok(self.ctx.add_expression(expr, meta, self.body))
+ }
+ false => frontend.add_local_var(self.ctx, self.body, decl),
+ }
+ }
+
+ /// Emits all the expressions captured by the emitter and starts the emitter again
+ ///
+ /// Alias to [`emit_restart`] with the declaration body
+ ///
+ /// [`emit_restart`]: Context::emit_restart
+ #[inline]
+ fn flush_expressions(&mut self) {
+ self.ctx.emit_restart(self.body);
+ }
+}
diff --git a/third_party/rust/naga/src/front/glsl/parser/declarations.rs b/third_party/rust/naga/src/front/glsl/parser/declarations.rs
new file mode 100644
index 0000000000..9085469dda
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/parser/declarations.rs
@@ -0,0 +1,671 @@
+use crate::{
+ front::glsl::{
+ ast::{
+ GlobalLookup, GlobalLookupKind, Precision, QualifierKey, QualifierValue,
+ StorageQualifier, StructLayout, TypeQualifiers,
+ },
+ context::{Context, ExprPos},
+ error::ExpectedToken,
+ offset,
+ token::{Token, TokenValue},
+ types::scalar_components,
+ variables::{GlobalOrConstant, VarDeclaration},
+ Error, ErrorKind, Frontend, Span,
+ },
+ proc::Alignment,
+ AddressSpace, Block, Expression, FunctionResult, Handle, ScalarKind, Statement, StructMember,
+ Type, TypeInner,
+};
+
+use super::{DeclarationContext, ParsingContext, Result};
+
+/// Helper method used to retrieve the child type of `ty` at
+/// index `i`.
+///
+/// # Note
+///
+/// Does not check if the index is valid and returns the same type
+/// when indexing out-of-bounds a struct or indexing a non indexable
+/// type.
+fn element_or_member_type(
+ ty: Handle<Type>,
+ i: usize,
+ types: &mut crate::UniqueArena<Type>,
+) -> Handle<Type> {
+ match types[ty].inner {
+ // The child type of a vector is a scalar of the same kind and width
+ TypeInner::Vector { kind, width, .. } => types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Scalar { kind, width },
+ },
+ Default::default(),
+ ),
+ // The child type of a matrix is a vector of floats with the same
+ // width and the size of the matrix rows.
+ TypeInner::Matrix { rows, width, .. } => types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Vector {
+ size: rows,
+ kind: ScalarKind::Float,
+ width,
+ },
+ },
+ Default::default(),
+ ),
+ // The child type of an array is the base type of the array
+ TypeInner::Array { base, .. } => base,
+ // The child type of a struct at index `i` is the type of it's
+ // member at that same index.
+ //
+ // In case the index is out of bounds the same type is returned
+ TypeInner::Struct { ref members, .. } => {
+ members.get(i).map(|member| member.ty).unwrap_or(ty)
+ }
+ // The type isn't indexable, the same type is returned
+ _ => ty,
+ }
+}
+
+impl<'source> ParsingContext<'source> {
+ pub fn parse_external_declaration(
+ &mut self,
+ frontend: &mut Frontend,
+ global_ctx: &mut Context,
+ global_body: &mut Block,
+ ) -> Result<()> {
+ if self
+ .parse_declaration(frontend, global_ctx, global_body, true)?
+ .is_none()
+ {
+ let token = self.bump(frontend)?;
+ match token.value {
+ TokenValue::Semicolon if frontend.meta.version == 460 => Ok(()),
+ _ => {
+ let expected = match frontend.meta.version {
+ 460 => vec![TokenValue::Semicolon.into(), ExpectedToken::Eof],
+ _ => vec![ExpectedToken::Eof],
+ };
+ Err(Error {
+ kind: ErrorKind::InvalidToken(token.value, expected),
+ meta: token.meta,
+ })
+ }
+ }
+ } else {
+ Ok(())
+ }
+ }
+
+ pub fn parse_initializer(
+ &mut self,
+ frontend: &mut Frontend,
+ ty: Handle<Type>,
+ ctx: &mut Context,
+ body: &mut Block,
+ ) -> Result<(Handle<Expression>, Span)> {
+ // initializer:
+ // assignment_expression
+ // LEFT_BRACE initializer_list RIGHT_BRACE
+ // LEFT_BRACE initializer_list COMMA RIGHT_BRACE
+ //
+ // initializer_list:
+ // initializer
+ // initializer_list COMMA initializer
+ if let Some(Token { mut meta, .. }) = self.bump_if(frontend, TokenValue::LeftBrace) {
+ // initializer_list
+ let mut components = Vec::new();
+ loop {
+ // The type expected to be parsed inside the initializer list
+ let new_ty =
+ element_or_member_type(ty, components.len(), &mut frontend.module.types);
+
+ components.push(self.parse_initializer(frontend, new_ty, ctx, body)?.0);
+
+ let token = self.bump(frontend)?;
+ match token.value {
+ TokenValue::Comma => {
+ if let Some(Token { meta: end_meta, .. }) =
+ self.bump_if(frontend, TokenValue::RightBrace)
+ {
+ meta.subsume(end_meta);
+ break;
+ }
+ }
+ TokenValue::RightBrace => {
+ meta.subsume(token.meta);
+ break;
+ }
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::InvalidToken(
+ token.value,
+ vec![TokenValue::Comma.into(), TokenValue::RightBrace.into()],
+ ),
+ meta: token.meta,
+ })
+ }
+ }
+ }
+
+ Ok((
+ ctx.add_expression(Expression::Compose { ty, components }, meta, body),
+ meta,
+ ))
+ } else {
+ let mut stmt = ctx.stmt_ctx();
+ let expr = self.parse_assignment(frontend, ctx, &mut stmt, body)?;
+ let (mut init, init_meta) =
+ ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs, body)?;
+
+ let scalar_components = scalar_components(&frontend.module.types[ty].inner);
+ if let Some((kind, width)) = scalar_components {
+ ctx.implicit_conversion(frontend, &mut init, init_meta, kind, width)?;
+ }
+
+ Ok((init, init_meta))
+ }
+ }
+
+ // Note: caller preparsed the type and qualifiers
+ // Note: caller skips this if the fallthrough token is not expected to be consumed here so this
+ // produced Error::InvalidToken if it isn't consumed
+ pub fn parse_init_declarator_list(
+ &mut self,
+ frontend: &mut Frontend,
+ mut ty: Handle<Type>,
+ ctx: &mut DeclarationContext,
+ ) -> Result<()> {
+ // init_declarator_list:
+ // single_declaration
+ // init_declarator_list COMMA IDENTIFIER
+ // init_declarator_list COMMA IDENTIFIER array_specifier
+ // init_declarator_list COMMA IDENTIFIER array_specifier EQUAL initializer
+ // init_declarator_list COMMA IDENTIFIER EQUAL initializer
+ //
+ // single_declaration:
+ // fully_specified_type
+ // fully_specified_type IDENTIFIER
+ // fully_specified_type IDENTIFIER array_specifier
+ // fully_specified_type IDENTIFIER array_specifier EQUAL initializer
+ // fully_specified_type IDENTIFIER EQUAL initializer
+
+ // Consume any leading comma, e.g. this is valid: `float, a=1;`
+ if self
+ .peek(frontend)
+ .map_or(false, |t| t.value == TokenValue::Comma)
+ {
+ self.next(frontend);
+ }
+
+ loop {
+ let token = self.bump(frontend)?;
+ let name = match token.value {
+ TokenValue::Semicolon => break,
+ TokenValue::Identifier(name) => name,
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::InvalidToken(
+ token.value,
+ vec![ExpectedToken::Identifier, TokenValue::Semicolon.into()],
+ ),
+ meta: token.meta,
+ })
+ }
+ };
+ let mut meta = token.meta;
+
+ // array_specifier
+ // array_specifier EQUAL initializer
+ // EQUAL initializer
+
+ // parse an array specifier if it exists
+ // NOTE: unlike other parse methods this one doesn't expect an array specifier and
+ // returns Ok(None) rather than an error if there is not one
+ self.parse_array_specifier(frontend, &mut meta, &mut ty)?;
+
+ let init = self
+ .bump_if(frontend, TokenValue::Assign)
+ .map::<Result<_>, _>(|_| {
+ let (mut expr, init_meta) =
+ self.parse_initializer(frontend, ty, ctx.ctx, ctx.body)?;
+
+ let scalar_components = scalar_components(&frontend.module.types[ty].inner);
+ if let Some((kind, width)) = scalar_components {
+ ctx.ctx
+ .implicit_conversion(frontend, &mut expr, init_meta, kind, width)?;
+ }
+
+ meta.subsume(init_meta);
+
+ Ok((expr, init_meta))
+ })
+ .transpose()?;
+
+ let is_const = ctx.qualifiers.storage.0 == StorageQualifier::Const;
+ let maybe_constant = if ctx.external {
+ if let Some((root, meta)) = init {
+ match frontend.solve_constant(ctx.ctx, root, meta) {
+ Ok(res) => Some(res),
+ // If the declaration is external (global scope) and is constant qualified
+ // then the initializer must be a constant expression
+ Err(err) if is_const => return Err(err),
+ _ => None,
+ }
+ } else {
+ None
+ }
+ } else {
+ None
+ };
+
+ let pointer = ctx.add_var(frontend, ty, name, maybe_constant, meta)?;
+
+ if let Some((value, _)) = init.filter(|_| maybe_constant.is_none()) {
+ ctx.flush_expressions();
+ ctx.body.push(Statement::Store { pointer, value }, meta);
+ }
+
+ let token = self.bump(frontend)?;
+ match token.value {
+ TokenValue::Semicolon => break,
+ TokenValue::Comma => {}
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::InvalidToken(
+ token.value,
+ vec![TokenValue::Comma.into(), TokenValue::Semicolon.into()],
+ ),
+ meta: token.meta,
+ })
+ }
+ }
+ }
+
+ Ok(())
+ }
+
+ /// `external` whether or not we are in a global or local context
+ pub fn parse_declaration(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ body: &mut Block,
+ external: bool,
+ ) -> Result<Option<Span>> {
+ //declaration:
+ // function_prototype SEMICOLON
+ //
+ // init_declarator_list SEMICOLON
+ // PRECISION precision_qualifier type_specifier SEMICOLON
+ //
+ // type_qualifier IDENTIFIER LEFT_BRACE struct_declaration_list RIGHT_BRACE SEMICOLON
+ // type_qualifier IDENTIFIER LEFT_BRACE struct_declaration_list RIGHT_BRACE IDENTIFIER SEMICOLON
+ // type_qualifier IDENTIFIER LEFT_BRACE struct_declaration_list RIGHT_BRACE IDENTIFIER array_specifier SEMICOLON
+ // type_qualifier SEMICOLON type_qualifier IDENTIFIER SEMICOLON
+ // type_qualifier IDENTIFIER identifier_list SEMICOLON
+
+ if self.peek_type_qualifier(frontend) || self.peek_type_name(frontend) {
+ let mut qualifiers = self.parse_type_qualifiers(frontend)?;
+
+ if self.peek_type_name(frontend) {
+ // This branch handles variables and function prototypes and if
+ // external is true also function definitions
+ let (ty, mut meta) = self.parse_type(frontend)?;
+
+ let token = self.bump(frontend)?;
+ let token_fallthrough = match token.value {
+ TokenValue::Identifier(name) => match self.expect_peek(frontend)?.value {
+ TokenValue::LeftParen => {
+ // This branch handles function definition and prototypes
+ self.bump(frontend)?;
+
+ let result = ty.map(|ty| FunctionResult { ty, binding: None });
+ let mut body = Block::new();
+
+ let mut context = Context::new(frontend, &mut body);
+
+ self.parse_function_args(frontend, &mut context, &mut body)?;
+
+ let end_meta = self.expect(frontend, TokenValue::RightParen)?.meta;
+ meta.subsume(end_meta);
+
+ let token = self.bump(frontend)?;
+ return match token.value {
+ TokenValue::Semicolon => {
+ // This branch handles function prototypes
+ frontend.add_prototype(context, name, result, meta);
+
+ Ok(Some(meta))
+ }
+ TokenValue::LeftBrace if external => {
+ // This branch handles function definitions
+ // as you can see by the guard this branch
+ // only happens if external is also true
+
+ // parse the body
+ self.parse_compound_statement(
+ token.meta,
+ frontend,
+ &mut context,
+ &mut body,
+ &mut None,
+ )?;
+
+ frontend.add_function(context, name, result, body, meta);
+
+ Ok(Some(meta))
+ }
+ _ if external => Err(Error {
+ kind: ErrorKind::InvalidToken(
+ token.value,
+ vec![
+ TokenValue::LeftBrace.into(),
+ TokenValue::Semicolon.into(),
+ ],
+ ),
+ meta: token.meta,
+ }),
+ _ => Err(Error {
+ kind: ErrorKind::InvalidToken(
+ token.value,
+ vec![TokenValue::Semicolon.into()],
+ ),
+ meta: token.meta,
+ }),
+ };
+ }
+ // Pass the token to the init_declarator_list parser
+ _ => Token {
+ value: TokenValue::Identifier(name),
+ meta: token.meta,
+ },
+ },
+ // Pass the token to the init_declarator_list parser
+ _ => token,
+ };
+
+ // If program execution has reached here then this will be a
+ // init_declarator_list
+ // token_fallthrough will have a token that was already bumped
+ if let Some(ty) = ty {
+ let mut ctx = DeclarationContext {
+ qualifiers,
+ external,
+ ctx,
+ body,
+ };
+
+ self.backtrack(token_fallthrough)?;
+ self.parse_init_declarator_list(frontend, ty, &mut ctx)?;
+ } else {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError("Declaration cannot have void type".into()),
+ meta,
+ })
+ }
+
+ Ok(Some(meta))
+ } else {
+ // This branch handles struct definitions and modifiers like
+ // ```glsl
+ // layout(early_fragment_tests);
+ // ```
+ let token = self.bump(frontend)?;
+ match token.value {
+ TokenValue::Identifier(ty_name) => {
+ if self.bump_if(frontend, TokenValue::LeftBrace).is_some() {
+ self.parse_block_declaration(
+ frontend,
+ ctx,
+ body,
+ &mut qualifiers,
+ ty_name,
+ token.meta,
+ )
+ .map(Some)
+ } else {
+ if qualifiers.invariant.take().is_some() {
+ frontend.make_variable_invariant(ctx, body, &ty_name, token.meta);
+
+ qualifiers.unused_errors(&mut frontend.errors);
+ self.expect(frontend, TokenValue::Semicolon)?;
+ return Ok(Some(qualifiers.span));
+ }
+
+ //TODO: declaration
+ // type_qualifier IDENTIFIER SEMICOLON
+ // type_qualifier IDENTIFIER identifier_list SEMICOLON
+ Err(Error {
+ kind: ErrorKind::NotImplemented("variable qualifier"),
+ meta: token.meta,
+ })
+ }
+ }
+ TokenValue::Semicolon => {
+ if let Some(value) =
+ qualifiers.uint_layout_qualifier("local_size_x", &mut frontend.errors)
+ {
+ frontend.meta.workgroup_size[0] = value;
+ }
+ if let Some(value) =
+ qualifiers.uint_layout_qualifier("local_size_y", &mut frontend.errors)
+ {
+ frontend.meta.workgroup_size[1] = value;
+ }
+ if let Some(value) =
+ qualifiers.uint_layout_qualifier("local_size_z", &mut frontend.errors)
+ {
+ frontend.meta.workgroup_size[2] = value;
+ }
+
+ frontend.meta.early_fragment_tests |= qualifiers
+ .none_layout_qualifier("early_fragment_tests", &mut frontend.errors);
+
+ qualifiers.unused_errors(&mut frontend.errors);
+
+ Ok(Some(qualifiers.span))
+ }
+ _ => Err(Error {
+ kind: ErrorKind::InvalidToken(
+ token.value,
+ vec![ExpectedToken::Identifier, TokenValue::Semicolon.into()],
+ ),
+ meta: token.meta,
+ }),
+ }
+ }
+ } else {
+ match self.peek(frontend).map(|t| &t.value) {
+ Some(&TokenValue::Precision) => {
+ // PRECISION precision_qualifier type_specifier SEMICOLON
+ self.bump(frontend)?;
+
+ let token = self.bump(frontend)?;
+ let _ = match token.value {
+ TokenValue::PrecisionQualifier(p) => p,
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::InvalidToken(
+ token.value,
+ vec![
+ TokenValue::PrecisionQualifier(Precision::High).into(),
+ TokenValue::PrecisionQualifier(Precision::Medium).into(),
+ TokenValue::PrecisionQualifier(Precision::Low).into(),
+ ],
+ ),
+ meta: token.meta,
+ })
+ }
+ };
+
+ let (ty, meta) = self.parse_type_non_void(frontend)?;
+
+ match frontend.module.types[ty].inner {
+ TypeInner::Scalar {
+ kind: ScalarKind::Float | ScalarKind::Sint,
+ ..
+ } => {}
+ _ => frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "Precision statement can only work on floats and ints".into(),
+ ),
+ meta,
+ }),
+ }
+
+ self.expect(frontend, TokenValue::Semicolon)?;
+
+ Ok(Some(meta))
+ }
+ _ => Ok(None),
+ }
+ }
+ }
+
+ pub fn parse_block_declaration(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ body: &mut Block,
+ qualifiers: &mut TypeQualifiers,
+ ty_name: String,
+ mut meta: Span,
+ ) -> Result<Span> {
+ let layout = match qualifiers.layout_qualifiers.remove(&QualifierKey::Layout) {
+ Some((QualifierValue::Layout(l), _)) => l,
+ None => {
+ if let StorageQualifier::AddressSpace(AddressSpace::Storage { .. }) =
+ qualifiers.storage.0
+ {
+ StructLayout::Std430
+ } else {
+ StructLayout::Std140
+ }
+ }
+ _ => unreachable!(),
+ };
+
+ let mut members = Vec::new();
+ let span = self.parse_struct_declaration_list(frontend, &mut members, layout)?;
+ self.expect(frontend, TokenValue::RightBrace)?;
+
+ let mut ty = frontend.module.types.insert(
+ Type {
+ name: Some(ty_name),
+ inner: TypeInner::Struct {
+ members: members.clone(),
+ span,
+ },
+ },
+ Default::default(),
+ );
+
+ let token = self.bump(frontend)?;
+ let name = match token.value {
+ TokenValue::Semicolon => None,
+ TokenValue::Identifier(name) => {
+ self.parse_array_specifier(frontend, &mut meta, &mut ty)?;
+
+ self.expect(frontend, TokenValue::Semicolon)?;
+
+ Some(name)
+ }
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::InvalidToken(
+ token.value,
+ vec![ExpectedToken::Identifier, TokenValue::Semicolon.into()],
+ ),
+ meta: token.meta,
+ })
+ }
+ };
+
+ let global = frontend.add_global_var(
+ ctx,
+ body,
+ VarDeclaration {
+ qualifiers,
+ ty,
+ name,
+ init: None,
+ meta,
+ },
+ )?;
+
+ for (i, k, ty) in members.into_iter().enumerate().filter_map(|(i, m)| {
+ let ty = m.ty;
+ m.name.map(|s| (i as u32, s, ty))
+ }) {
+ let lookup = GlobalLookup {
+ kind: match global {
+ GlobalOrConstant::Global(handle) => GlobalLookupKind::BlockSelect(handle, i),
+ GlobalOrConstant::Constant(handle) => GlobalLookupKind::Constant(handle, ty),
+ },
+ entry_arg: None,
+ mutable: true,
+ };
+ ctx.add_global(frontend, &k, lookup, body);
+
+ frontend.global_variables.push((k, lookup));
+ }
+
+ Ok(meta)
+ }
+
+ // TODO: Accept layout arguments
+ pub fn parse_struct_declaration_list(
+ &mut self,
+ frontend: &mut Frontend,
+ members: &mut Vec<StructMember>,
+ layout: StructLayout,
+ ) -> Result<u32> {
+ let mut span = 0;
+ let mut align = Alignment::ONE;
+
+ loop {
+ // TODO: type_qualifier
+
+ let (mut ty, mut meta) = self.parse_type_non_void(frontend)?;
+ let (name, end_meta) = self.expect_ident(frontend)?;
+
+ meta.subsume(end_meta);
+
+ self.parse_array_specifier(frontend, &mut meta, &mut ty)?;
+
+ self.expect(frontend, TokenValue::Semicolon)?;
+
+ let info = offset::calculate_offset(
+ ty,
+ meta,
+ layout,
+ &mut frontend.module.types,
+ &frontend.module.constants,
+ &mut frontend.errors,
+ );
+
+ let member_alignment = info.align;
+ span = member_alignment.round_up(span);
+ align = member_alignment.max(align);
+
+ members.push(StructMember {
+ name: Some(name),
+ ty: info.ty,
+ binding: None,
+ offset: span,
+ });
+
+ span += info.span;
+
+ if let TokenValue::RightBrace = self.expect_peek(frontend)?.value {
+ break;
+ }
+ }
+
+ span = align.round_up(span);
+
+ Ok(span)
+ }
+}
diff --git a/third_party/rust/naga/src/front/glsl/parser/expressions.rs b/third_party/rust/naga/src/front/glsl/parser/expressions.rs
new file mode 100644
index 0000000000..f09e58b6f6
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/parser/expressions.rs
@@ -0,0 +1,546 @@
+use crate::{
+ front::glsl::{
+ ast::{FunctionCall, FunctionCallKind, HirExpr, HirExprKind},
+ context::{Context, StmtContext},
+ error::{ErrorKind, ExpectedToken},
+ parser::ParsingContext,
+ token::{Token, TokenValue},
+ Error, Frontend, Result, Span,
+ },
+ ArraySize, BinaryOperator, Block, Constant, ConstantInner, Handle, ScalarValue, Type,
+ TypeInner, UnaryOperator,
+};
+
+impl<'source> ParsingContext<'source> {
+ pub fn parse_primary(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ stmt: &mut StmtContext,
+ body: &mut Block,
+ ) -> Result<Handle<HirExpr>> {
+ let mut token = self.bump(frontend)?;
+
+ let (width, value) = match token.value {
+ TokenValue::IntConstant(int) => (
+ (int.width / 8) as u8,
+ if int.signed {
+ ScalarValue::Sint(int.value as i64)
+ } else {
+ ScalarValue::Uint(int.value)
+ },
+ ),
+ TokenValue::FloatConstant(float) => (
+ (float.width / 8) as u8,
+ ScalarValue::Float(float.value as f64),
+ ),
+ TokenValue::BoolConstant(value) => (1, ScalarValue::Bool(value)),
+ TokenValue::LeftParen => {
+ let expr = self.parse_expression(frontend, ctx, stmt, body)?;
+ let meta = self.expect(frontend, TokenValue::RightParen)?.meta;
+
+ token.meta.subsume(meta);
+
+ return Ok(expr);
+ }
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::InvalidToken(
+ token.value,
+ vec![
+ TokenValue::LeftParen.into(),
+ ExpectedToken::IntLiteral,
+ ExpectedToken::FloatLiteral,
+ ExpectedToken::BoolLiteral,
+ ],
+ ),
+ meta: token.meta,
+ });
+ }
+ };
+
+ let handle = frontend.module.constants.fetch_or_append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner: ConstantInner::Scalar { width, value },
+ },
+ token.meta,
+ );
+
+ Ok(stmt.hir_exprs.append(
+ HirExpr {
+ kind: HirExprKind::Constant(handle),
+ meta: token.meta,
+ },
+ Default::default(),
+ ))
+ }
+
+ pub fn parse_function_call_args(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ stmt: &mut StmtContext,
+ body: &mut Block,
+ meta: &mut Span,
+ ) -> Result<Vec<Handle<HirExpr>>> {
+ let mut args = Vec::new();
+ if let Some(token) = self.bump_if(frontend, TokenValue::RightParen) {
+ meta.subsume(token.meta);
+ } else {
+ loop {
+ args.push(self.parse_assignment(frontend, ctx, stmt, body)?);
+
+ let token = self.bump(frontend)?;
+ match token.value {
+ TokenValue::Comma => {}
+ TokenValue::RightParen => {
+ meta.subsume(token.meta);
+ break;
+ }
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::InvalidToken(
+ token.value,
+ vec![TokenValue::Comma.into(), TokenValue::RightParen.into()],
+ ),
+ meta: token.meta,
+ });
+ }
+ }
+ }
+ }
+
+ Ok(args)
+ }
+
+ pub fn parse_postfix(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ stmt: &mut StmtContext,
+ body: &mut Block,
+ ) -> Result<Handle<HirExpr>> {
+ let mut base = if self.peek_type_name(frontend) {
+ let (mut handle, mut meta) = self.parse_type_non_void(frontend)?;
+
+ self.expect(frontend, TokenValue::LeftParen)?;
+ let args = self.parse_function_call_args(frontend, ctx, stmt, body, &mut meta)?;
+
+ if let TypeInner::Array {
+ size: ArraySize::Dynamic,
+ stride,
+ base,
+ } = frontend.module.types[handle].inner
+ {
+ let span = frontend.module.types.get_span(handle);
+
+ let constant = frontend.module.constants.fetch_or_append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner: ConstantInner::Scalar {
+ width: 4,
+ value: ScalarValue::Uint(args.len() as u64),
+ },
+ },
+ Span::default(),
+ );
+ handle = frontend.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Array {
+ stride,
+ base,
+ size: ArraySize::Constant(constant),
+ },
+ },
+ span,
+ )
+ }
+
+ stmt.hir_exprs.append(
+ HirExpr {
+ kind: HirExprKind::Call(FunctionCall {
+ kind: FunctionCallKind::TypeConstructor(handle),
+ args,
+ }),
+ meta,
+ },
+ Default::default(),
+ )
+ } else if let TokenValue::Identifier(_) = self.expect_peek(frontend)?.value {
+ let (name, mut meta) = self.expect_ident(frontend)?;
+
+ let expr = if self.bump_if(frontend, TokenValue::LeftParen).is_some() {
+ let args = self.parse_function_call_args(frontend, ctx, stmt, body, &mut meta)?;
+
+ let kind = match frontend.lookup_type.get(&name) {
+ Some(ty) => FunctionCallKind::TypeConstructor(*ty),
+ None => FunctionCallKind::Function(name),
+ };
+
+ HirExpr {
+ kind: HirExprKind::Call(FunctionCall { kind, args }),
+ meta,
+ }
+ } else {
+ let var = match frontend.lookup_variable(ctx, body, &name, meta) {
+ Some(var) => var,
+ None => {
+ return Err(Error {
+ kind: ErrorKind::UnknownVariable(name),
+ meta,
+ })
+ }
+ };
+
+ HirExpr {
+ kind: HirExprKind::Variable(var),
+ meta,
+ }
+ };
+
+ stmt.hir_exprs.append(expr, Default::default())
+ } else {
+ self.parse_primary(frontend, ctx, stmt, body)?
+ };
+
+ while let TokenValue::LeftBracket
+ | TokenValue::Dot
+ | TokenValue::Increment
+ | TokenValue::Decrement = self.expect_peek(frontend)?.value
+ {
+ let Token { value, mut meta } = self.bump(frontend)?;
+
+ match value {
+ TokenValue::LeftBracket => {
+ let index = self.parse_expression(frontend, ctx, stmt, body)?;
+ let end_meta = self.expect(frontend, TokenValue::RightBracket)?.meta;
+
+ meta.subsume(end_meta);
+ base = stmt.hir_exprs.append(
+ HirExpr {
+ kind: HirExprKind::Access { base, index },
+ meta,
+ },
+ Default::default(),
+ )
+ }
+ TokenValue::Dot => {
+ let (field, end_meta) = self.expect_ident(frontend)?;
+
+ if self.bump_if(frontend, TokenValue::LeftParen).is_some() {
+ let args =
+ self.parse_function_call_args(frontend, ctx, stmt, body, &mut meta)?;
+
+ base = stmt.hir_exprs.append(
+ HirExpr {
+ kind: HirExprKind::Method {
+ expr: base,
+ name: field,
+ args,
+ },
+ meta,
+ },
+ Default::default(),
+ );
+ continue;
+ }
+
+ meta.subsume(end_meta);
+ base = stmt.hir_exprs.append(
+ HirExpr {
+ kind: HirExprKind::Select { base, field },
+ meta,
+ },
+ Default::default(),
+ )
+ }
+ TokenValue::Increment | TokenValue::Decrement => {
+ base = stmt.hir_exprs.append(
+ HirExpr {
+ kind: HirExprKind::PrePostfix {
+ op: match value {
+ TokenValue::Increment => crate::BinaryOperator::Add,
+ _ => crate::BinaryOperator::Subtract,
+ },
+ postfix: true,
+ expr: base,
+ },
+ meta,
+ },
+ Default::default(),
+ )
+ }
+ _ => unreachable!(),
+ }
+ }
+
+ Ok(base)
+ }
+
+ pub fn parse_unary(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ stmt: &mut StmtContext,
+ body: &mut Block,
+ ) -> Result<Handle<HirExpr>> {
+ Ok(match self.expect_peek(frontend)?.value {
+ TokenValue::Plus | TokenValue::Dash | TokenValue::Bang | TokenValue::Tilde => {
+ let Token { value, mut meta } = self.bump(frontend)?;
+
+ let expr = self.parse_unary(frontend, ctx, stmt, body)?;
+ let end_meta = stmt.hir_exprs[expr].meta;
+
+ let kind = match value {
+ TokenValue::Dash => HirExprKind::Unary {
+ op: UnaryOperator::Negate,
+ expr,
+ },
+ TokenValue::Bang | TokenValue::Tilde => HirExprKind::Unary {
+ op: UnaryOperator::Not,
+ expr,
+ },
+ _ => return Ok(expr),
+ };
+
+ meta.subsume(end_meta);
+ stmt.hir_exprs
+ .append(HirExpr { kind, meta }, Default::default())
+ }
+ TokenValue::Increment | TokenValue::Decrement => {
+ let Token { value, meta } = self.bump(frontend)?;
+
+ let expr = self.parse_unary(frontend, ctx, stmt, body)?;
+
+ stmt.hir_exprs.append(
+ HirExpr {
+ kind: HirExprKind::PrePostfix {
+ op: match value {
+ TokenValue::Increment => crate::BinaryOperator::Add,
+ _ => crate::BinaryOperator::Subtract,
+ },
+ postfix: false,
+ expr,
+ },
+ meta,
+ },
+ Default::default(),
+ )
+ }
+ _ => self.parse_postfix(frontend, ctx, stmt, body)?,
+ })
+ }
+
+ pub fn parse_binary(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ stmt: &mut StmtContext,
+ body: &mut Block,
+ passthrough: Option<Handle<HirExpr>>,
+ min_bp: u8,
+ ) -> Result<Handle<HirExpr>> {
+ let mut left = passthrough
+ .ok_or(ErrorKind::EndOfFile /* Dummy error */)
+ .or_else(|_| self.parse_unary(frontend, ctx, stmt, body))?;
+ let mut meta = stmt.hir_exprs[left].meta;
+
+ while let Some((l_bp, r_bp)) = binding_power(&self.expect_peek(frontend)?.value) {
+ if l_bp < min_bp {
+ break;
+ }
+
+ let Token { value, .. } = self.bump(frontend)?;
+
+ let right = self.parse_binary(frontend, ctx, stmt, body, None, r_bp)?;
+ let end_meta = stmt.hir_exprs[right].meta;
+
+ meta.subsume(end_meta);
+ left = stmt.hir_exprs.append(
+ HirExpr {
+ kind: HirExprKind::Binary {
+ left,
+ op: match value {
+ TokenValue::LogicalOr => BinaryOperator::LogicalOr,
+ TokenValue::LogicalXor => BinaryOperator::NotEqual,
+ TokenValue::LogicalAnd => BinaryOperator::LogicalAnd,
+ TokenValue::VerticalBar => BinaryOperator::InclusiveOr,
+ TokenValue::Caret => BinaryOperator::ExclusiveOr,
+ TokenValue::Ampersand => BinaryOperator::And,
+ TokenValue::Equal => BinaryOperator::Equal,
+ TokenValue::NotEqual => BinaryOperator::NotEqual,
+ TokenValue::GreaterEqual => BinaryOperator::GreaterEqual,
+ TokenValue::LessEqual => BinaryOperator::LessEqual,
+ TokenValue::LeftAngle => BinaryOperator::Less,
+ TokenValue::RightAngle => BinaryOperator::Greater,
+ TokenValue::LeftShift => BinaryOperator::ShiftLeft,
+ TokenValue::RightShift => BinaryOperator::ShiftRight,
+ TokenValue::Plus => BinaryOperator::Add,
+ TokenValue::Dash => BinaryOperator::Subtract,
+ TokenValue::Star => BinaryOperator::Multiply,
+ TokenValue::Slash => BinaryOperator::Divide,
+ TokenValue::Percent => BinaryOperator::Modulo,
+ _ => unreachable!(),
+ },
+ right,
+ },
+ meta,
+ },
+ Default::default(),
+ )
+ }
+
+ Ok(left)
+ }
+
+ pub fn parse_conditional(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ stmt: &mut StmtContext,
+ body: &mut Block,
+ passthrough: Option<Handle<HirExpr>>,
+ ) -> Result<Handle<HirExpr>> {
+ let mut condition = self.parse_binary(frontend, ctx, stmt, body, passthrough, 0)?;
+ let mut meta = stmt.hir_exprs[condition].meta;
+
+ if self.bump_if(frontend, TokenValue::Question).is_some() {
+ let accept = self.parse_expression(frontend, ctx, stmt, body)?;
+ self.expect(frontend, TokenValue::Colon)?;
+ let reject = self.parse_assignment(frontend, ctx, stmt, body)?;
+ let end_meta = stmt.hir_exprs[reject].meta;
+
+ meta.subsume(end_meta);
+ condition = stmt.hir_exprs.append(
+ HirExpr {
+ kind: HirExprKind::Conditional {
+ condition,
+ accept,
+ reject,
+ },
+ meta,
+ },
+ Default::default(),
+ )
+ }
+
+ Ok(condition)
+ }
+
+ pub fn parse_assignment(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ stmt: &mut StmtContext,
+ body: &mut Block,
+ ) -> Result<Handle<HirExpr>> {
+ let tgt = self.parse_unary(frontend, ctx, stmt, body)?;
+ let mut meta = stmt.hir_exprs[tgt].meta;
+
+ Ok(match self.expect_peek(frontend)?.value {
+ TokenValue::Assign => {
+ self.bump(frontend)?;
+ let value = self.parse_assignment(frontend, ctx, stmt, body)?;
+ let end_meta = stmt.hir_exprs[value].meta;
+
+ meta.subsume(end_meta);
+ stmt.hir_exprs.append(
+ HirExpr {
+ kind: HirExprKind::Assign { tgt, value },
+ meta,
+ },
+ Default::default(),
+ )
+ }
+ TokenValue::OrAssign
+ | TokenValue::AndAssign
+ | TokenValue::AddAssign
+ | TokenValue::DivAssign
+ | TokenValue::ModAssign
+ | TokenValue::SubAssign
+ | TokenValue::MulAssign
+ | TokenValue::LeftShiftAssign
+ | TokenValue::RightShiftAssign
+ | TokenValue::XorAssign => {
+ let token = self.bump(frontend)?;
+ let right = self.parse_assignment(frontend, ctx, stmt, body)?;
+ let end_meta = stmt.hir_exprs[right].meta;
+
+ meta.subsume(end_meta);
+ let value = stmt.hir_exprs.append(
+ HirExpr {
+ meta,
+ kind: HirExprKind::Binary {
+ left: tgt,
+ op: match token.value {
+ TokenValue::OrAssign => BinaryOperator::InclusiveOr,
+ TokenValue::AndAssign => BinaryOperator::And,
+ TokenValue::AddAssign => BinaryOperator::Add,
+ TokenValue::DivAssign => BinaryOperator::Divide,
+ TokenValue::ModAssign => BinaryOperator::Modulo,
+ TokenValue::SubAssign => BinaryOperator::Subtract,
+ TokenValue::MulAssign => BinaryOperator::Multiply,
+ TokenValue::LeftShiftAssign => BinaryOperator::ShiftLeft,
+ TokenValue::RightShiftAssign => BinaryOperator::ShiftRight,
+ TokenValue::XorAssign => BinaryOperator::ExclusiveOr,
+ _ => unreachable!(),
+ },
+ right,
+ },
+ },
+ Default::default(),
+ );
+
+ stmt.hir_exprs.append(
+ HirExpr {
+ kind: HirExprKind::Assign { tgt, value },
+ meta,
+ },
+ Default::default(),
+ )
+ }
+ _ => self.parse_conditional(frontend, ctx, stmt, body, Some(tgt))?,
+ })
+ }
+
+ pub fn parse_expression(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ stmt: &mut StmtContext,
+ body: &mut Block,
+ ) -> Result<Handle<HirExpr>> {
+ let mut expr = self.parse_assignment(frontend, ctx, stmt, body)?;
+
+ while let TokenValue::Comma = self.expect_peek(frontend)?.value {
+ self.bump(frontend)?;
+ expr = self.parse_assignment(frontend, ctx, stmt, body)?;
+ }
+
+ Ok(expr)
+ }
+}
+
+const fn binding_power(value: &TokenValue) -> Option<(u8, u8)> {
+ Some(match *value {
+ TokenValue::LogicalOr => (1, 2),
+ TokenValue::LogicalXor => (3, 4),
+ TokenValue::LogicalAnd => (5, 6),
+ TokenValue::VerticalBar => (7, 8),
+ TokenValue::Caret => (9, 10),
+ TokenValue::Ampersand => (11, 12),
+ TokenValue::Equal | TokenValue::NotEqual => (13, 14),
+ TokenValue::GreaterEqual
+ | TokenValue::LessEqual
+ | TokenValue::LeftAngle
+ | TokenValue::RightAngle => (15, 16),
+ TokenValue::LeftShift | TokenValue::RightShift => (17, 18),
+ TokenValue::Plus | TokenValue::Dash => (19, 20),
+ TokenValue::Star | TokenValue::Slash | TokenValue::Percent => (21, 22),
+ _ => return None,
+ })
+}
diff --git a/third_party/rust/naga/src/front/glsl/parser/functions.rs b/third_party/rust/naga/src/front/glsl/parser/functions.rs
new file mode 100644
index 0000000000..781f4cffed
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/parser/functions.rs
@@ -0,0 +1,654 @@
+use crate::front::glsl::context::ExprPos;
+use crate::front::glsl::Span;
+use crate::{
+ front::glsl::{
+ ast::ParameterQualifier,
+ context::Context,
+ parser::ParsingContext,
+ token::{Token, TokenValue},
+ variables::VarDeclaration,
+ Error, ErrorKind, Frontend, Result,
+ },
+ Block, ConstantInner, Expression, ScalarValue, Statement, SwitchCase, UnaryOperator,
+};
+
+impl<'source> ParsingContext<'source> {
+ pub fn peek_parameter_qualifier(&mut self, frontend: &mut Frontend) -> bool {
+ self.peek(frontend).map_or(false, |t| match t.value {
+ TokenValue::In | TokenValue::Out | TokenValue::InOut | TokenValue::Const => true,
+ _ => false,
+ })
+ }
+
+ /// Returns the parsed `ParameterQualifier` or `ParameterQualifier::In`
+ pub fn parse_parameter_qualifier(&mut self, frontend: &mut Frontend) -> ParameterQualifier {
+ if self.peek_parameter_qualifier(frontend) {
+ match self.bump(frontend).unwrap().value {
+ TokenValue::In => ParameterQualifier::In,
+ TokenValue::Out => ParameterQualifier::Out,
+ TokenValue::InOut => ParameterQualifier::InOut,
+ TokenValue::Const => ParameterQualifier::Const,
+ _ => unreachable!(),
+ }
+ } else {
+ ParameterQualifier::In
+ }
+ }
+
+ pub fn parse_statement(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ body: &mut Block,
+ terminator: &mut Option<usize>,
+ ) -> Result<Option<Span>> {
+ // Type qualifiers always identify a declaration statement
+ if self.peek_type_qualifier(frontend) {
+ return self.parse_declaration(frontend, ctx, body, false);
+ }
+
+ // Type names can identify either declaration statements or type constructors
+ // depending on wether the token following the type name is a `(` (LeftParen)
+ if self.peek_type_name(frontend) {
+ // Start by consuming the type name so that we can peek the token after it
+ let token = self.bump(frontend)?;
+ // Peek the next token and check if it's a `(` (LeftParen) if so the statement
+ // is a constructor, otherwise it's a declaration. We need to do the check
+ // beforehand and not in the if since we will backtrack before the if
+ let declaration = TokenValue::LeftParen != self.expect_peek(frontend)?.value;
+
+ self.backtrack(token)?;
+
+ if declaration {
+ return self.parse_declaration(frontend, ctx, body, false);
+ }
+ }
+
+ let new_break = || {
+ let mut block = Block::new();
+ block.push(Statement::Break, crate::Span::default());
+ block
+ };
+
+ let &Token {
+ ref value,
+ mut meta,
+ } = self.expect_peek(frontend)?;
+
+ let meta_rest = match *value {
+ TokenValue::Continue => {
+ let meta = self.bump(frontend)?.meta;
+ body.push(Statement::Continue, meta);
+ terminator.get_or_insert(body.len());
+ self.expect(frontend, TokenValue::Semicolon)?.meta
+ }
+ TokenValue::Break => {
+ let meta = self.bump(frontend)?.meta;
+ body.push(Statement::Break, meta);
+ terminator.get_or_insert(body.len());
+ self.expect(frontend, TokenValue::Semicolon)?.meta
+ }
+ TokenValue::Return => {
+ self.bump(frontend)?;
+ let (value, meta) = match self.expect_peek(frontend)?.value {
+ TokenValue::Semicolon => (None, self.bump(frontend)?.meta),
+ _ => {
+ // TODO: Implicit conversions
+ let mut stmt = ctx.stmt_ctx();
+ let expr = self.parse_expression(frontend, ctx, &mut stmt, body)?;
+ self.expect(frontend, TokenValue::Semicolon)?;
+ let (handle, meta) =
+ ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs, body)?;
+ (Some(handle), meta)
+ }
+ };
+
+ ctx.emit_restart(body);
+
+ body.push(Statement::Return { value }, meta);
+ terminator.get_or_insert(body.len());
+
+ meta
+ }
+ TokenValue::Discard => {
+ let meta = self.bump(frontend)?.meta;
+ body.push(Statement::Kill, meta);
+ terminator.get_or_insert(body.len());
+
+ self.expect(frontend, TokenValue::Semicolon)?.meta
+ }
+ TokenValue::If => {
+ let mut meta = self.bump(frontend)?.meta;
+
+ self.expect(frontend, TokenValue::LeftParen)?;
+ let condition = {
+ let mut stmt = ctx.stmt_ctx();
+ let expr = self.parse_expression(frontend, ctx, &mut stmt, body)?;
+ let (handle, more_meta) =
+ ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs, body)?;
+ meta.subsume(more_meta);
+ handle
+ };
+ self.expect(frontend, TokenValue::RightParen)?;
+
+ ctx.emit_restart(body);
+
+ let mut accept = Block::new();
+ if let Some(more_meta) =
+ self.parse_statement(frontend, ctx, &mut accept, &mut None)?
+ {
+ meta.subsume(more_meta)
+ }
+
+ let mut reject = Block::new();
+ if self.bump_if(frontend, TokenValue::Else).is_some() {
+ if let Some(more_meta) =
+ self.parse_statement(frontend, ctx, &mut reject, &mut None)?
+ {
+ meta.subsume(more_meta);
+ }
+ }
+
+ body.push(
+ Statement::If {
+ condition,
+ accept,
+ reject,
+ },
+ meta,
+ );
+
+ meta
+ }
+ TokenValue::Switch => {
+ let mut meta = self.bump(frontend)?.meta;
+ let end_meta;
+
+ self.expect(frontend, TokenValue::LeftParen)?;
+
+ let (selector, uint) = {
+ let mut stmt = ctx.stmt_ctx();
+ let expr = self.parse_expression(frontend, ctx, &mut stmt, body)?;
+ let (root, meta) =
+ ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs, body)?;
+ let uint = frontend.resolve_type(ctx, root, meta)?.scalar_kind()
+ == Some(crate::ScalarKind::Uint);
+ (root, uint)
+ };
+
+ self.expect(frontend, TokenValue::RightParen)?;
+
+ ctx.emit_restart(body);
+
+ let mut cases = Vec::new();
+ // Track if any default case is present in the switch statement.
+ let mut default_present = false;
+
+ self.expect(frontend, TokenValue::LeftBrace)?;
+ loop {
+ let value = match self.expect_peek(frontend)?.value {
+ TokenValue::Case => {
+ self.bump(frontend)?;
+
+ let mut stmt = ctx.stmt_ctx();
+ let expr = self.parse_expression(frontend, ctx, &mut stmt, body)?;
+ let (root, meta) =
+ ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs, body)?;
+ let constant = frontend.solve_constant(ctx, root, meta)?;
+
+ match frontend.module.constants[constant].inner {
+ ConstantInner::Scalar {
+ value: ScalarValue::Sint(int),
+ ..
+ } => match uint {
+ true => crate::SwitchValue::U32(int as u32),
+ false => crate::SwitchValue::I32(int as i32),
+ },
+ ConstantInner::Scalar {
+ value: ScalarValue::Uint(int),
+ ..
+ } => crate::SwitchValue::U32(int as u32),
+ _ => {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "Case values can only be integers".into(),
+ ),
+ meta,
+ });
+
+ crate::SwitchValue::I32(0)
+ }
+ }
+ }
+ TokenValue::Default => {
+ self.bump(frontend)?;
+ default_present = true;
+ crate::SwitchValue::Default
+ }
+ TokenValue::RightBrace => {
+ end_meta = self.bump(frontend)?.meta;
+ break;
+ }
+ _ => {
+ let Token { value, meta } = self.bump(frontend)?;
+ return Err(Error {
+ kind: ErrorKind::InvalidToken(
+ value,
+ vec![
+ TokenValue::Case.into(),
+ TokenValue::Default.into(),
+ TokenValue::RightBrace.into(),
+ ],
+ ),
+ meta,
+ });
+ }
+ };
+
+ self.expect(frontend, TokenValue::Colon)?;
+
+ let mut body = Block::new();
+
+ let mut case_terminator = None;
+ loop {
+ match self.expect_peek(frontend)?.value {
+ TokenValue::Case | TokenValue::Default | TokenValue::RightBrace => {
+ break
+ }
+ _ => {
+ self.parse_statement(
+ frontend,
+ ctx,
+ &mut body,
+ &mut case_terminator,
+ )?;
+ }
+ }
+ }
+
+ let mut fall_through = true;
+
+ if let Some(mut idx) = case_terminator {
+ if let Statement::Break = body[idx - 1] {
+ fall_through = false;
+ idx -= 1;
+ }
+
+ body.cull(idx..)
+ }
+
+ cases.push(SwitchCase {
+ value,
+ body,
+ fall_through,
+ })
+ }
+
+ meta.subsume(end_meta);
+
+ // NOTE: do not unwrap here since a switch statement isn't required
+ // to have any cases.
+ if let Some(case) = cases.last_mut() {
+ // GLSL requires that the last case not be empty, so we check
+ // that here and produce an error otherwise (fall_through must
+ // also be checked because `break`s count as statements but
+ // they aren't added to the body)
+ if case.body.is_empty() && case.fall_through {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "last case/default label must be followed by statements".into(),
+ ),
+ meta,
+ })
+ }
+
+ // GLSL allows the last case to not have any `break` statement,
+ // this would mark it as fall through but naga's IR requires that
+ // the last case must not be fall through, so we mark need to mark
+ // the last case as not fall through always.
+ case.fall_through = false;
+ }
+
+ // Add an empty default case in case non was present, this is needed because
+ // naga's IR requires that all switch statements must have a default case but
+ // GLSL doesn't require that, so we might need to add an empty default case.
+ if !default_present {
+ cases.push(SwitchCase {
+ value: crate::SwitchValue::Default,
+ body: Block::new(),
+ fall_through: false,
+ })
+ }
+
+ body.push(Statement::Switch { selector, cases }, meta);
+
+ meta
+ }
+ TokenValue::While => {
+ let mut meta = self.bump(frontend)?.meta;
+
+ let mut loop_body = Block::new();
+
+ let mut stmt = ctx.stmt_ctx();
+ self.expect(frontend, TokenValue::LeftParen)?;
+ let root = self.parse_expression(frontend, ctx, &mut stmt, &mut loop_body)?;
+ meta.subsume(self.expect(frontend, TokenValue::RightParen)?.meta);
+
+ let (expr, expr_meta) =
+ ctx.lower_expect(stmt, frontend, root, ExprPos::Rhs, &mut loop_body)?;
+ let condition = ctx.add_expression(
+ Expression::Unary {
+ op: UnaryOperator::Not,
+ expr,
+ },
+ expr_meta,
+ &mut loop_body,
+ );
+
+ ctx.emit_restart(&mut loop_body);
+
+ loop_body.push(
+ Statement::If {
+ condition,
+ accept: new_break(),
+ reject: Block::new(),
+ },
+ crate::Span::default(),
+ );
+
+ meta.subsume(expr_meta);
+
+ if let Some(body_meta) =
+ self.parse_statement(frontend, ctx, &mut loop_body, &mut None)?
+ {
+ meta.subsume(body_meta);
+ }
+
+ body.push(
+ Statement::Loop {
+ body: loop_body,
+ continuing: Block::new(),
+ break_if: None,
+ },
+ meta,
+ );
+
+ meta
+ }
+ TokenValue::Do => {
+ let mut meta = self.bump(frontend)?.meta;
+
+ let mut loop_body = Block::new();
+
+ let mut terminator = None;
+ self.parse_statement(frontend, ctx, &mut loop_body, &mut terminator)?;
+
+ let mut stmt = ctx.stmt_ctx();
+
+ self.expect(frontend, TokenValue::While)?;
+ self.expect(frontend, TokenValue::LeftParen)?;
+ let root = self.parse_expression(frontend, ctx, &mut stmt, &mut loop_body)?;
+ let end_meta = self.expect(frontend, TokenValue::RightParen)?.meta;
+
+ meta.subsume(end_meta);
+
+ let (expr, expr_meta) =
+ ctx.lower_expect(stmt, frontend, root, ExprPos::Rhs, &mut loop_body)?;
+ let condition = ctx.add_expression(
+ Expression::Unary {
+ op: UnaryOperator::Not,
+ expr,
+ },
+ expr_meta,
+ &mut loop_body,
+ );
+
+ ctx.emit_restart(&mut loop_body);
+
+ loop_body.push(
+ Statement::If {
+ condition,
+ accept: new_break(),
+ reject: Block::new(),
+ },
+ crate::Span::default(),
+ );
+
+ if let Some(idx) = terminator {
+ loop_body.cull(idx..)
+ }
+
+ body.push(
+ Statement::Loop {
+ body: loop_body,
+ continuing: Block::new(),
+ break_if: None,
+ },
+ meta,
+ );
+
+ meta
+ }
+ TokenValue::For => {
+ let mut meta = self.bump(frontend)?.meta;
+
+ ctx.symbol_table.push_scope();
+ self.expect(frontend, TokenValue::LeftParen)?;
+
+ if self.bump_if(frontend, TokenValue::Semicolon).is_none() {
+ if self.peek_type_name(frontend) || self.peek_type_qualifier(frontend) {
+ self.parse_declaration(frontend, ctx, body, false)?;
+ } else {
+ let mut stmt = ctx.stmt_ctx();
+ let expr = self.parse_expression(frontend, ctx, &mut stmt, body)?;
+ ctx.lower(stmt, frontend, expr, ExprPos::Rhs, body)?;
+ self.expect(frontend, TokenValue::Semicolon)?;
+ }
+ }
+
+ let (mut block, mut continuing) = (Block::new(), Block::new());
+
+ if self.bump_if(frontend, TokenValue::Semicolon).is_none() {
+ let (expr, expr_meta) = if self.peek_type_name(frontend)
+ || self.peek_type_qualifier(frontend)
+ {
+ let mut qualifiers = self.parse_type_qualifiers(frontend)?;
+ let (ty, mut meta) = self.parse_type_non_void(frontend)?;
+ let name = self.expect_ident(frontend)?.0;
+
+ self.expect(frontend, TokenValue::Assign)?;
+
+ let (value, end_meta) =
+ self.parse_initializer(frontend, ty, ctx, &mut block)?;
+ meta.subsume(end_meta);
+
+ let decl = VarDeclaration {
+ qualifiers: &mut qualifiers,
+ ty,
+ name: Some(name),
+ init: None,
+ meta,
+ };
+
+ let pointer = frontend.add_local_var(ctx, &mut block, decl)?;
+
+ ctx.emit_restart(&mut block);
+
+ block.push(Statement::Store { pointer, value }, meta);
+
+ (value, end_meta)
+ } else {
+ let mut stmt = ctx.stmt_ctx();
+ let root = self.parse_expression(frontend, ctx, &mut stmt, &mut block)?;
+ ctx.lower_expect(stmt, frontend, root, ExprPos::Rhs, &mut block)?
+ };
+
+ let condition = ctx.add_expression(
+ Expression::Unary {
+ op: UnaryOperator::Not,
+ expr,
+ },
+ expr_meta,
+ &mut block,
+ );
+
+ ctx.emit_restart(&mut block);
+
+ block.push(
+ Statement::If {
+ condition,
+ accept: new_break(),
+ reject: Block::new(),
+ },
+ crate::Span::default(),
+ );
+
+ self.expect(frontend, TokenValue::Semicolon)?;
+ }
+
+ match self.expect_peek(frontend)?.value {
+ TokenValue::RightParen => {}
+ _ => {
+ let mut stmt = ctx.stmt_ctx();
+ let rest =
+ self.parse_expression(frontend, ctx, &mut stmt, &mut continuing)?;
+ ctx.lower(stmt, frontend, rest, ExprPos::Rhs, &mut continuing)?;
+ }
+ }
+
+ meta.subsume(self.expect(frontend, TokenValue::RightParen)?.meta);
+
+ if let Some(stmt_meta) =
+ self.parse_statement(frontend, ctx, &mut block, &mut None)?
+ {
+ meta.subsume(stmt_meta);
+ }
+
+ body.push(
+ Statement::Loop {
+ body: block,
+ continuing,
+ break_if: None,
+ },
+ meta,
+ );
+
+ ctx.symbol_table.pop_scope();
+
+ meta
+ }
+ TokenValue::LeftBrace => {
+ let meta = self.bump(frontend)?.meta;
+
+ let mut block = Block::new();
+
+ let mut block_terminator = None;
+ let meta = self.parse_compound_statement(
+ meta,
+ frontend,
+ ctx,
+ &mut block,
+ &mut block_terminator,
+ )?;
+
+ body.push(Statement::Block(block), meta);
+ if block_terminator.is_some() {
+ terminator.get_or_insert(body.len());
+ }
+
+ meta
+ }
+ TokenValue::Semicolon => self.bump(frontend)?.meta,
+ _ => {
+ // Attempt to force expression parsing for remainder of the
+ // tokens. Unknown or invalid tokens will be caught there and
+ // turned into an error.
+ let mut stmt = ctx.stmt_ctx();
+ let expr = self.parse_expression(frontend, ctx, &mut stmt, body)?;
+ ctx.lower(stmt, frontend, expr, ExprPos::Rhs, body)?;
+ self.expect(frontend, TokenValue::Semicolon)?.meta
+ }
+ };
+
+ meta.subsume(meta_rest);
+ Ok(Some(meta))
+ }
+
+ pub fn parse_compound_statement(
+ &mut self,
+ mut meta: Span,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ body: &mut Block,
+ terminator: &mut Option<usize>,
+ ) -> Result<Span> {
+ ctx.symbol_table.push_scope();
+
+ loop {
+ if let Some(Token {
+ meta: brace_meta, ..
+ }) = self.bump_if(frontend, TokenValue::RightBrace)
+ {
+ meta.subsume(brace_meta);
+ break;
+ }
+
+ let stmt = self.parse_statement(frontend, ctx, body, terminator)?;
+
+ if let Some(stmt_meta) = stmt {
+ meta.subsume(stmt_meta);
+ }
+ }
+
+ if let Some(idx) = *terminator {
+ body.cull(idx..)
+ }
+
+ ctx.symbol_table.pop_scope();
+
+ Ok(meta)
+ }
+
+ pub fn parse_function_args(
+ &mut self,
+ frontend: &mut Frontend,
+ context: &mut Context,
+ body: &mut Block,
+ ) -> Result<()> {
+ if self.bump_if(frontend, TokenValue::Void).is_some() {
+ return Ok(());
+ }
+
+ loop {
+ if self.peek_type_name(frontend) || self.peek_parameter_qualifier(frontend) {
+ let qualifier = self.parse_parameter_qualifier(frontend);
+ let mut ty = self.parse_type_non_void(frontend)?.0;
+
+ match self.expect_peek(frontend)?.value {
+ TokenValue::Comma => {
+ self.bump(frontend)?;
+ context.add_function_arg(frontend, body, None, ty, qualifier);
+ continue;
+ }
+ TokenValue::Identifier(_) => {
+ let mut name = self.expect_ident(frontend)?;
+ self.parse_array_specifier(frontend, &mut name.1, &mut ty)?;
+
+ context.add_function_arg(frontend, body, Some(name), ty, qualifier);
+
+ if self.bump_if(frontend, TokenValue::Comma).is_some() {
+ continue;
+ }
+
+ break;
+ }
+ _ => break,
+ }
+ }
+
+ break;
+ }
+
+ Ok(())
+ }
+}
diff --git a/third_party/rust/naga/src/front/glsl/parser/types.rs b/third_party/rust/naga/src/front/glsl/parser/types.rs
new file mode 100644
index 0000000000..08a70669a0
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/parser/types.rs
@@ -0,0 +1,434 @@
+use crate::{
+ front::glsl::{
+ ast::{QualifierKey, QualifierValue, StorageQualifier, StructLayout, TypeQualifiers},
+ error::ExpectedToken,
+ parser::ParsingContext,
+ token::{Token, TokenValue},
+ Error, ErrorKind, Frontend, Result,
+ },
+ AddressSpace, ArraySize, Handle, Span, Type, TypeInner,
+};
+
+impl<'source> ParsingContext<'source> {
+ /// Parses an optional array_specifier returning wether or not it's present
+ /// and modifying the type handle if it exists
+ pub fn parse_array_specifier(
+ &mut self,
+ frontend: &mut Frontend,
+ span: &mut Span,
+ ty: &mut Handle<Type>,
+ ) -> Result<()> {
+ while self.parse_array_specifier_single(frontend, span, ty)? {}
+ Ok(())
+ }
+
+ /// Implementation of [`Self::parse_array_specifier`] for a single array_specifier
+ fn parse_array_specifier_single(
+ &mut self,
+ frontend: &mut Frontend,
+ span: &mut Span,
+ ty: &mut Handle<Type>,
+ ) -> Result<bool> {
+ if self.bump_if(frontend, TokenValue::LeftBracket).is_some() {
+ let size = if let Some(Token { meta, .. }) =
+ self.bump_if(frontend, TokenValue::RightBracket)
+ {
+ span.subsume(meta);
+ ArraySize::Dynamic
+ } else {
+ let (value, constant_span) = self.parse_uint_constant(frontend)?;
+ let constant = frontend.module.constants.fetch_or_append(
+ crate::Constant {
+ name: None,
+ specialization: None,
+ inner: crate::ConstantInner::Scalar {
+ width: 4,
+ value: crate::ScalarValue::Uint(value as u64),
+ },
+ },
+ constant_span,
+ );
+ let end_span = self.expect(frontend, TokenValue::RightBracket)?.meta;
+ span.subsume(end_span);
+ ArraySize::Constant(constant)
+ };
+
+ frontend
+ .layouter
+ .update(&frontend.module.types, &frontend.module.constants)
+ .unwrap();
+ let stride = frontend.layouter[*ty].to_stride();
+ *ty = frontend.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Array {
+ base: *ty,
+ size,
+ stride,
+ },
+ },
+ *span,
+ );
+
+ Ok(true)
+ } else {
+ Ok(false)
+ }
+ }
+
+ pub fn parse_type(&mut self, frontend: &mut Frontend) -> Result<(Option<Handle<Type>>, Span)> {
+ let token = self.bump(frontend)?;
+ let mut handle = match token.value {
+ TokenValue::Void => return Ok((None, token.meta)),
+ TokenValue::TypeName(ty) => frontend.module.types.insert(ty, token.meta),
+ TokenValue::Struct => {
+ let mut meta = token.meta;
+ let ty_name = self.expect_ident(frontend)?.0;
+ self.expect(frontend, TokenValue::LeftBrace)?;
+ let mut members = Vec::new();
+ let span = self.parse_struct_declaration_list(
+ frontend,
+ &mut members,
+ StructLayout::Std140,
+ )?;
+ let end_meta = self.expect(frontend, TokenValue::RightBrace)?.meta;
+ meta.subsume(end_meta);
+ let ty = frontend.module.types.insert(
+ Type {
+ name: Some(ty_name.clone()),
+ inner: TypeInner::Struct { members, span },
+ },
+ meta,
+ );
+ frontend.lookup_type.insert(ty_name, ty);
+ ty
+ }
+ TokenValue::Identifier(ident) => match frontend.lookup_type.get(&ident) {
+ Some(ty) => *ty,
+ None => {
+ return Err(Error {
+ kind: ErrorKind::UnknownType(ident),
+ meta: token.meta,
+ })
+ }
+ },
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::InvalidToken(
+ token.value,
+ vec![
+ TokenValue::Void.into(),
+ TokenValue::Struct.into(),
+ ExpectedToken::TypeName,
+ ],
+ ),
+ meta: token.meta,
+ });
+ }
+ };
+
+ let mut span = token.meta;
+ self.parse_array_specifier(frontend, &mut span, &mut handle)?;
+ Ok((Some(handle), span))
+ }
+
+ pub fn parse_type_non_void(&mut self, frontend: &mut Frontend) -> Result<(Handle<Type>, Span)> {
+ let (maybe_ty, meta) = self.parse_type(frontend)?;
+ let ty = maybe_ty.ok_or_else(|| Error {
+ kind: ErrorKind::SemanticError("Type can't be void".into()),
+ meta,
+ })?;
+
+ Ok((ty, meta))
+ }
+
+ pub fn peek_type_qualifier(&mut self, frontend: &mut Frontend) -> bool {
+ self.peek(frontend).map_or(false, |t| match t.value {
+ TokenValue::Invariant
+ | TokenValue::Interpolation(_)
+ | TokenValue::Sampling(_)
+ | TokenValue::PrecisionQualifier(_)
+ | TokenValue::Const
+ | TokenValue::In
+ | TokenValue::Out
+ | TokenValue::Uniform
+ | TokenValue::Shared
+ | TokenValue::Buffer
+ | TokenValue::Restrict
+ | TokenValue::MemoryQualifier(_)
+ | TokenValue::Layout => true,
+ _ => false,
+ })
+ }
+
+ pub fn parse_type_qualifiers<'a>(
+ &mut self,
+ frontend: &mut Frontend,
+ ) -> Result<TypeQualifiers<'a>> {
+ let mut qualifiers = TypeQualifiers::default();
+
+ while self.peek_type_qualifier(frontend) {
+ let token = self.bump(frontend)?;
+
+ // Handle layout qualifiers outside the match since this can push multiple values
+ if token.value == TokenValue::Layout {
+ self.parse_layout_qualifier_id_list(frontend, &mut qualifiers)?;
+ continue;
+ }
+
+ qualifiers.span.subsume(token.meta);
+
+ match token.value {
+ TokenValue::Invariant => {
+ if qualifiers.invariant.is_some() {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "Cannot use more than one invariant qualifier per declaration"
+ .into(),
+ ),
+ meta: token.meta,
+ })
+ }
+
+ qualifiers.invariant = Some(token.meta);
+ }
+ TokenValue::Interpolation(i) => {
+ if qualifiers.interpolation.is_some() {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "Cannot use more than one interpolation qualifier per declaration"
+ .into(),
+ ),
+ meta: token.meta,
+ })
+ }
+
+ qualifiers.interpolation = Some((i, token.meta));
+ }
+ TokenValue::Const
+ | TokenValue::In
+ | TokenValue::Out
+ | TokenValue::Uniform
+ | TokenValue::Shared
+ | TokenValue::Buffer => {
+ let storage = match token.value {
+ TokenValue::Const => StorageQualifier::Const,
+ TokenValue::In => StorageQualifier::Input,
+ TokenValue::Out => StorageQualifier::Output,
+ TokenValue::Uniform => {
+ StorageQualifier::AddressSpace(AddressSpace::Uniform)
+ }
+ TokenValue::Shared => {
+ StorageQualifier::AddressSpace(AddressSpace::WorkGroup)
+ }
+ TokenValue::Buffer => {
+ StorageQualifier::AddressSpace(AddressSpace::Storage {
+ access: crate::StorageAccess::all(),
+ })
+ }
+ _ => unreachable!(),
+ };
+
+ if StorageQualifier::AddressSpace(AddressSpace::Function)
+ != qualifiers.storage.0
+ {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "Cannot use more than one storage qualifier per declaration".into(),
+ ),
+ meta: token.meta,
+ });
+ }
+
+ qualifiers.storage = (storage, token.meta);
+ }
+ TokenValue::Sampling(s) => {
+ if qualifiers.sampling.is_some() {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "Cannot use more than one sampling qualifier per declaration"
+ .into(),
+ ),
+ meta: token.meta,
+ })
+ }
+
+ qualifiers.sampling = Some((s, token.meta));
+ }
+ TokenValue::PrecisionQualifier(p) => {
+ if qualifiers.precision.is_some() {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "Cannot use more than one precision qualifier per declaration"
+ .into(),
+ ),
+ meta: token.meta,
+ })
+ }
+
+ qualifiers.precision = Some((p, token.meta));
+ }
+ TokenValue::MemoryQualifier(access) => {
+ let storage_access = qualifiers
+ .storage_access
+ .get_or_insert((crate::StorageAccess::all(), Span::default()));
+ if !storage_access.0.contains(!access) {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "The same memory qualifier can only be used once".into(),
+ ),
+ meta: token.meta,
+ })
+ }
+
+ storage_access.0 &= access;
+ storage_access.1.subsume(token.meta);
+ }
+ TokenValue::Restrict => continue,
+ _ => unreachable!(),
+ };
+ }
+
+ Ok(qualifiers)
+ }
+
+ pub fn parse_layout_qualifier_id_list(
+ &mut self,
+ frontend: &mut Frontend,
+ qualifiers: &mut TypeQualifiers,
+ ) -> Result<()> {
+ self.expect(frontend, TokenValue::LeftParen)?;
+ loop {
+ self.parse_layout_qualifier_id(frontend, &mut qualifiers.layout_qualifiers)?;
+
+ if self.bump_if(frontend, TokenValue::Comma).is_some() {
+ continue;
+ }
+
+ break;
+ }
+ let token = self.expect(frontend, TokenValue::RightParen)?;
+ qualifiers.span.subsume(token.meta);
+
+ Ok(())
+ }
+
+ pub fn parse_layout_qualifier_id(
+ &mut self,
+ frontend: &mut Frontend,
+ qualifiers: &mut crate::FastHashMap<QualifierKey, (QualifierValue, Span)>,
+ ) -> Result<()> {
+ // layout_qualifier_id:
+ // IDENTIFIER
+ // IDENTIFIER EQUAL constant_expression
+ // SHARED
+ let mut token = self.bump(frontend)?;
+ match token.value {
+ TokenValue::Identifier(name) => {
+ let (key, value) = match name.as_str() {
+ "std140" => (
+ QualifierKey::Layout,
+ QualifierValue::Layout(StructLayout::Std140),
+ ),
+ "std430" => (
+ QualifierKey::Layout,
+ QualifierValue::Layout(StructLayout::Std430),
+ ),
+ word => {
+ if let Some(format) = map_image_format(word) {
+ (QualifierKey::Format, QualifierValue::Format(format))
+ } else {
+ let key = QualifierKey::String(name.into());
+ let value = if self.bump_if(frontend, TokenValue::Assign).is_some() {
+ let (value, end_meta) = match self.parse_uint_constant(frontend) {
+ Ok(v) => v,
+ Err(e) => {
+ frontend.errors.push(e);
+ (0, Span::default())
+ }
+ };
+ token.meta.subsume(end_meta);
+
+ QualifierValue::Uint(value)
+ } else {
+ QualifierValue::None
+ };
+
+ (key, value)
+ }
+ }
+ };
+
+ qualifiers.insert(key, (value, token.meta));
+ }
+ _ => frontend.errors.push(Error {
+ kind: ErrorKind::InvalidToken(token.value, vec![ExpectedToken::Identifier]),
+ meta: token.meta,
+ }),
+ }
+
+ Ok(())
+ }
+
+ pub fn peek_type_name(&mut self, frontend: &mut Frontend) -> bool {
+ self.peek(frontend).map_or(false, |t| match t.value {
+ TokenValue::TypeName(_) | TokenValue::Void => true,
+ TokenValue::Struct => true,
+ TokenValue::Identifier(ref ident) => frontend.lookup_type.contains_key(ident),
+ _ => false,
+ })
+ }
+}
+
+fn map_image_format(word: &str) -> Option<crate::StorageFormat> {
+ use crate::StorageFormat as Sf;
+
+ let format = match word {
+ // float-image-format-qualifier:
+ "rgba32f" => Sf::Rgba32Float,
+ "rgba16f" => Sf::Rgba16Float,
+ "rg32f" => Sf::Rg32Float,
+ "rg16f" => Sf::Rg16Float,
+ "r11f_g11f_b10f" => Sf::Rg11b10Float,
+ "r32f" => Sf::R32Float,
+ "r16f" => Sf::R16Float,
+ "rgba16" => Sf::Rgba16Unorm,
+ "rgb10_a2" => Sf::Rgb10a2Unorm,
+ "rgba8" => Sf::Rgba8Unorm,
+ "rg16" => Sf::Rg16Unorm,
+ "rg8" => Sf::Rg8Unorm,
+ "r16" => Sf::R16Unorm,
+ "r8" => Sf::R8Unorm,
+ "rgba16_snorm" => Sf::Rgba16Snorm,
+ "rgba8_snorm" => Sf::Rgba8Snorm,
+ "rg16_snorm" => Sf::Rg16Snorm,
+ "rg8_snorm" => Sf::Rg8Snorm,
+ "r16_snorm" => Sf::R16Snorm,
+ "r8_snorm" => Sf::R8Snorm,
+ // int-image-format-qualifier:
+ "rgba32i" => Sf::Rgba32Sint,
+ "rgba16i" => Sf::Rgba16Sint,
+ "rgba8i" => Sf::Rgba8Sint,
+ "rg32i" => Sf::Rg32Sint,
+ "rg16i" => Sf::Rg16Sint,
+ "rg8i" => Sf::Rg8Sint,
+ "r32i" => Sf::R32Sint,
+ "r16i" => Sf::R16Sint,
+ "r8i" => Sf::R8Sint,
+ // uint-image-format-qualifier:
+ "rgba32ui" => Sf::Rgba32Uint,
+ "rgba16ui" => Sf::Rgba16Uint,
+ "rgba8ui" => Sf::Rgba8Uint,
+ "rg32ui" => Sf::Rg32Uint,
+ "rg16ui" => Sf::Rg16Uint,
+ "rg8ui" => Sf::Rg8Uint,
+ "r32ui" => Sf::R32Uint,
+ "r16ui" => Sf::R16Uint,
+ "r8ui" => Sf::R8Uint,
+ // TODO: These next ones seem incorrect to me
+ // "rgb10_a2ui" => Sf::Rgb10a2Unorm,
+ _ => return None,
+ };
+
+ Some(format)
+}
diff --git a/third_party/rust/naga/src/front/glsl/parser_tests.rs b/third_party/rust/naga/src/front/glsl/parser_tests.rs
new file mode 100644
index 0000000000..ae529a6fbe
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/parser_tests.rs
@@ -0,0 +1,832 @@
+use super::{
+ ast::Profile,
+ error::ExpectedToken,
+ error::{Error, ErrorKind},
+ token::TokenValue,
+ Frontend, Options, Span,
+};
+use crate::ShaderStage;
+use pp_rs::token::PreprocessorError;
+
+#[test]
+fn version() {
+ let mut frontend = Frontend::default();
+
+ // invalid versions
+ assert_eq!(
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ "#version 99000\n void main(){}",
+ )
+ .err()
+ .unwrap(),
+ vec![Error {
+ kind: ErrorKind::InvalidVersion(99000),
+ meta: Span::new(9, 14)
+ }],
+ );
+
+ assert_eq!(
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ "#version 449\n void main(){}",
+ )
+ .err()
+ .unwrap(),
+ vec![Error {
+ kind: ErrorKind::InvalidVersion(449),
+ meta: Span::new(9, 12)
+ }]
+ );
+
+ assert_eq!(
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ "#version 450 smart\n void main(){}",
+ )
+ .err()
+ .unwrap(),
+ vec![Error {
+ kind: ErrorKind::InvalidProfile("smart".into()),
+ meta: Span::new(13, 18),
+ }]
+ );
+
+ assert_eq!(
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ "#version 450\nvoid main(){} #version 450",
+ )
+ .err()
+ .unwrap(),
+ vec![
+ Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedHash,),
+ meta: Span::new(27, 28),
+ },
+ Error {
+ kind: ErrorKind::InvalidToken(
+ TokenValue::Identifier("version".into()),
+ vec![ExpectedToken::Eof]
+ ),
+ meta: Span::new(28, 35)
+ }
+ ]
+ );
+
+ // valid versions
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ " # version 450\nvoid main() {}",
+ )
+ .unwrap();
+ assert_eq!(
+ (frontend.metadata().version, frontend.metadata().profile),
+ (450, Profile::Core)
+ );
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ "#version 450\nvoid main() {}",
+ )
+ .unwrap();
+ assert_eq!(
+ (frontend.metadata().version, frontend.metadata().profile),
+ (450, Profile::Core)
+ );
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ "#version 450 core\nvoid main(void) {}",
+ )
+ .unwrap();
+ assert_eq!(
+ (frontend.metadata().version, frontend.metadata().profile),
+ (450, Profile::Core)
+ );
+}
+
+#[test]
+fn control_flow() {
+ let mut frontend = Frontend::default();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void main() {
+ if (true) {
+ return 1;
+ } else {
+ return 2;
+ }
+ }
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void main() {
+ if (true) {
+ return 1;
+ }
+ }
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void main() {
+ int x;
+ int y = 3;
+ switch (5) {
+ case 2:
+ x = 2;
+ case 5:
+ x = 5;
+ y = 2;
+ break;
+ default:
+ x = 0;
+ }
+ }
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void main() {
+ int x = 0;
+ while(x < 5) {
+ x = x + 1;
+ }
+ do {
+ x = x - 1;
+ } while(x >= 4)
+ }
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void main() {
+ int x = 0;
+ for(int i = 0; i < 10;) {
+ x = x + 2;
+ }
+ for(;;);
+ return x;
+ }
+ "#,
+ )
+ .unwrap();
+}
+
+#[test]
+fn declarations() {
+ let mut frontend = Frontend::default();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ #version 450
+ layout(location = 0) in vec2 v_uv;
+ layout(location = 0) out vec4 o_color;
+ layout(set = 1, binding = 1) uniform texture2D tex;
+ layout(set = 1, binding = 2) uniform sampler tex_sampler;
+
+ layout(early_fragment_tests) in;
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ #version 450
+ layout(std140, set = 2, binding = 0)
+ uniform u_locals {
+ vec3 model_offs;
+ float load_time;
+ ivec4 atlas_offs;
+ };
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ #version 450
+ layout(push_constant)
+ uniform u_locals {
+ vec3 model_offs;
+ float load_time;
+ ivec4 atlas_offs;
+ };
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ #version 450
+ layout(std430, set = 2, binding = 0)
+ uniform u_locals {
+ vec3 model_offs;
+ float load_time;
+ ivec4 atlas_offs;
+ };
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ #version 450
+ layout(std140, set = 2, binding = 0)
+ uniform u_locals {
+ vec3 model_offs;
+ float load_time;
+ } block_var;
+
+ void main() {
+ load_time * model_offs;
+ block_var.load_time * block_var.model_offs;
+ }
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ #version 450
+ float vector = vec4(1.0 / 17.0, 9.0 / 17.0, 3.0 / 17.0, 11.0 / 17.0);
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ #version 450
+ precision highp float;
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+}
+
+#[test]
+fn textures() {
+ let mut frontend = Frontend::default();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ #version 450
+ layout(location = 0) in vec2 v_uv;
+ layout(location = 0) out vec4 o_color;
+ layout(set = 1, binding = 1) uniform texture2D tex;
+ layout(set = 1, binding = 2) uniform sampler tex_sampler;
+ void main() {
+ o_color = texture(sampler2D(tex, tex_sampler), v_uv);
+ o_color.a = texture(sampler2D(tex, tex_sampler), v_uv, 2.0).a;
+ }
+ "#,
+ )
+ .unwrap();
+}
+
+#[test]
+fn functions() {
+ let mut frontend = Frontend::default();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void test1(float);
+ void test1(float) {}
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void test2(float a) {}
+ void test3(float a, float b) {}
+ void test4(float, float) {}
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ float test(float a) { return a; }
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ float test(vec4 p) {
+ return p.x;
+ }
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ // Function overloading
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ float test(vec2 p);
+ float test(vec3 p);
+ float test(vec4 p);
+
+ float test(vec2 p) {
+ return p.x;
+ }
+
+ float test(vec3 p) {
+ return p.x;
+ }
+
+ float test(vec4 p) {
+ return p.x;
+ }
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ assert_eq!(
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ int test(vec4 p) {
+ return p.x;
+ }
+
+ float test(vec4 p) {
+ return p.x;
+ }
+
+ void main() {}
+ "#,
+ )
+ .err()
+ .unwrap(),
+ vec![Error {
+ kind: ErrorKind::SemanticError("Function already defined".into()),
+ meta: Span::new(134, 152),
+ }]
+ );
+
+ println!();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ float callee(uint q) {
+ return float(q);
+ }
+
+ float caller() {
+ callee(1u);
+ }
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ // Nested function call
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ layout(set = 0, binding = 1) uniform texture2D t_noise;
+ layout(set = 0, binding = 2) uniform sampler s_noise;
+
+ void main() {
+ textureLod(sampler2D(t_noise, s_noise), vec2(1.0), 0);
+ }
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void fun(vec2 in_parameter, out float out_parameter) {
+ ivec2 _ = ivec2(in_parameter);
+ }
+
+ void main() {
+ float a;
+ fun(vec2(1.0), a);
+ }
+ "#,
+ )
+ .unwrap();
+}
+
+#[test]
+fn constants() {
+ use crate::{Constant, ConstantInner, ScalarValue};
+ let mut frontend = Frontend::default();
+
+ let module = frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ const float a = 1.0;
+ float global = a;
+ const float b = a;
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ let mut constants = module.constants.iter();
+
+ assert_eq!(
+ constants.next().unwrap().1,
+ &Constant {
+ name: Some("a".to_owned()),
+ specialization: None,
+ inner: ConstantInner::Scalar {
+ width: 4,
+ value: ScalarValue::Float(1.0)
+ }
+ }
+ );
+ assert_eq!(
+ constants.next().unwrap().1,
+ &Constant {
+ name: Some("b".to_owned()),
+ specialization: None,
+ inner: ConstantInner::Scalar {
+ width: 4,
+ value: ScalarValue::Float(1.0)
+ }
+ }
+ );
+
+ assert!(constants.next().is_none());
+}
+
+#[test]
+fn function_overloading() {
+ let mut frontend = Frontend::default();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+
+ float saturate(float v) { return clamp(v, 0.0, 1.0); }
+ vec2 saturate(vec2 v) { return clamp(v, vec2(0.0), vec2(1.0)); }
+ vec3 saturate(vec3 v) { return clamp(v, vec3(0.0), vec3(1.0)); }
+ vec4 saturate(vec4 v) { return clamp(v, vec4(0.0), vec4(1.0)); }
+
+ void main() {
+ float v1 = saturate(1.5);
+ vec2 v2 = saturate(vec2(0.5, 1.5));
+ vec3 v3 = saturate(vec3(0.5, 1.5, 2.5));
+ vec3 v4 = saturate(vec4(0.5, 1.5, 2.5, 3.5));
+ }
+ "#,
+ )
+ .unwrap();
+}
+
+#[test]
+fn implicit_conversions() {
+ let mut frontend = Frontend::default();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void main() {
+ mat4 a = mat4(1);
+ float b = 1u;
+ float c = 1 + 2.0;
+ }
+ "#,
+ )
+ .unwrap();
+
+ assert_eq!(
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void test(int a) {}
+ void test(uint a) {}
+
+ void main() {
+ test(1.0);
+ }
+ "#,
+ )
+ .err()
+ .unwrap(),
+ vec![Error {
+ kind: ErrorKind::SemanticError("Unknown function \'test\'".into()),
+ meta: Span::new(156, 165),
+ }]
+ );
+
+ assert_eq!(
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void test(float a) {}
+ void test(uint a) {}
+
+ void main() {
+ test(1);
+ }
+ "#,
+ )
+ .err()
+ .unwrap(),
+ vec![Error {
+ kind: ErrorKind::SemanticError("Ambiguous best function for \'test\'".into()),
+ meta: Span::new(158, 165),
+ }]
+ );
+}
+
+#[test]
+fn structs() {
+ let mut frontend = Frontend::default();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ Test {
+ vec4 pos;
+ } xx;
+
+ void main() {}
+ "#,
+ )
+ .unwrap_err();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ struct Test {
+ vec4 pos;
+ };
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ const int NUM_VECS = 42;
+ struct Test {
+ vec4 vecs[NUM_VECS];
+ };
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ struct Hello {
+ vec4 test;
+ } test() {
+ return Hello( vec4(1.0) );
+ }
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ struct Test {};
+
+ void main() {}
+ "#,
+ )
+ .unwrap_err();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ inout struct Test {
+ vec4 x;
+ };
+
+ void main() {}
+ "#,
+ )
+ .unwrap_err();
+}
+
+#[test]
+fn swizzles() {
+ let mut frontend = Frontend::default();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void main() {
+ vec4 v = vec4(1);
+ v.xyz = vec3(2);
+ v.x = 5.0;
+ v.xyz.zxy.yx.xy = vec2(5.0, 1.0);
+ }
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void main() {
+ vec4 v = vec4(1);
+ v.xx = vec2(5.0);
+ }
+ "#,
+ )
+ .unwrap_err();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void main() {
+ vec3 v = vec3(1);
+ v.w = 2.0;
+ }
+ "#,
+ )
+ .unwrap_err();
+}
+
+#[test]
+fn expressions() {
+ let mut frontend = Frontend::default();
+
+ // Vector indexing
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ float test(int index) {
+ vec4 v = vec4(1.0, 2.0, 3.0, 4.0);
+ return v[index] + 1.0;
+ }
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ // Prefix increment/decrement
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void main() {
+ uint index = 0;
+
+ --index;
+ ++index;
+ }
+ "#,
+ )
+ .unwrap();
+
+ // Dynamic indexing of array
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void main() {
+ const vec4 positions[1] = { vec4(0) };
+
+ gl_Position = positions[gl_VertexIndex];
+ }
+ "#,
+ )
+ .unwrap();
+}
diff --git a/third_party/rust/naga/src/front/glsl/token.rs b/third_party/rust/naga/src/front/glsl/token.rs
new file mode 100644
index 0000000000..74f10cbf85
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/token.rs
@@ -0,0 +1,137 @@
+pub use pp_rs::token::{Float, Integer, Location, PreprocessorError, Token as PPToken};
+
+use super::ast::Precision;
+use crate::{Interpolation, Sampling, Span, Type};
+
+impl From<Location> for Span {
+ fn from(loc: Location) -> Self {
+ Span::new(loc.start, loc.end)
+ }
+}
+
+#[derive(Debug)]
+#[cfg_attr(test, derive(PartialEq))]
+pub struct Token {
+ pub value: TokenValue,
+ pub meta: Span,
+}
+
+/// A token passed from the lexing used in the parsing.
+///
+/// This type is exported since it's returned in the
+/// [`InvalidToken`](super::ErrorKind::InvalidToken) error.
+#[derive(Debug, PartialEq)]
+pub enum TokenValue {
+ Identifier(String),
+
+ FloatConstant(Float),
+ IntConstant(Integer),
+ BoolConstant(bool),
+
+ Layout,
+ In,
+ Out,
+ InOut,
+ Uniform,
+ Buffer,
+ Const,
+ Shared,
+
+ Restrict,
+ /// A `glsl` memory qualifier such as `writeonly`
+ ///
+ /// The associated [`crate::StorageAccess`] is the access being allowed
+ /// (for example `writeonly` has an associated value of [`crate::StorageAccess::STORE`])
+ MemoryQualifier(crate::StorageAccess),
+
+ Invariant,
+ Interpolation(Interpolation),
+ Sampling(Sampling),
+ Precision,
+ PrecisionQualifier(Precision),
+
+ Continue,
+ Break,
+ Return,
+ Discard,
+
+ If,
+ Else,
+ Switch,
+ Case,
+ Default,
+ While,
+ Do,
+ For,
+
+ Void,
+ Struct,
+ TypeName(Type),
+
+ Assign,
+ AddAssign,
+ SubAssign,
+ MulAssign,
+ DivAssign,
+ ModAssign,
+ LeftShiftAssign,
+ RightShiftAssign,
+ AndAssign,
+ XorAssign,
+ OrAssign,
+
+ Increment,
+ Decrement,
+
+ LogicalOr,
+ LogicalAnd,
+ LogicalXor,
+
+ LessEqual,
+ GreaterEqual,
+ Equal,
+ NotEqual,
+
+ LeftShift,
+ RightShift,
+
+ LeftBrace,
+ RightBrace,
+ LeftParen,
+ RightParen,
+ LeftBracket,
+ RightBracket,
+ LeftAngle,
+ RightAngle,
+
+ Comma,
+ Semicolon,
+ Colon,
+ Dot,
+ Bang,
+ Dash,
+ Tilde,
+ Plus,
+ Star,
+ Slash,
+ Percent,
+ VerticalBar,
+ Caret,
+ Ampersand,
+ Question,
+}
+
+#[derive(Debug)]
+#[cfg_attr(test, derive(PartialEq))]
+pub struct Directive {
+ pub kind: DirectiveKind,
+ pub tokens: Vec<PPToken>,
+}
+
+#[derive(Debug)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum DirectiveKind {
+ Version { is_first_directive: bool },
+ Extension,
+ Pragma,
+}
diff --git a/third_party/rust/naga/src/front/glsl/types.rs b/third_party/rust/naga/src/front/glsl/types.rs
new file mode 100644
index 0000000000..a7967848d5
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/types.rs
@@ -0,0 +1,335 @@
+use super::{
+ constants::ConstantSolver, context::Context, Error, ErrorKind, Frontend, Result, Span,
+};
+use crate::{
+ proc::ResolveContext, Bytes, Constant, Expression, Handle, ImageClass, ImageDimension,
+ ScalarKind, Type, TypeInner, VectorSize,
+};
+
+pub fn parse_type(type_name: &str) -> Option<Type> {
+ match type_name {
+ "bool" => Some(Type {
+ name: None,
+ inner: TypeInner::Scalar {
+ kind: ScalarKind::Bool,
+ width: crate::BOOL_WIDTH,
+ },
+ }),
+ "float" => Some(Type {
+ name: None,
+ inner: TypeInner::Scalar {
+ kind: ScalarKind::Float,
+ width: 4,
+ },
+ }),
+ "double" => Some(Type {
+ name: None,
+ inner: TypeInner::Scalar {
+ kind: ScalarKind::Float,
+ width: 8,
+ },
+ }),
+ "int" => Some(Type {
+ name: None,
+ inner: TypeInner::Scalar {
+ kind: ScalarKind::Sint,
+ width: 4,
+ },
+ }),
+ "uint" => Some(Type {
+ name: None,
+ inner: TypeInner::Scalar {
+ kind: ScalarKind::Uint,
+ width: 4,
+ },
+ }),
+ "sampler" | "samplerShadow" => Some(Type {
+ name: None,
+ inner: TypeInner::Sampler {
+ comparison: type_name == "samplerShadow",
+ },
+ }),
+ word => {
+ fn kind_width_parse(ty: &str) -> Option<(ScalarKind, u8)> {
+ Some(match ty {
+ "" => (ScalarKind::Float, 4),
+ "b" => (ScalarKind::Bool, crate::BOOL_WIDTH),
+ "i" => (ScalarKind::Sint, 4),
+ "u" => (ScalarKind::Uint, 4),
+ "d" => (ScalarKind::Float, 8),
+ _ => return None,
+ })
+ }
+
+ fn size_parse(n: &str) -> Option<VectorSize> {
+ Some(match n {
+ "2" => VectorSize::Bi,
+ "3" => VectorSize::Tri,
+ "4" => VectorSize::Quad,
+ _ => return None,
+ })
+ }
+
+ let vec_parse = |word: &str| {
+ let mut iter = word.split("vec");
+
+ let kind = iter.next()?;
+ let size = iter.next()?;
+ let (kind, width) = kind_width_parse(kind)?;
+ let size = size_parse(size)?;
+
+ Some(Type {
+ name: None,
+ inner: TypeInner::Vector { size, kind, width },
+ })
+ };
+
+ let mat_parse = |word: &str| {
+ let mut iter = word.split("mat");
+
+ let kind = iter.next()?;
+ let size = iter.next()?;
+ let (_, width) = kind_width_parse(kind)?;
+
+ let (columns, rows) = if let Some(size) = size_parse(size) {
+ (size, size)
+ } else {
+ let mut iter = size.split('x');
+ match (iter.next()?, iter.next()?, iter.next()) {
+ (col, row, None) => (size_parse(col)?, size_parse(row)?),
+ _ => return None,
+ }
+ };
+
+ Some(Type {
+ name: None,
+ inner: TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ },
+ })
+ };
+
+ let texture_parse = |word: &str| {
+ let mut iter = word.split("texture");
+
+ let texture_kind = |ty| {
+ Some(match ty {
+ "" => ScalarKind::Float,
+ "i" => ScalarKind::Sint,
+ "u" => ScalarKind::Uint,
+ _ => return None,
+ })
+ };
+
+ let kind = iter.next()?;
+ let size = iter.next()?;
+ let kind = texture_kind(kind)?;
+
+ let sampled = |multi| ImageClass::Sampled { kind, multi };
+
+ let (dim, arrayed, class) = match size {
+ "1D" => (ImageDimension::D1, false, sampled(false)),
+ "1DArray" => (ImageDimension::D1, true, sampled(false)),
+ "2D" => (ImageDimension::D2, false, sampled(false)),
+ "2DArray" => (ImageDimension::D2, true, sampled(false)),
+ "2DMS" => (ImageDimension::D2, false, sampled(true)),
+ "2DMSArray" => (ImageDimension::D2, true, sampled(true)),
+ "3D" => (ImageDimension::D3, false, sampled(false)),
+ "Cube" => (ImageDimension::Cube, false, sampled(false)),
+ "CubeArray" => (ImageDimension::Cube, true, sampled(false)),
+ _ => return None,
+ };
+
+ Some(Type {
+ name: None,
+ inner: TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ },
+ })
+ };
+
+ let image_parse = |word: &str| {
+ let mut iter = word.split("image");
+
+ let texture_kind = |ty| {
+ Some(match ty {
+ "" => ScalarKind::Float,
+ "i" => ScalarKind::Sint,
+ "u" => ScalarKind::Uint,
+ _ => return None,
+ })
+ };
+
+ let kind = iter.next()?;
+ let size = iter.next()?;
+ // TODO: Check that the texture format and the kind match
+ let _ = texture_kind(kind)?;
+
+ let class = ImageClass::Storage {
+ format: crate::StorageFormat::R8Uint,
+ access: crate::StorageAccess::all(),
+ };
+
+ // TODO: glsl support multisampled storage images, naga doesn't
+ let (dim, arrayed) = match size {
+ "1D" => (ImageDimension::D1, false),
+ "1DArray" => (ImageDimension::D1, true),
+ "2D" => (ImageDimension::D2, false),
+ "2DArray" => (ImageDimension::D2, true),
+ "3D" => (ImageDimension::D3, false),
+ // Naga doesn't support cube images and it's usefulness
+ // is questionable, so they won't be supported for now
+ // "Cube" => (ImageDimension::Cube, false),
+ // "CubeArray" => (ImageDimension::Cube, true),
+ _ => return None,
+ };
+
+ Some(Type {
+ name: None,
+ inner: TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ },
+ })
+ };
+
+ vec_parse(word)
+ .or_else(|| mat_parse(word))
+ .or_else(|| texture_parse(word))
+ .or_else(|| image_parse(word))
+ }
+ }
+}
+
+pub const fn scalar_components(ty: &TypeInner) -> Option<(ScalarKind, Bytes)> {
+ match *ty {
+ TypeInner::Scalar { kind, width } => Some((kind, width)),
+ TypeInner::Vector { kind, width, .. } => Some((kind, width)),
+ TypeInner::Matrix { width, .. } => Some((ScalarKind::Float, width)),
+ TypeInner::ValuePointer { kind, width, .. } => Some((kind, width)),
+ _ => None,
+ }
+}
+
+pub const fn type_power(kind: ScalarKind, width: Bytes) -> Option<u32> {
+ Some(match kind {
+ ScalarKind::Sint => 0,
+ ScalarKind::Uint => 1,
+ ScalarKind::Float if width == 4 => 2,
+ ScalarKind::Float => 3,
+ ScalarKind::Bool => return None,
+ })
+}
+
+impl Frontend {
+ /// Resolves the types of the expressions until `expr` (inclusive)
+ ///
+ /// This needs to be done before the [`typifier`] can be queried for
+ /// the types of the expressions in the range between the last grow and `expr`.
+ ///
+ /// # Note
+ ///
+ /// The `resolve_type*` methods (like [`resolve_type`]) automatically
+ /// grow the [`typifier`] so calling this method is not necessary when using
+ /// them.
+ ///
+ /// [`typifier`]: Context::typifier
+ /// [`resolve_type`]: Self::resolve_type
+ pub(crate) fn typifier_grow(
+ &self,
+ ctx: &mut Context,
+ expr: Handle<Expression>,
+ meta: Span,
+ ) -> Result<()> {
+ let resolve_ctx = ResolveContext::with_locals(&self.module, &ctx.locals, &ctx.arguments);
+
+ ctx.typifier
+ .grow(expr, &ctx.expressions, &resolve_ctx)
+ .map_err(|error| Error {
+ kind: ErrorKind::SemanticError(format!("Can't resolve type: {error:?}").into()),
+ meta,
+ })
+ }
+
+ /// Gets the type for the result of the `expr` expression
+ ///
+ /// Automatically grows the [`typifier`] to `expr` so calling
+ /// [`typifier_grow`] is not necessary
+ ///
+ /// [`typifier`]: Context::typifier
+ /// [`typifier_grow`]: Self::typifier_grow
+ pub(crate) fn resolve_type<'b>(
+ &'b self,
+ ctx: &'b mut Context,
+ expr: Handle<Expression>,
+ meta: Span,
+ ) -> Result<&'b TypeInner> {
+ self.typifier_grow(ctx, expr, meta)?;
+ Ok(ctx.typifier.get(expr, &self.module.types))
+ }
+
+ /// Gets the type handle for the result of the `expr` expression
+ ///
+ /// Automatically grows the [`typifier`] to `expr` so calling
+ /// [`typifier_grow`] is not necessary
+ ///
+ /// # Note
+ ///
+ /// Consider using [`resolve_type`] whenever possible
+ /// since it doesn't require adding each type to the [`types`] arena
+ /// and it doesn't need to mutably borrow the [`Parser`][Self]
+ ///
+ /// [`types`]: crate::Module::types
+ /// [`typifier`]: Context::typifier
+ /// [`typifier_grow`]: Self::typifier_grow
+ /// [`resolve_type`]: Self::resolve_type
+ pub(crate) fn resolve_type_handle(
+ &mut self,
+ ctx: &mut Context,
+ expr: Handle<Expression>,
+ meta: Span,
+ ) -> Result<Handle<Type>> {
+ self.typifier_grow(ctx, expr, meta)?;
+ Ok(ctx.typifier.register_type(expr, &mut self.module.types))
+ }
+
+ /// Invalidates the cached type resolution for `expr` forcing a recomputation
+ pub(crate) fn invalidate_expression<'b>(
+ &'b self,
+ ctx: &'b mut Context,
+ expr: Handle<Expression>,
+ meta: Span,
+ ) -> Result<()> {
+ let resolve_ctx = ResolveContext::with_locals(&self.module, &ctx.locals, &ctx.arguments);
+
+ ctx.typifier
+ .invalidate(expr, &ctx.expressions, &resolve_ctx)
+ .map_err(|error| Error {
+ kind: ErrorKind::SemanticError(format!("Can't resolve type: {error:?}").into()),
+ meta,
+ })
+ }
+
+ pub(crate) fn solve_constant(
+ &mut self,
+ ctx: &Context,
+ root: Handle<Expression>,
+ meta: Span,
+ ) -> Result<Handle<Constant>> {
+ let mut solver = ConstantSolver {
+ types: &mut self.module.types,
+ expressions: &ctx.expressions,
+ constants: &mut self.module.constants,
+ };
+
+ solver.solve(root).map_err(|e| Error {
+ kind: e.into(),
+ meta,
+ })
+ }
+}
diff --git a/third_party/rust/naga/src/front/glsl/variables.rs b/third_party/rust/naga/src/front/glsl/variables.rs
new file mode 100644
index 0000000000..4c0e0a927d
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/variables.rs
@@ -0,0 +1,679 @@
+use super::{
+ ast::*,
+ context::{Context, ExprPos},
+ error::{Error, ErrorKind},
+ Frontend, Result, Span,
+};
+use crate::{
+ AddressSpace, Binding, Block, BuiltIn, Constant, Expression, GlobalVariable, Handle,
+ Interpolation, LocalVariable, ResourceBinding, ScalarKind, ShaderStage, SwizzleComponent, Type,
+ TypeInner, VectorSize,
+};
+
+pub struct VarDeclaration<'a, 'key> {
+ pub qualifiers: &'a mut TypeQualifiers<'key>,
+ pub ty: Handle<Type>,
+ pub name: Option<String>,
+ pub init: Option<Handle<Constant>>,
+ pub meta: Span,
+}
+
+/// Information about a builtin used in [`add_builtin`](Frontend::add_builtin).
+struct BuiltInData {
+ /// The type of the builtin.
+ inner: TypeInner,
+ /// The associated builtin class.
+ builtin: BuiltIn,
+ /// Whether the builtin can be written to or not.
+ mutable: bool,
+ /// The storage used for the builtin.
+ storage: StorageQualifier,
+}
+
+pub enum GlobalOrConstant {
+ Global(Handle<GlobalVariable>),
+ Constant(Handle<Constant>),
+}
+
+impl Frontend {
+ /// Adds a builtin and returns a variable reference to it
+ fn add_builtin(
+ &mut self,
+ ctx: &mut Context,
+ body: &mut Block,
+ name: &str,
+ data: BuiltInData,
+ meta: Span,
+ ) -> Option<VariableReference> {
+ let ty = self.module.types.insert(
+ Type {
+ name: None,
+ inner: data.inner,
+ },
+ meta,
+ );
+
+ let handle = self.module.global_variables.append(
+ GlobalVariable {
+ name: Some(name.into()),
+ space: AddressSpace::Private,
+ binding: None,
+ ty,
+ init: None,
+ },
+ meta,
+ );
+
+ let idx = self.entry_args.len();
+ self.entry_args.push(EntryArg {
+ name: None,
+ binding: Binding::BuiltIn(data.builtin),
+ handle,
+ storage: data.storage,
+ });
+
+ self.global_variables.push((
+ name.into(),
+ GlobalLookup {
+ kind: GlobalLookupKind::Variable(handle),
+ entry_arg: Some(idx),
+ mutable: data.mutable,
+ },
+ ));
+
+ let expr = ctx.add_expression(Expression::GlobalVariable(handle), meta, body);
+
+ let var = VariableReference {
+ expr,
+ load: true,
+ mutable: data.mutable,
+ constant: None,
+ entry_arg: Some(idx),
+ };
+
+ ctx.symbol_table.add_root(name.into(), var.clone());
+
+ Some(var)
+ }
+
+ pub(crate) fn lookup_variable(
+ &mut self,
+ ctx: &mut Context,
+ body: &mut Block,
+ name: &str,
+ meta: Span,
+ ) -> Option<VariableReference> {
+ if let Some(var) = ctx.symbol_table.lookup(name).cloned() {
+ return Some(var);
+ }
+
+ let data = match name {
+ "gl_Position" => BuiltInData {
+ inner: TypeInner::Vector {
+ size: VectorSize::Quad,
+ kind: ScalarKind::Float,
+ width: 4,
+ },
+ builtin: BuiltIn::Position { invariant: false },
+ mutable: true,
+ storage: StorageQualifier::Output,
+ },
+ "gl_FragCoord" => BuiltInData {
+ inner: TypeInner::Vector {
+ size: VectorSize::Quad,
+ kind: ScalarKind::Float,
+ width: 4,
+ },
+ builtin: BuiltIn::Position { invariant: false },
+ mutable: false,
+ storage: StorageQualifier::Input,
+ },
+ "gl_PointCoord" => BuiltInData {
+ inner: TypeInner::Vector {
+ size: VectorSize::Bi,
+ kind: ScalarKind::Float,
+ width: 4,
+ },
+ builtin: BuiltIn::PointCoord,
+ mutable: false,
+ storage: StorageQualifier::Input,
+ },
+ "gl_GlobalInvocationID"
+ | "gl_NumWorkGroups"
+ | "gl_WorkGroupSize"
+ | "gl_WorkGroupID"
+ | "gl_LocalInvocationID" => BuiltInData {
+ inner: TypeInner::Vector {
+ size: VectorSize::Tri,
+ kind: ScalarKind::Uint,
+ width: 4,
+ },
+ builtin: match name {
+ "gl_GlobalInvocationID" => BuiltIn::GlobalInvocationId,
+ "gl_NumWorkGroups" => BuiltIn::NumWorkGroups,
+ "gl_WorkGroupSize" => BuiltIn::WorkGroupSize,
+ "gl_WorkGroupID" => BuiltIn::WorkGroupId,
+ "gl_LocalInvocationID" => BuiltIn::LocalInvocationId,
+ _ => unreachable!(),
+ },
+ mutable: false,
+ storage: StorageQualifier::Input,
+ },
+ "gl_FrontFacing" => BuiltInData {
+ inner: TypeInner::Scalar {
+ kind: ScalarKind::Bool,
+ width: crate::BOOL_WIDTH,
+ },
+ builtin: BuiltIn::FrontFacing,
+ mutable: false,
+ storage: StorageQualifier::Input,
+ },
+ "gl_PointSize" | "gl_FragDepth" => BuiltInData {
+ inner: TypeInner::Scalar {
+ kind: ScalarKind::Float,
+ width: 4,
+ },
+ builtin: match name {
+ "gl_PointSize" => BuiltIn::PointSize,
+ "gl_FragDepth" => BuiltIn::FragDepth,
+ _ => unreachable!(),
+ },
+ mutable: true,
+ storage: StorageQualifier::Output,
+ },
+ "gl_ClipDistance" | "gl_CullDistance" => {
+ let base = self.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Scalar {
+ kind: ScalarKind::Float,
+ width: 4,
+ },
+ },
+ meta,
+ );
+
+ BuiltInData {
+ inner: TypeInner::Array {
+ base,
+ size: crate::ArraySize::Dynamic,
+ stride: 4,
+ },
+ builtin: match name {
+ "gl_ClipDistance" => BuiltIn::ClipDistance,
+ "gl_CullDistance" => BuiltIn::CullDistance,
+ _ => unreachable!(),
+ },
+ mutable: self.meta.stage == ShaderStage::Vertex,
+ storage: StorageQualifier::Output,
+ }
+ }
+ _ => {
+ let builtin = match name {
+ "gl_BaseVertex" => BuiltIn::BaseVertex,
+ "gl_BaseInstance" => BuiltIn::BaseInstance,
+ "gl_PrimitiveID" => BuiltIn::PrimitiveIndex,
+ "gl_InstanceIndex" => BuiltIn::InstanceIndex,
+ "gl_VertexIndex" => BuiltIn::VertexIndex,
+ "gl_SampleID" => BuiltIn::SampleIndex,
+ "gl_LocalInvocationIndex" => BuiltIn::LocalInvocationIndex,
+ _ => return None,
+ };
+
+ BuiltInData {
+ inner: TypeInner::Scalar {
+ kind: ScalarKind::Uint,
+ width: 4,
+ },
+ builtin,
+ mutable: false,
+ storage: StorageQualifier::Input,
+ }
+ }
+ };
+
+ self.add_builtin(ctx, body, name, data, meta)
+ }
+
+ pub(crate) fn make_variable_invariant(
+ &mut self,
+ ctx: &mut Context,
+ body: &mut Block,
+ name: &str,
+ meta: Span,
+ ) {
+ if let Some(var) = self.lookup_variable(ctx, body, name, meta) {
+ if let Some(index) = var.entry_arg {
+ if let Binding::BuiltIn(BuiltIn::Position { ref mut invariant }) =
+ self.entry_args[index].binding
+ {
+ *invariant = true;
+ }
+ }
+ }
+ }
+
+ pub(crate) fn field_selection(
+ &mut self,
+ ctx: &mut Context,
+ pos: ExprPos,
+ body: &mut Block,
+ expression: Handle<Expression>,
+ name: &str,
+ meta: Span,
+ ) -> Result<Handle<Expression>> {
+ let (ty, is_pointer) = match *self.resolve_type(ctx, expression, meta)? {
+ TypeInner::Pointer { base, .. } => (&self.module.types[base].inner, true),
+ ref ty => (ty, false),
+ };
+ match *ty {
+ TypeInner::Struct { ref members, .. } => {
+ let index = members
+ .iter()
+ .position(|m| m.name == Some(name.into()))
+ .ok_or_else(|| Error {
+ kind: ErrorKind::UnknownField(name.into()),
+ meta,
+ })?;
+ let pointer = ctx.add_expression(
+ Expression::AccessIndex {
+ base: expression,
+ index: index as u32,
+ },
+ meta,
+ body,
+ );
+
+ Ok(match pos {
+ ExprPos::Rhs if is_pointer => {
+ ctx.add_expression(Expression::Load { pointer }, meta, body)
+ }
+ _ => pointer,
+ })
+ }
+ // swizzles (xyzw, rgba, stpq)
+ TypeInner::Vector { size, .. } => {
+ let check_swizzle_components = |comps: &str| {
+ name.chars()
+ .map(|c| {
+ comps
+ .find(c)
+ .filter(|i| *i < size as usize)
+ .map(|i| SwizzleComponent::from_index(i as u32))
+ })
+ .collect::<Option<Vec<SwizzleComponent>>>()
+ };
+
+ let components = check_swizzle_components("xyzw")
+ .or_else(|| check_swizzle_components("rgba"))
+ .or_else(|| check_swizzle_components("stpq"));
+
+ if let Some(components) = components {
+ if let ExprPos::Lhs = pos {
+ let not_unique = (1..components.len())
+ .any(|i| components[i..].contains(&components[i - 1]));
+ if not_unique {
+ self.errors.push(Error {
+ kind:
+ ErrorKind::SemanticError(
+ format!(
+ "swizzle cannot have duplicate components in left-hand-side expression for \"{name:?}\""
+ )
+ .into(),
+ ),
+ meta ,
+ })
+ }
+ }
+
+ let mut pattern = [SwizzleComponent::X; 4];
+ for (pat, component) in pattern.iter_mut().zip(&components) {
+ *pat = *component;
+ }
+
+ // flatten nested swizzles (vec.zyx.xy.x => vec.z)
+ let mut expression = expression;
+ if let Expression::Swizzle {
+ size: _,
+ vector,
+ pattern: ref src_pattern,
+ } = ctx[expression]
+ {
+ expression = vector;
+ for pat in &mut pattern {
+ *pat = src_pattern[pat.index() as usize];
+ }
+ }
+
+ let size = match components.len() {
+ // Swizzles with just one component are accesses and not swizzles
+ 1 => {
+ match pos {
+ // If the position is in the right hand side and the base
+ // vector is a pointer, load it, otherwise the swizzle would
+ // produce a pointer
+ ExprPos::Rhs if is_pointer => {
+ expression = ctx.add_expression(
+ Expression::Load {
+ pointer: expression,
+ },
+ meta,
+ body,
+ );
+ }
+ _ => {}
+ };
+ return Ok(ctx.add_expression(
+ Expression::AccessIndex {
+ base: expression,
+ index: pattern[0].index(),
+ },
+ meta,
+ body,
+ ));
+ }
+ 2 => VectorSize::Bi,
+ 3 => VectorSize::Tri,
+ 4 => VectorSize::Quad,
+ _ => {
+ self.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ format!("Bad swizzle size for \"{name:?}\"").into(),
+ ),
+ meta,
+ });
+
+ VectorSize::Quad
+ }
+ };
+
+ if is_pointer {
+ // NOTE: for lhs expression, this extra load ends up as an unused expr, because the
+ // assignment will extract the pointer and use it directly anyway. Unfortunately we
+ // need it for validation to pass, as swizzles cannot operate on pointer values.
+ expression = ctx.add_expression(
+ Expression::Load {
+ pointer: expression,
+ },
+ meta,
+ body,
+ );
+ }
+
+ Ok(ctx.add_expression(
+ Expression::Swizzle {
+ size,
+ vector: expression,
+ pattern,
+ },
+ meta,
+ body,
+ ))
+ } else {
+ Err(Error {
+ kind: ErrorKind::SemanticError(
+ format!("Invalid swizzle for vector \"{name}\"").into(),
+ ),
+ meta,
+ })
+ }
+ }
+ _ => Err(Error {
+ kind: ErrorKind::SemanticError(
+ format!("Can't lookup field on this type \"{name}\"").into(),
+ ),
+ meta,
+ }),
+ }
+ }
+
+ pub(crate) fn add_global_var(
+ &mut self,
+ ctx: &mut Context,
+ body: &mut Block,
+ VarDeclaration {
+ qualifiers,
+ mut ty,
+ name,
+ init,
+ meta,
+ }: VarDeclaration,
+ ) -> Result<GlobalOrConstant> {
+ let storage = qualifiers.storage.0;
+ let (ret, lookup) = match storage {
+ StorageQualifier::Input | StorageQualifier::Output => {
+ let input = storage == StorageQualifier::Input;
+ // TODO: glslang seems to use a counter for variables without
+ // explicit location (even if that causes collisions)
+ let location = qualifiers
+ .uint_layout_qualifier("location", &mut self.errors)
+ .unwrap_or(0);
+ let interpolation = qualifiers.interpolation.take().map(|(i, _)| i).or_else(|| {
+ let kind = self.module.types[ty].inner.scalar_kind()?;
+ Some(match kind {
+ ScalarKind::Float => Interpolation::Perspective,
+ _ => Interpolation::Flat,
+ })
+ });
+ let sampling = qualifiers.sampling.take().map(|(s, _)| s);
+
+ let handle = self.module.global_variables.append(
+ GlobalVariable {
+ name: name.clone(),
+ space: AddressSpace::Private,
+ binding: None,
+ ty,
+ init,
+ },
+ meta,
+ );
+
+ let idx = self.entry_args.len();
+ self.entry_args.push(EntryArg {
+ name: name.clone(),
+ binding: Binding::Location {
+ location,
+ interpolation,
+ sampling,
+ },
+ handle,
+ storage,
+ });
+
+ let lookup = GlobalLookup {
+ kind: GlobalLookupKind::Variable(handle),
+ entry_arg: Some(idx),
+ mutable: !input,
+ };
+
+ (GlobalOrConstant::Global(handle), lookup)
+ }
+ StorageQualifier::Const => {
+ let init = init.ok_or_else(|| Error {
+ kind: ErrorKind::SemanticError("const values must have an initializer".into()),
+ meta,
+ })?;
+
+ let lookup = GlobalLookup {
+ kind: GlobalLookupKind::Constant(init, ty),
+ entry_arg: None,
+ mutable: false,
+ };
+
+ if let Some(name) = name.as_ref() {
+ let constant = self.module.constants.get_mut(init);
+ if constant.name.is_none() {
+ // set the name of the constant
+ constant.name = Some(name.clone())
+ } else {
+ // add a copy of the constant with the new name
+ let new_const = Constant {
+ name: Some(name.clone()),
+ specialization: constant.specialization,
+ inner: constant.inner.clone(),
+ };
+ self.module.constants.fetch_or_append(new_const, meta);
+ }
+ }
+
+ (GlobalOrConstant::Constant(init), lookup)
+ }
+ StorageQualifier::AddressSpace(mut space) => {
+ match space {
+ AddressSpace::Storage { ref mut access } => {
+ if let Some((allowed_access, _)) = qualifiers.storage_access.take() {
+ *access = allowed_access;
+ }
+ }
+ AddressSpace::Uniform => match self.module.types[ty].inner {
+ TypeInner::Image {
+ class,
+ dim,
+ arrayed,
+ } => {
+ if let crate::ImageClass::Storage {
+ mut access,
+ mut format,
+ } = class
+ {
+ if let Some((allowed_access, _)) = qualifiers.storage_access.take()
+ {
+ access = allowed_access;
+ }
+
+ match qualifiers.layout_qualifiers.remove(&QualifierKey::Format) {
+ Some((QualifierValue::Format(f), _)) => format = f,
+ // TODO: glsl supports images without format qualifier
+ // if they are `writeonly`
+ None => self.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "image types require a format layout qualifier".into(),
+ ),
+ meta,
+ }),
+ _ => unreachable!(),
+ }
+
+ ty = self.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Image {
+ dim,
+ arrayed,
+ class: crate::ImageClass::Storage { format, access },
+ },
+ },
+ meta,
+ );
+ }
+
+ space = AddressSpace::Handle
+ }
+ TypeInner::Sampler { .. } => space = AddressSpace::Handle,
+ _ => {
+ if qualifiers.none_layout_qualifier("push_constant", &mut self.errors) {
+ space = AddressSpace::PushConstant
+ }
+ }
+ },
+ AddressSpace::Function => space = AddressSpace::Private,
+ _ => {}
+ };
+
+ let binding = match space {
+ AddressSpace::Uniform | AddressSpace::Storage { .. } | AddressSpace::Handle => {
+ let binding = qualifiers.uint_layout_qualifier("binding", &mut self.errors);
+ if binding.is_none() {
+ self.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "uniform/buffer blocks require layout(binding=X)".into(),
+ ),
+ meta,
+ });
+ }
+ let set = qualifiers.uint_layout_qualifier("set", &mut self.errors);
+ binding.map(|binding| ResourceBinding {
+ group: set.unwrap_or(0),
+ binding,
+ })
+ }
+ _ => None,
+ };
+
+ let handle = self.module.global_variables.append(
+ GlobalVariable {
+ name: name.clone(),
+ space,
+ binding,
+ ty,
+ init,
+ },
+ meta,
+ );
+
+ let lookup = GlobalLookup {
+ kind: GlobalLookupKind::Variable(handle),
+ entry_arg: None,
+ mutable: true,
+ };
+
+ (GlobalOrConstant::Global(handle), lookup)
+ }
+ };
+
+ if let Some(name) = name {
+ ctx.add_global(self, &name, lookup, body);
+
+ self.global_variables.push((name, lookup));
+ }
+
+ qualifiers.unused_errors(&mut self.errors);
+
+ Ok(ret)
+ }
+
+ pub(crate) fn add_local_var(
+ &mut self,
+ ctx: &mut Context,
+ body: &mut Block,
+ decl: VarDeclaration,
+ ) -> Result<Handle<Expression>> {
+ let storage = decl.qualifiers.storage;
+ let mutable = match storage.0 {
+ StorageQualifier::AddressSpace(AddressSpace::Function) => true,
+ StorageQualifier::Const => false,
+ _ => {
+ self.errors.push(Error {
+ kind: ErrorKind::SemanticError("Locals cannot have a storage qualifier".into()),
+ meta: storage.1,
+ });
+ true
+ }
+ };
+
+ let handle = ctx.locals.append(
+ LocalVariable {
+ name: decl.name.clone(),
+ ty: decl.ty,
+ init: None,
+ },
+ decl.meta,
+ );
+ let expr = ctx.add_expression(Expression::LocalVariable(handle), decl.meta, body);
+
+ if let Some(name) = decl.name {
+ let maybe_var = ctx.add_local_var(name.clone(), expr, mutable);
+
+ if maybe_var.is_some() {
+ self.errors.push(Error {
+ kind: ErrorKind::VariableAlreadyDeclared(name),
+ meta: decl.meta,
+ })
+ }
+ }
+
+ decl.qualifiers.unused_errors(&mut self.errors);
+
+ Ok(expr)
+ }
+}
diff --git a/third_party/rust/naga/src/front/interpolator.rs b/third_party/rust/naga/src/front/interpolator.rs
new file mode 100644
index 0000000000..96f5db6ff7
--- /dev/null
+++ b/third_party/rust/naga/src/front/interpolator.rs
@@ -0,0 +1,61 @@
+/*!
+Interpolation defaults.
+*/
+
+impl crate::Binding {
+ /// Apply the usual default interpolation for `ty` to `binding`.
+ ///
+ /// This function is a utility front ends may use to satisfy the Naga IR's
+ /// requirement, meant to ensure that input languages' policies have been
+ /// applied appropriately, that all I/O `Binding`s from the vertex shader to the
+ /// fragment shader must have non-`None` `interpolation` values.
+ ///
+ /// All the shader languages Naga supports have similar rules:
+ /// perspective-correct, center-sampled interpolation is the default for any
+ /// binding that can vary, and everything else either defaults to flat, or
+ /// requires an explicit flat qualifier/attribute/what-have-you.
+ ///
+ /// If `binding` is not a [`Location`] binding, or if its [`interpolation`] is
+ /// already set, then make no changes. Otherwise, set `binding`'s interpolation
+ /// and sampling to reasonable defaults depending on `ty`, the type of the value
+ /// being interpolated:
+ ///
+ /// - If `ty` is a floating-point scalar, vector, or matrix type, then
+ /// default to [`Perspective`] interpolation and [`Center`] sampling.
+ ///
+ /// - If `ty` is an integral scalar or vector, then default to [`Flat`]
+ /// interpolation, which has no associated sampling.
+ ///
+ /// - For any other types, make no change. Such types are not permitted as
+ /// user-defined IO values, and will probably be flagged by the verifier
+ ///
+ /// When structs appear in input or output types, each member ought to have its
+ /// own [`Binding`], so structs are simply covered by the third case.
+ ///
+ /// [`Binding`]: crate::Binding
+ /// [`Location`]: crate::Binding::Location
+ /// [`interpolation`]: crate::Binding::Location::interpolation
+ /// [`Perspective`]: crate::Interpolation::Perspective
+ /// [`Flat`]: crate::Interpolation::Flat
+ /// [`Center`]: crate::Sampling::Center
+ pub fn apply_default_interpolation(&mut self, ty: &crate::TypeInner) {
+ if let crate::Binding::Location {
+ location: _,
+ interpolation: ref mut interpolation @ None,
+ ref mut sampling,
+ } = *self
+ {
+ match ty.scalar_kind() {
+ Some(crate::ScalarKind::Float) => {
+ *interpolation = Some(crate::Interpolation::Perspective);
+ *sampling = Some(crate::Sampling::Center);
+ }
+ Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
+ *interpolation = Some(crate::Interpolation::Flat);
+ *sampling = None;
+ }
+ Some(_) | None => {}
+ }
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/front/mod.rs b/third_party/rust/naga/src/front/mod.rs
new file mode 100644
index 0000000000..d6f38671ea
--- /dev/null
+++ b/third_party/rust/naga/src/front/mod.rs
@@ -0,0 +1,374 @@
+/*!
+Frontend parsers that consume binary and text shaders and load them into [`Module`](super::Module)s.
+*/
+
+mod interpolator;
+mod type_gen;
+
+#[cfg(feature = "glsl-in")]
+pub mod glsl;
+#[cfg(feature = "spv-in")]
+pub mod spv;
+#[cfg(feature = "wgsl-in")]
+pub mod wgsl;
+
+use crate::{
+ arena::{Arena, Handle, UniqueArena},
+ proc::{ResolveContext, ResolveError, TypeResolution},
+ FastHashMap,
+};
+use std::ops;
+
+/// Helper class to emit expressions
+#[allow(dead_code)]
+#[derive(Default, Debug)]
+struct Emitter {
+ start_len: Option<usize>,
+}
+
+#[allow(dead_code)]
+impl Emitter {
+ fn start(&mut self, arena: &Arena<crate::Expression>) {
+ if self.start_len.is_some() {
+ unreachable!("Emitting has already started!");
+ }
+ self.start_len = Some(arena.len());
+ }
+ #[must_use]
+ fn finish(
+ &mut self,
+ arena: &Arena<crate::Expression>,
+ ) -> Option<(crate::Statement, crate::span::Span)> {
+ let start_len = self.start_len.take().unwrap();
+ if start_len != arena.len() {
+ #[allow(unused_mut)]
+ let mut span = crate::span::Span::default();
+ let range = arena.range_from(start_len);
+ #[cfg(feature = "span")]
+ for handle in range.clone() {
+ span.subsume(arena.get_span(handle))
+ }
+ Some((crate::Statement::Emit(range), span))
+ } else {
+ None
+ }
+ }
+}
+
+#[allow(dead_code)]
+impl super::ConstantInner {
+ const fn boolean(value: bool) -> Self {
+ Self::Scalar {
+ width: super::BOOL_WIDTH,
+ value: super::ScalarValue::Bool(value),
+ }
+ }
+}
+
+/// A table of types for an `Arena<Expression>`.
+///
+/// A front end can use a `Typifier` to get types for an arena's expressions
+/// while it is still contributing expressions to it. At any point, you can call
+/// [`typifier.grow(expr, arena, ctx)`], where `expr` is a `Handle<Expression>`
+/// referring to something in `arena`, and the `Typifier` will resolve the types
+/// of all the expressions up to and including `expr`. Then you can write
+/// `typifier[handle]` to get the type of any handle at or before `expr`.
+///
+/// Note that `Typifier` does *not* build an `Arena<Type>` as a part of its
+/// usual operation. Ideally, a module's type arena should only contain types
+/// actually needed by `Handle<Type>`s elsewhere in the module — functions,
+/// variables, [`Compose`] expressions, other types, and so on — so we don't
+/// want every little thing that occurs as the type of some intermediate
+/// expression to show up there.
+///
+/// Instead, `Typifier` accumulates a [`TypeResolution`] for each expression,
+/// which refers to the `Arena<Type>` in the [`ResolveContext`] passed to `grow`
+/// as needed. [`TypeResolution`] is a lightweight representation for
+/// intermediate types like this; see its documentation for details.
+///
+/// If you do need to register a `Typifier`'s conclusion in an `Arena<Type>`
+/// (say, for a [`LocalVariable`] whose type you've inferred), you can use
+/// [`register_type`] to do so.
+///
+/// [`typifier.grow(expr, arena)`]: Typifier::grow
+/// [`register_type`]: Typifier::register_type
+/// [`Compose`]: crate::Expression::Compose
+/// [`LocalVariable`]: crate::LocalVariable
+#[derive(Debug, Default)]
+pub struct Typifier {
+ resolutions: Vec<TypeResolution>,
+}
+
+impl Typifier {
+ pub const fn new() -> Self {
+ Typifier {
+ resolutions: Vec::new(),
+ }
+ }
+
+ pub fn reset(&mut self) {
+ self.resolutions.clear()
+ }
+
+ pub fn get<'a>(
+ &'a self,
+ expr_handle: Handle<crate::Expression>,
+ types: &'a UniqueArena<crate::Type>,
+ ) -> &'a crate::TypeInner {
+ self.resolutions[expr_handle.index()].inner_with(types)
+ }
+
+ /// Add an expression's type to an `Arena<Type>`.
+ ///
+ /// Add the type of `expr_handle` to `types`, and return a `Handle<Type>`
+ /// referring to it.
+ ///
+ /// # Note
+ ///
+ /// If you just need a [`TypeInner`] for `expr_handle`'s type, consider
+ /// using `typifier[expression].inner_with(types)` instead. Calling
+ /// [`TypeResolution::inner_with`] often lets us avoid adding anything to
+ /// the arena, which can significantly reduce the number of types that end
+ /// up in the final module.
+ ///
+ /// [`TypeInner`]: crate::TypeInner
+ pub fn register_type(
+ &self,
+ expr_handle: Handle<crate::Expression>,
+ types: &mut UniqueArena<crate::Type>,
+ ) -> Handle<crate::Type> {
+ match self[expr_handle].clone() {
+ TypeResolution::Handle(handle) => handle,
+ TypeResolution::Value(inner) => {
+ types.insert(crate::Type { name: None, inner }, crate::Span::UNDEFINED)
+ }
+ }
+ }
+
+ /// Grow this typifier until it contains a type for `expr_handle`.
+ pub fn grow(
+ &mut self,
+ expr_handle: Handle<crate::Expression>,
+ expressions: &Arena<crate::Expression>,
+ ctx: &ResolveContext,
+ ) -> Result<(), ResolveError> {
+ if self.resolutions.len() <= expr_handle.index() {
+ for (eh, expr) in expressions.iter().skip(self.resolutions.len()) {
+ //Note: the closure can't `Err` by construction
+ let resolution = ctx.resolve(expr, |h| Ok(&self.resolutions[h.index()]))?;
+ log::debug!("Resolving {:?} = {:?} : {:?}", eh, expr, resolution);
+ self.resolutions.push(resolution);
+ }
+ }
+ Ok(())
+ }
+
+ /// Recompute the type resolution for `expr_handle`.
+ ///
+ /// If the type of `expr_handle` hasn't yet been calculated, call
+ /// [`grow`](Self::grow) to ensure it is covered.
+ ///
+ /// In either case, when this returns, `self[expr_handle]` should be an
+ /// updated type resolution for `expr_handle`.
+ pub fn invalidate(
+ &mut self,
+ expr_handle: Handle<crate::Expression>,
+ expressions: &Arena<crate::Expression>,
+ ctx: &ResolveContext,
+ ) -> Result<(), ResolveError> {
+ if self.resolutions.len() <= expr_handle.index() {
+ self.grow(expr_handle, expressions, ctx)
+ } else {
+ let expr = &expressions[expr_handle];
+ //Note: the closure can't `Err` by construction
+ let resolution = ctx.resolve(expr, |h| Ok(&self.resolutions[h.index()]))?;
+ self.resolutions[expr_handle.index()] = resolution;
+ Ok(())
+ }
+ }
+}
+
+impl ops::Index<Handle<crate::Expression>> for Typifier {
+ type Output = TypeResolution;
+ fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
+ &self.resolutions[handle.index()]
+ }
+}
+
+/// Type representing a lexical scope, associating a name to a single variable
+///
+/// The scope is generic over the variable representation and name representaion
+/// in order to allow larger flexibility on the frontends on how they might
+/// represent them.
+type Scope<Name, Var> = FastHashMap<Name, Var>;
+
+/// Structure responsible for managing variable lookups and keeping track of
+/// lexical scopes
+///
+/// The symbol table is generic over the variable representation and its name
+/// to allow larger flexibility on the frontends on how they might represent them.
+///
+/// ```
+/// use naga::front::SymbolTable;
+///
+/// // Create a new symbol table with `u32`s representing the variable
+/// let mut symbol_table: SymbolTable<&str, u32> = SymbolTable::default();
+///
+/// // Add two variables named `var1` and `var2` with 0 and 2 respectively
+/// symbol_table.add("var1", 0);
+/// symbol_table.add("var2", 2);
+///
+/// // Check that `var1` exists and is `0`
+/// assert_eq!(symbol_table.lookup("var1"), Some(&0));
+///
+/// // Push a new scope and add a variable to it named `var1` shadowing the
+/// // variable of our previous scope
+/// symbol_table.push_scope();
+/// symbol_table.add("var1", 1);
+///
+/// // Check that `var1` now points to the new value of `1` and `var2` still
+/// // exists with its value of `2`
+/// assert_eq!(symbol_table.lookup("var1"), Some(&1));
+/// assert_eq!(symbol_table.lookup("var2"), Some(&2));
+///
+/// // Pop the scope
+/// symbol_table.pop_scope();
+///
+/// // Check that `var1` now refers to our initial variable with value `0`
+/// assert_eq!(symbol_table.lookup("var1"), Some(&0));
+/// ```
+///
+/// Scopes are ordered as a LIFO stack so a variable defined in a later scope
+/// with the same name as another variable defined in a earlier scope will take
+/// precedence in the lookup. Scopes can be added with [`push_scope`] and
+/// removed with [`pop_scope`].
+///
+/// A root scope is added when the symbol table is created and must always be
+/// present. Trying to pop it will result in a panic.
+///
+/// Variables can be added with [`add`] and looked up with [`lookup`]. Adding a
+/// variable will do so in the currently active scope and as mentioned
+/// previously a lookup will search from the current scope to the root scope.
+///
+/// [`push_scope`]: Self::push_scope
+/// [`pop_scope`]: Self::push_scope
+/// [`add`]: Self::add
+/// [`lookup`]: Self::lookup
+pub struct SymbolTable<Name, Var> {
+ /// Stack of lexical scopes. Not all scopes are active; see [`cursor`].
+ ///
+ /// [`cursor`]: Self::cursor
+ scopes: Vec<Scope<Name, Var>>,
+ /// Limit of the [`scopes`] stack (exclusive). By using a separate value for
+ /// the stack length instead of `Vec`'s own internal length, the scopes can
+ /// be reused to cache memory allocations.
+ ///
+ /// [`scopes`]: Self::scopes
+ cursor: usize,
+}
+
+impl<Name, Var> SymbolTable<Name, Var> {
+ /// Adds a new lexical scope.
+ ///
+ /// All variables declared after this point will be added to this scope
+ /// until another scope is pushed or [`pop_scope`] is called, causing this
+ /// scope to be removed along with all variables added to it.
+ ///
+ /// [`pop_scope`]: Self::pop_scope
+ pub fn push_scope(&mut self) {
+ // If the cursor is equal to the scope's stack length then we need to
+ // push another empty scope. Otherwise we can reuse the already existing
+ // scope.
+ if self.scopes.len() == self.cursor {
+ self.scopes.push(FastHashMap::default())
+ } else {
+ self.scopes[self.cursor].clear();
+ }
+
+ self.cursor += 1;
+ }
+
+ /// Removes the current lexical scope and all its variables
+ ///
+ /// # PANICS
+ /// - If the current lexical scope is the root scope
+ pub fn pop_scope(&mut self) {
+ // Despite the method title, the variables are only deleted when the
+ // scope is reused. This is because while a clear is inevitable if the
+ // scope needs to be reused, there are cases where the scope might be
+ // popped and not reused, i.e. if another scope with the same nesting
+ // level is never pushed again.
+ assert!(self.cursor != 1, "Tried to pop the root scope");
+
+ self.cursor -= 1;
+ }
+}
+
+impl<Name, Var> SymbolTable<Name, Var>
+where
+ Name: std::hash::Hash + Eq,
+{
+ /// Perform a lookup for a variable named `name`.
+ ///
+ /// As stated in the struct level documentation the lookup will proceed from
+ /// the current scope to the root scope, returning `Some` when a variable is
+ /// found or `None` if there doesn't exist a variable with `name` in any
+ /// scope.
+ pub fn lookup<Q: ?Sized>(&self, name: &Q) -> Option<&Var>
+ where
+ Name: std::borrow::Borrow<Q>,
+ Q: std::hash::Hash + Eq,
+ {
+ // Iterate backwards trough the scopes and try to find the variable
+ for scope in self.scopes[..self.cursor].iter().rev() {
+ if let Some(var) = scope.get(name) {
+ return Some(var);
+ }
+ }
+
+ None
+ }
+
+ /// Adds a new variable to the current scope.
+ ///
+ /// Returns the previous variable with the same name in this scope if it
+ /// exists, so that the frontend might handle it in case variable shadowing
+ /// is disallowed.
+ pub fn add(&mut self, name: Name, var: Var) -> Option<Var> {
+ self.scopes[self.cursor - 1].insert(name, var)
+ }
+
+ /// Adds a new variable to the root scope.
+ ///
+ /// This is used in GLSL for builtins which aren't known in advance and only
+ /// when used for the first time, so there must be a way to add those
+ /// declarations to the root unconditionally from the current scope.
+ ///
+ /// Returns the previous variable with the same name in the root scope if it
+ /// exists, so that the frontend might handle it in case variable shadowing
+ /// is disallowed.
+ pub fn add_root(&mut self, name: Name, var: Var) -> Option<Var> {
+ self.scopes[0].insert(name, var)
+ }
+}
+
+impl<Name, Var> Default for SymbolTable<Name, Var> {
+ /// Constructs a new symbol table with a root scope
+ fn default() -> Self {
+ Self {
+ scopes: vec![FastHashMap::default()],
+ cursor: 1,
+ }
+ }
+}
+
+use std::fmt;
+
+impl<Name: fmt::Debug, Var: fmt::Debug> fmt::Debug for SymbolTable<Name, Var> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.write_str("SymbolTable ")?;
+ f.debug_list()
+ .entries(self.scopes[..self.cursor].iter())
+ .finish()
+ }
+}
diff --git a/third_party/rust/naga/src/front/spv/convert.rs b/third_party/rust/naga/src/front/spv/convert.rs
new file mode 100644
index 0000000000..2967555d2b
--- /dev/null
+++ b/third_party/rust/naga/src/front/spv/convert.rs
@@ -0,0 +1,181 @@
+use super::error::Error;
+use num_traits::cast::FromPrimitive;
+use std::convert::TryInto;
+
+pub(super) const fn map_binary_operator(word: spirv::Op) -> Result<crate::BinaryOperator, Error> {
+ use crate::BinaryOperator;
+ use spirv::Op;
+
+ match word {
+ // Arithmetic Instructions +, -, *, /, %
+ Op::IAdd | Op::FAdd => Ok(BinaryOperator::Add),
+ Op::ISub | Op::FSub => Ok(BinaryOperator::Subtract),
+ Op::IMul | Op::FMul => Ok(BinaryOperator::Multiply),
+ Op::UDiv | Op::SDiv | Op::FDiv => Ok(BinaryOperator::Divide),
+ Op::SRem => Ok(BinaryOperator::Modulo),
+ // Relational and Logical Instructions
+ Op::IEqual | Op::FOrdEqual | Op::FUnordEqual | Op::LogicalEqual => {
+ Ok(BinaryOperator::Equal)
+ }
+ Op::INotEqual | Op::FOrdNotEqual | Op::FUnordNotEqual | Op::LogicalNotEqual => {
+ Ok(BinaryOperator::NotEqual)
+ }
+ Op::ULessThan | Op::SLessThan | Op::FOrdLessThan | Op::FUnordLessThan => {
+ Ok(BinaryOperator::Less)
+ }
+ Op::ULessThanEqual
+ | Op::SLessThanEqual
+ | Op::FOrdLessThanEqual
+ | Op::FUnordLessThanEqual => Ok(BinaryOperator::LessEqual),
+ Op::UGreaterThan | Op::SGreaterThan | Op::FOrdGreaterThan | Op::FUnordGreaterThan => {
+ Ok(BinaryOperator::Greater)
+ }
+ Op::UGreaterThanEqual
+ | Op::SGreaterThanEqual
+ | Op::FOrdGreaterThanEqual
+ | Op::FUnordGreaterThanEqual => Ok(BinaryOperator::GreaterEqual),
+ Op::BitwiseOr => Ok(BinaryOperator::InclusiveOr),
+ Op::BitwiseXor => Ok(BinaryOperator::ExclusiveOr),
+ Op::BitwiseAnd => Ok(BinaryOperator::And),
+ _ => Err(Error::UnknownBinaryOperator(word)),
+ }
+}
+
+pub(super) const fn map_relational_fun(
+ word: spirv::Op,
+) -> Result<crate::RelationalFunction, Error> {
+ use crate::RelationalFunction as Rf;
+ use spirv::Op;
+
+ match word {
+ Op::All => Ok(Rf::All),
+ Op::Any => Ok(Rf::Any),
+ Op::IsNan => Ok(Rf::IsNan),
+ Op::IsInf => Ok(Rf::IsInf),
+ Op::IsFinite => Ok(Rf::IsFinite),
+ Op::IsNormal => Ok(Rf::IsNormal),
+ _ => Err(Error::UnknownRelationalFunction(word)),
+ }
+}
+
+pub(super) const fn map_vector_size(word: spirv::Word) -> Result<crate::VectorSize, Error> {
+ match word {
+ 2 => Ok(crate::VectorSize::Bi),
+ 3 => Ok(crate::VectorSize::Tri),
+ 4 => Ok(crate::VectorSize::Quad),
+ _ => Err(Error::InvalidVectorSize(word)),
+ }
+}
+
+pub(super) fn map_image_dim(word: spirv::Word) -> Result<crate::ImageDimension, Error> {
+ use spirv::Dim as D;
+ match D::from_u32(word) {
+ Some(D::Dim1D) => Ok(crate::ImageDimension::D1),
+ Some(D::Dim2D) => Ok(crate::ImageDimension::D2),
+ Some(D::Dim3D) => Ok(crate::ImageDimension::D3),
+ Some(D::DimCube) => Ok(crate::ImageDimension::Cube),
+ _ => Err(Error::UnsupportedImageDim(word)),
+ }
+}
+
+pub(super) fn map_image_format(word: spirv::Word) -> Result<crate::StorageFormat, Error> {
+ match spirv::ImageFormat::from_u32(word) {
+ Some(spirv::ImageFormat::R8) => Ok(crate::StorageFormat::R8Unorm),
+ Some(spirv::ImageFormat::R8Snorm) => Ok(crate::StorageFormat::R8Snorm),
+ Some(spirv::ImageFormat::R8ui) => Ok(crate::StorageFormat::R8Uint),
+ Some(spirv::ImageFormat::R8i) => Ok(crate::StorageFormat::R8Sint),
+ Some(spirv::ImageFormat::R16) => Ok(crate::StorageFormat::R16Unorm),
+ Some(spirv::ImageFormat::R16Snorm) => Ok(crate::StorageFormat::R16Snorm),
+ Some(spirv::ImageFormat::R16ui) => Ok(crate::StorageFormat::R16Uint),
+ Some(spirv::ImageFormat::R16i) => Ok(crate::StorageFormat::R16Sint),
+ Some(spirv::ImageFormat::R16f) => Ok(crate::StorageFormat::R16Float),
+ Some(spirv::ImageFormat::Rg8) => Ok(crate::StorageFormat::Rg8Unorm),
+ Some(spirv::ImageFormat::Rg8Snorm) => Ok(crate::StorageFormat::Rg8Snorm),
+ Some(spirv::ImageFormat::Rg8ui) => Ok(crate::StorageFormat::Rg8Uint),
+ Some(spirv::ImageFormat::Rg8i) => Ok(crate::StorageFormat::Rg8Sint),
+ Some(spirv::ImageFormat::R32ui) => Ok(crate::StorageFormat::R32Uint),
+ Some(spirv::ImageFormat::R32i) => Ok(crate::StorageFormat::R32Sint),
+ Some(spirv::ImageFormat::R32f) => Ok(crate::StorageFormat::R32Float),
+ Some(spirv::ImageFormat::Rg16) => Ok(crate::StorageFormat::Rg16Unorm),
+ Some(spirv::ImageFormat::Rg16Snorm) => Ok(crate::StorageFormat::Rg16Snorm),
+ Some(spirv::ImageFormat::Rg16ui) => Ok(crate::StorageFormat::Rg16Uint),
+ Some(spirv::ImageFormat::Rg16i) => Ok(crate::StorageFormat::Rg16Sint),
+ Some(spirv::ImageFormat::Rg16f) => Ok(crate::StorageFormat::Rg16Float),
+ Some(spirv::ImageFormat::Rgba8) => Ok(crate::StorageFormat::Rgba8Unorm),
+ Some(spirv::ImageFormat::Rgba8Snorm) => Ok(crate::StorageFormat::Rgba8Snorm),
+ Some(spirv::ImageFormat::Rgba8ui) => Ok(crate::StorageFormat::Rgba8Uint),
+ Some(spirv::ImageFormat::Rgba8i) => Ok(crate::StorageFormat::Rgba8Sint),
+ Some(spirv::ImageFormat::Rgb10a2ui) => Ok(crate::StorageFormat::Rgb10a2Unorm),
+ Some(spirv::ImageFormat::R11fG11fB10f) => Ok(crate::StorageFormat::Rg11b10Float),
+ Some(spirv::ImageFormat::Rg32ui) => Ok(crate::StorageFormat::Rg32Uint),
+ Some(spirv::ImageFormat::Rg32i) => Ok(crate::StorageFormat::Rg32Sint),
+ Some(spirv::ImageFormat::Rg32f) => Ok(crate::StorageFormat::Rg32Float),
+ Some(spirv::ImageFormat::Rgba16) => Ok(crate::StorageFormat::Rgba16Unorm),
+ Some(spirv::ImageFormat::Rgba16Snorm) => Ok(crate::StorageFormat::Rgba16Snorm),
+ Some(spirv::ImageFormat::Rgba16ui) => Ok(crate::StorageFormat::Rgba16Uint),
+ Some(spirv::ImageFormat::Rgba16i) => Ok(crate::StorageFormat::Rgba16Sint),
+ Some(spirv::ImageFormat::Rgba16f) => Ok(crate::StorageFormat::Rgba16Float),
+ Some(spirv::ImageFormat::Rgba32ui) => Ok(crate::StorageFormat::Rgba32Uint),
+ Some(spirv::ImageFormat::Rgba32i) => Ok(crate::StorageFormat::Rgba32Sint),
+ Some(spirv::ImageFormat::Rgba32f) => Ok(crate::StorageFormat::Rgba32Float),
+ _ => Err(Error::UnsupportedImageFormat(word)),
+ }
+}
+
+pub(super) fn map_width(word: spirv::Word) -> Result<crate::Bytes, Error> {
+ (word >> 3) // bits to bytes
+ .try_into()
+ .map_err(|_| Error::InvalidTypeWidth(word))
+}
+
+pub(super) fn map_builtin(word: spirv::Word, invariant: bool) -> Result<crate::BuiltIn, Error> {
+ use spirv::BuiltIn as Bi;
+ Ok(match spirv::BuiltIn::from_u32(word) {
+ Some(Bi::Position | Bi::FragCoord) => crate::BuiltIn::Position { invariant },
+ Some(Bi::ViewIndex) => crate::BuiltIn::ViewIndex,
+ // vertex
+ Some(Bi::BaseInstance) => crate::BuiltIn::BaseInstance,
+ Some(Bi::BaseVertex) => crate::BuiltIn::BaseVertex,
+ Some(Bi::ClipDistance) => crate::BuiltIn::ClipDistance,
+ Some(Bi::CullDistance) => crate::BuiltIn::CullDistance,
+ Some(Bi::InstanceIndex) => crate::BuiltIn::InstanceIndex,
+ Some(Bi::PointSize) => crate::BuiltIn::PointSize,
+ Some(Bi::VertexIndex) => crate::BuiltIn::VertexIndex,
+ // fragment
+ Some(Bi::FragDepth) => crate::BuiltIn::FragDepth,
+ Some(Bi::PointCoord) => crate::BuiltIn::PointCoord,
+ Some(Bi::FrontFacing) => crate::BuiltIn::FrontFacing,
+ Some(Bi::PrimitiveId) => crate::BuiltIn::PrimitiveIndex,
+ Some(Bi::SampleId) => crate::BuiltIn::SampleIndex,
+ Some(Bi::SampleMask) => crate::BuiltIn::SampleMask,
+ // compute
+ Some(Bi::GlobalInvocationId) => crate::BuiltIn::GlobalInvocationId,
+ Some(Bi::LocalInvocationId) => crate::BuiltIn::LocalInvocationId,
+ Some(Bi::LocalInvocationIndex) => crate::BuiltIn::LocalInvocationIndex,
+ Some(Bi::WorkgroupId) => crate::BuiltIn::WorkGroupId,
+ Some(Bi::WorkgroupSize) => crate::BuiltIn::WorkGroupSize,
+ Some(Bi::NumWorkgroups) => crate::BuiltIn::NumWorkGroups,
+ _ => return Err(Error::UnsupportedBuiltIn(word)),
+ })
+}
+
+pub(super) fn map_storage_class(word: spirv::Word) -> Result<super::ExtendedClass, Error> {
+ use super::ExtendedClass as Ec;
+ use spirv::StorageClass as Sc;
+ Ok(match Sc::from_u32(word) {
+ Some(Sc::Function) => Ec::Global(crate::AddressSpace::Function),
+ Some(Sc::Input) => Ec::Input,
+ Some(Sc::Output) => Ec::Output,
+ Some(Sc::Private) => Ec::Global(crate::AddressSpace::Private),
+ Some(Sc::UniformConstant) => Ec::Global(crate::AddressSpace::Handle),
+ Some(Sc::StorageBuffer) => Ec::Global(crate::AddressSpace::Storage {
+ //Note: this is restricted by decorations later
+ access: crate::StorageAccess::all(),
+ }),
+ // we expect the `Storage` case to be filtered out before calling this function.
+ Some(Sc::Uniform) => Ec::Global(crate::AddressSpace::Uniform),
+ Some(Sc::Workgroup) => Ec::Global(crate::AddressSpace::WorkGroup),
+ Some(Sc::PushConstant) => Ec::Global(crate::AddressSpace::PushConstant),
+ _ => return Err(Error::UnsupportedStorageClass(word)),
+ })
+}
diff --git a/third_party/rust/naga/src/front/spv/error.rs b/third_party/rust/naga/src/front/spv/error.rs
new file mode 100644
index 0000000000..6c9cf384f1
--- /dev/null
+++ b/third_party/rust/naga/src/front/spv/error.rs
@@ -0,0 +1,127 @@
+use super::ModuleState;
+use crate::arena::Handle;
+
+#[derive(Debug, thiserror::Error)]
+pub enum Error {
+ #[error("invalid header")]
+ InvalidHeader,
+ #[error("invalid word count")]
+ InvalidWordCount,
+ #[error("unknown instruction {0}")]
+ UnknownInstruction(u16),
+ #[error("unknown capability %{0}")]
+ UnknownCapability(spirv::Word),
+ #[error("unsupported instruction {1:?} at {0:?}")]
+ UnsupportedInstruction(ModuleState, spirv::Op),
+ #[error("unsupported capability {0:?}")]
+ UnsupportedCapability(spirv::Capability),
+ #[error("unsupported extension {0}")]
+ UnsupportedExtension(String),
+ #[error("unsupported extension set {0}")]
+ UnsupportedExtSet(String),
+ #[error("unsupported extension instantiation set %{0}")]
+ UnsupportedExtInstSet(spirv::Word),
+ #[error("unsupported extension instantiation %{0}")]
+ UnsupportedExtInst(spirv::Word),
+ #[error("unsupported type {0:?}")]
+ UnsupportedType(Handle<crate::Type>),
+ #[error("unsupported execution model %{0}")]
+ UnsupportedExecutionModel(spirv::Word),
+ #[error("unsupported execution mode %{0}")]
+ UnsupportedExecutionMode(spirv::Word),
+ #[error("unsupported storage class %{0}")]
+ UnsupportedStorageClass(spirv::Word),
+ #[error("unsupported image dimension %{0}")]
+ UnsupportedImageDim(spirv::Word),
+ #[error("unsupported image format %{0}")]
+ UnsupportedImageFormat(spirv::Word),
+ #[error("unsupported builtin %{0}")]
+ UnsupportedBuiltIn(spirv::Word),
+ #[error("unsupported control flow %{0}")]
+ UnsupportedControlFlow(spirv::Word),
+ #[error("unsupported binary operator %{0}")]
+ UnsupportedBinaryOperator(spirv::Word),
+ #[error("Naga supports OpTypeRuntimeArray in the StorageBuffer storage class only")]
+ UnsupportedRuntimeArrayStorageClass,
+ #[error("unsupported matrix stride {stride} for a {columns}x{rows} matrix with scalar width={width}")]
+ UnsupportedMatrixStride {
+ stride: u32,
+ columns: u8,
+ rows: u8,
+ width: u8,
+ },
+ #[error("unknown binary operator {0:?}")]
+ UnknownBinaryOperator(spirv::Op),
+ #[error("unknown relational function {0:?}")]
+ UnknownRelationalFunction(spirv::Op),
+ #[error("invalid parameter {0:?}")]
+ InvalidParameter(spirv::Op),
+ #[error("invalid operand count {1} for {0:?}")]
+ InvalidOperandCount(spirv::Op, u16),
+ #[error("invalid operand")]
+ InvalidOperand,
+ #[error("invalid id %{0}")]
+ InvalidId(spirv::Word),
+ #[error("invalid decoration %{0}")]
+ InvalidDecoration(spirv::Word),
+ #[error("invalid type width %{0}")]
+ InvalidTypeWidth(spirv::Word),
+ #[error("invalid sign %{0}")]
+ InvalidSign(spirv::Word),
+ #[error("invalid inner type %{0}")]
+ InvalidInnerType(spirv::Word),
+ #[error("invalid vector size %{0}")]
+ InvalidVectorSize(spirv::Word),
+ #[error("invalid access type %{0}")]
+ InvalidAccessType(spirv::Word),
+ #[error("invalid access {0:?}")]
+ InvalidAccess(crate::Expression),
+ #[error("invalid access index %{0}")]
+ InvalidAccessIndex(spirv::Word),
+ #[error("invalid binding %{0}")]
+ InvalidBinding(spirv::Word),
+ #[error("invalid global var {0:?}")]
+ InvalidGlobalVar(crate::Expression),
+ #[error("invalid image/sampler expression {0:?}")]
+ InvalidImageExpression(crate::Expression),
+ #[error("invalid image base type {0:?}")]
+ InvalidImageBaseType(Handle<crate::Type>),
+ #[error("invalid image {0:?}")]
+ InvalidImage(Handle<crate::Type>),
+ #[error("invalid as type {0:?}")]
+ InvalidAsType(Handle<crate::Type>),
+ #[error("invalid vector type {0:?}")]
+ InvalidVectorType(Handle<crate::Type>),
+ #[error("inconsistent comparison sampling {0:?}")]
+ InconsistentComparisonSampling(Handle<crate::GlobalVariable>),
+ #[error("wrong function result type %{0}")]
+ WrongFunctionResultType(spirv::Word),
+ #[error("wrong function argument type %{0}")]
+ WrongFunctionArgumentType(spirv::Word),
+ #[error("missing decoration {0:?}")]
+ MissingDecoration(spirv::Decoration),
+ #[error("bad string")]
+ BadString,
+ #[error("incomplete data")]
+ IncompleteData,
+ #[error("invalid terminator")]
+ InvalidTerminator,
+ #[error("invalid edge classification")]
+ InvalidEdgeClassification,
+ #[error("cycle detected in the CFG during traversal at {0}")]
+ ControlFlowGraphCycle(crate::front::spv::BlockId),
+ #[error("recursive function call %{0}")]
+ FunctionCallCycle(spirv::Word),
+ #[error("invalid array size {0:?}")]
+ InvalidArraySize(Handle<crate::Constant>),
+ #[error("invalid barrier scope %{0}")]
+ InvalidBarrierScope(spirv::Word),
+ #[error("invalid barrier memory semantics %{0}")]
+ InvalidBarrierMemorySemantics(spirv::Word),
+ #[error(
+ "arrays of images / samplers are supported only through bindings for \
+ now (i.e. you can't create an array of images or samplers that doesn't \
+ come from a binding)"
+ )]
+ NonBindingArrayOfImageOrSamplers,
+}
diff --git a/third_party/rust/naga/src/front/spv/function.rs b/third_party/rust/naga/src/front/spv/function.rs
new file mode 100644
index 0000000000..5dc781504e
--- /dev/null
+++ b/third_party/rust/naga/src/front/spv/function.rs
@@ -0,0 +1,661 @@
+use crate::{
+ arena::{Arena, Handle},
+ front::spv::{BlockContext, BodyIndex},
+};
+
+use super::{Error, Instruction, LookupExpression, LookupHelper as _};
+use crate::front::Emitter;
+
+pub type BlockId = u32;
+
+#[derive(Copy, Clone, Debug)]
+pub struct MergeInstruction {
+ pub merge_block_id: BlockId,
+ pub continue_block_id: Option<BlockId>,
+}
+
+impl<I: Iterator<Item = u32>> super::Frontend<I> {
+ // Registers a function call. It will generate a dummy handle to call, which
+ // gets resolved after all the functions are processed.
+ pub(super) fn add_call(
+ &mut self,
+ from: spirv::Word,
+ to: spirv::Word,
+ ) -> Handle<crate::Function> {
+ let dummy_handle = self
+ .dummy_functions
+ .append(crate::Function::default(), Default::default());
+ self.deferred_function_calls.push(to);
+ self.function_call_graph.add_edge(from, to, ());
+ dummy_handle
+ }
+
+ pub(super) fn parse_function(&mut self, module: &mut crate::Module) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.lookup_expression.clear();
+ self.lookup_load_override.clear();
+ self.lookup_sampled_image.clear();
+
+ let result_type_id = self.next()?;
+ let fun_id = self.next()?;
+ let _fun_control = self.next()?;
+ let fun_type_id = self.next()?;
+
+ let mut fun = {
+ let ft = self.lookup_function_type.lookup(fun_type_id)?;
+ if ft.return_type_id != result_type_id {
+ return Err(Error::WrongFunctionResultType(result_type_id));
+ }
+ crate::Function {
+ name: self.future_decor.remove(&fun_id).and_then(|dec| dec.name),
+ arguments: Vec::with_capacity(ft.parameter_type_ids.len()),
+ result: if self.lookup_void_type == Some(result_type_id) {
+ None
+ } else {
+ let lookup_result_ty = self.lookup_type.lookup(result_type_id)?;
+ Some(crate::FunctionResult {
+ ty: lookup_result_ty.handle,
+ binding: None,
+ })
+ },
+ local_variables: Arena::new(),
+ expressions: self
+ .make_expression_storage(&module.global_variables, &module.constants),
+ named_expressions: crate::NamedExpressions::default(),
+ body: crate::Block::new(),
+ }
+ };
+
+ // read parameters
+ for i in 0..fun.arguments.capacity() {
+ let start = self.data_offset;
+ match self.next_inst()? {
+ Instruction {
+ op: spirv::Op::FunctionParameter,
+ wc: 3,
+ } => {
+ let type_id = self.next()?;
+ let id = self.next()?;
+ let handle = fun.expressions.append(
+ crate::Expression::FunctionArgument(i as u32),
+ self.span_from(start),
+ );
+ self.lookup_expression.insert(
+ id,
+ LookupExpression {
+ handle,
+ type_id,
+ // Setting this to an invalid id will cause get_expr_handle
+ // to default to the main body making sure no load/stores
+ // are added.
+ block_id: 0,
+ },
+ );
+ //Note: we redo the lookup in order to work around `self` borrowing
+
+ if type_id
+ != self
+ .lookup_function_type
+ .lookup(fun_type_id)?
+ .parameter_type_ids[i]
+ {
+ return Err(Error::WrongFunctionArgumentType(type_id));
+ }
+ let ty = self.lookup_type.lookup(type_id)?.handle;
+ let decor = self.future_decor.remove(&id).unwrap_or_default();
+ fun.arguments.push(crate::FunctionArgument {
+ name: decor.name,
+ ty,
+ binding: None,
+ });
+ }
+ Instruction { op, .. } => return Err(Error::InvalidParameter(op)),
+ }
+ }
+
+ // Read body
+ self.function_call_graph.add_node(fun_id);
+ let mut parameters_sampling =
+ vec![super::image::SamplingFlags::empty(); fun.arguments.len()];
+
+ let mut block_ctx = BlockContext {
+ phis: Default::default(),
+ blocks: Default::default(),
+ body_for_label: Default::default(),
+ mergers: Default::default(),
+ bodies: Default::default(),
+ function_id: fun_id,
+ expressions: &mut fun.expressions,
+ local_arena: &mut fun.local_variables,
+ const_arena: &mut module.constants,
+ type_arena: &module.types,
+ global_arena: &module.global_variables,
+ arguments: &fun.arguments,
+ parameter_sampling: &mut parameters_sampling,
+ };
+ // Insert the main body whose parent is also himself
+ block_ctx.bodies.push(super::Body::with_parent(0));
+
+ // Scan the blocks and add them as nodes
+ loop {
+ let fun_inst = self.next_inst()?;
+ log::debug!("{:?}", fun_inst.op);
+ match fun_inst.op {
+ spirv::Op::Line => {
+ fun_inst.expect(4)?;
+ let _file_id = self.next()?;
+ let _row_id = self.next()?;
+ let _col_id = self.next()?;
+ }
+ spirv::Op::Label => {
+ // Read the label ID
+ fun_inst.expect(2)?;
+ let block_id = self.next()?;
+
+ self.next_block(block_id, &mut block_ctx)?;
+ }
+ spirv::Op::FunctionEnd => {
+ fun_inst.expect(1)?;
+ break;
+ }
+ _ => {
+ return Err(Error::UnsupportedInstruction(self.state, fun_inst.op));
+ }
+ }
+ }
+
+ if let Some(ref prefix) = self.options.block_ctx_dump_prefix {
+ let dump_suffix = match self.lookup_entry_point.get(&fun_id) {
+ Some(ep) => format!("block_ctx.{:?}-{}.txt", ep.stage, ep.name),
+ None => format!("block_ctx.Fun-{}.txt", module.functions.len()),
+ };
+ let dest = prefix.join(dump_suffix);
+ let dump = format!("{block_ctx:#?}");
+ if let Err(e) = std::fs::write(&dest, dump) {
+ log::error!("Unable to dump the block context into {:?}: {}", dest, e);
+ }
+ }
+
+ // Emit `Store` statements to properly initialize all the local variables we
+ // created for `phi` expressions.
+ //
+ // Note that get_expr_handle also contributes slightly odd entries to this table,
+ // to get the spill.
+ for phi in block_ctx.phis.iter() {
+ // Get a pointer to the local variable for the phi's value.
+ let phi_pointer = block_ctx.expressions.append(
+ crate::Expression::LocalVariable(phi.local),
+ crate::Span::default(),
+ );
+
+ // At the end of each of `phi`'s predecessor blocks, store the corresponding
+ // source value in the phi's local variable.
+ for &(source, predecessor) in phi.expressions.iter() {
+ let source_lexp = &self.lookup_expression[&source];
+ let predecessor_body_idx = block_ctx.body_for_label[&predecessor];
+ // If the expression is a global/argument it will have a 0 block
+ // id so we must use a default value instead of panicking
+ let source_body_idx = block_ctx
+ .body_for_label
+ .get(&source_lexp.block_id)
+ .copied()
+ .unwrap_or(0);
+
+ // If the Naga `Expression` generated for `source` is in scope, then we
+ // can simply store that in the phi's local variable.
+ //
+ // Otherwise, spill the source value to a local variable in the block that
+ // defines it. (We know this store dominates the predecessor; otherwise,
+ // the phi wouldn't have been able to refer to that source expression in
+ // the first place.) Then, the predecessor block can count on finding the
+ // source's value in that local variable.
+ let value = if super::is_parent(predecessor_body_idx, source_body_idx, &block_ctx) {
+ source_lexp.handle
+ } else {
+ // The source SPIR-V expression is not defined in the phi's
+ // predecessor block, nor is it a globally available expression. So it
+ // must be defined off in some other block that merely dominates the
+ // predecessor. This means that the corresponding Naga `Expression`
+ // may not be in scope in the predecessor block.
+ //
+ // In the block that defines `source`, spill it to a fresh local
+ // variable, to ensure we can still use it at the end of the
+ // predecessor.
+ let ty = self.lookup_type[&source_lexp.type_id].handle;
+ let local = block_ctx.local_arena.append(
+ crate::LocalVariable {
+ name: None,
+ ty,
+ init: None,
+ },
+ crate::Span::default(),
+ );
+
+ let pointer = block_ctx.expressions.append(
+ crate::Expression::LocalVariable(local),
+ crate::Span::default(),
+ );
+
+ // Get the spilled value of the source expression.
+ let start = block_ctx.expressions.len();
+ let expr = block_ctx
+ .expressions
+ .append(crate::Expression::Load { pointer }, crate::Span::default());
+ let range = block_ctx.expressions.range_from(start);
+
+ block_ctx
+ .blocks
+ .get_mut(&predecessor)
+ .unwrap()
+ .push(crate::Statement::Emit(range), crate::Span::default());
+
+ // At the end of the block that defines it, spill the source
+ // expression's value.
+ block_ctx
+ .blocks
+ .get_mut(&source_lexp.block_id)
+ .unwrap()
+ .push(
+ crate::Statement::Store {
+ pointer,
+ value: source_lexp.handle,
+ },
+ crate::Span::default(),
+ );
+
+ expr
+ };
+
+ // At the end of the phi predecessor block, store the source
+ // value in the phi's value.
+ block_ctx.blocks.get_mut(&predecessor).unwrap().push(
+ crate::Statement::Store {
+ pointer: phi_pointer,
+ value,
+ },
+ crate::Span::default(),
+ )
+ }
+ }
+
+ fun.body = block_ctx.lower();
+
+ // done
+ let fun_handle = module.functions.append(fun, self.span_from_with_op(start));
+ self.lookup_function.insert(
+ fun_id,
+ super::LookupFunction {
+ handle: fun_handle,
+ parameters_sampling,
+ },
+ );
+
+ if let Some(ep) = self.lookup_entry_point.remove(&fun_id) {
+ // create a wrapping function
+ let mut function = crate::Function {
+ name: Some(format!("{}_wrap", ep.name)),
+ arguments: Vec::new(),
+ result: None,
+ local_variables: Arena::new(),
+ expressions: Arena::new(),
+ named_expressions: crate::NamedExpressions::default(),
+ body: crate::Block::new(),
+ };
+
+ // 1. copy the inputs from arguments to privates
+ for &v_id in ep.variable_ids.iter() {
+ let lvar = self.lookup_variable.lookup(v_id)?;
+ if let super::Variable::Input(ref arg) = lvar.inner {
+ let span = module.global_variables.get_span(lvar.handle);
+ let arg_expr = function.expressions.append(
+ crate::Expression::FunctionArgument(function.arguments.len() as u32),
+ span,
+ );
+ let load_expr = if arg.ty == module.global_variables[lvar.handle].ty {
+ arg_expr
+ } else {
+ // The only case where the type is different is if we need to treat
+ // unsigned integer as signed.
+ let mut emitter = Emitter::default();
+ emitter.start(&function.expressions);
+ let handle = function.expressions.append(
+ crate::Expression::As {
+ expr: arg_expr,
+ kind: crate::ScalarKind::Sint,
+ convert: Some(4),
+ },
+ span,
+ );
+ function.body.extend(emitter.finish(&function.expressions));
+ handle
+ };
+ function.body.push(
+ crate::Statement::Store {
+ pointer: function
+ .expressions
+ .append(crate::Expression::GlobalVariable(lvar.handle), span),
+ value: load_expr,
+ },
+ span,
+ );
+
+ let mut arg = arg.clone();
+ if ep.stage == crate::ShaderStage::Fragment {
+ if let Some(ref mut binding) = arg.binding {
+ binding.apply_default_interpolation(&module.types[arg.ty].inner);
+ }
+ }
+ function.arguments.push(arg);
+ }
+ }
+ // 2. call the wrapped function
+ let fake_id = !(module.entry_points.len() as u32); // doesn't matter, as long as it's not a collision
+ let dummy_handle = self.add_call(fake_id, fun_id);
+ function.body.push(
+ crate::Statement::Call {
+ function: dummy_handle,
+ arguments: Vec::new(),
+ result: None,
+ },
+ crate::Span::default(),
+ );
+
+ // 3. copy the outputs from privates to the result
+ let mut members = Vec::new();
+ let mut components = Vec::new();
+ for &v_id in ep.variable_ids.iter() {
+ let lvar = self.lookup_variable.lookup(v_id)?;
+ if let super::Variable::Output(ref result) = lvar.inner {
+ let span = module.global_variables.get_span(lvar.handle);
+ let expr_handle = function
+ .expressions
+ .append(crate::Expression::GlobalVariable(lvar.handle), span);
+
+ // Cull problematic builtins of gl_PerVertex.
+ // See the docs for `Frontend::gl_per_vertex_builtin_access`.
+ {
+ let ty = &module.types[result.ty];
+ match ty.inner {
+ crate::TypeInner::Struct {
+ members: ref original_members,
+ span,
+ } if ty.name.as_deref() == Some("gl_PerVertex") => {
+ let mut new_members = original_members.clone();
+ for member in &mut new_members {
+ if let Some(crate::Binding::BuiltIn(built_in)) = member.binding
+ {
+ if !self.gl_per_vertex_builtin_access.contains(&built_in) {
+ member.binding = None
+ }
+ }
+ }
+ if &new_members != original_members {
+ module.types.replace(
+ result.ty,
+ crate::Type {
+ name: ty.name.clone(),
+ inner: crate::TypeInner::Struct {
+ members: new_members,
+ span,
+ },
+ },
+ );
+ }
+ }
+ _ => {}
+ }
+ }
+
+ match module.types[result.ty].inner {
+ crate::TypeInner::Struct {
+ members: ref sub_members,
+ ..
+ } => {
+ for (index, sm) in sub_members.iter().enumerate() {
+ if sm.binding.is_none() {
+ continue;
+ }
+ let mut sm = sm.clone();
+
+ if let Some(ref mut binding) = sm.binding {
+ if ep.stage == crate::ShaderStage::Vertex {
+ binding.apply_default_interpolation(
+ &module.types[sm.ty].inner,
+ );
+ }
+ }
+
+ members.push(sm);
+
+ components.push(function.expressions.append(
+ crate::Expression::AccessIndex {
+ base: expr_handle,
+ index: index as u32,
+ },
+ span,
+ ));
+ }
+ }
+ ref inner => {
+ let mut binding = result.binding.clone();
+ if let Some(ref mut binding) = binding {
+ if ep.stage == crate::ShaderStage::Vertex {
+ binding.apply_default_interpolation(inner);
+ }
+ }
+
+ members.push(crate::StructMember {
+ name: None,
+ ty: result.ty,
+ binding,
+ offset: 0,
+ });
+ // populate just the globals first, then do `Load` in a
+ // separate step, so that we can get a range.
+ components.push(expr_handle);
+ }
+ }
+ }
+ }
+
+ for (member_index, member) in members.iter().enumerate() {
+ match member.binding {
+ Some(crate::Binding::BuiltIn(crate::BuiltIn::Position { .. }))
+ if self.options.adjust_coordinate_space =>
+ {
+ let mut emitter = Emitter::default();
+ emitter.start(&function.expressions);
+ let global_expr = components[member_index];
+ let span = function.expressions.get_span(global_expr);
+ let access_expr = function.expressions.append(
+ crate::Expression::AccessIndex {
+ base: global_expr,
+ index: 1,
+ },
+ span,
+ );
+ let load_expr = function.expressions.append(
+ crate::Expression::Load {
+ pointer: access_expr,
+ },
+ span,
+ );
+ let neg_expr = function.expressions.append(
+ crate::Expression::Unary {
+ op: crate::UnaryOperator::Negate,
+ expr: load_expr,
+ },
+ span,
+ );
+ function.body.extend(emitter.finish(&function.expressions));
+ function.body.push(
+ crate::Statement::Store {
+ pointer: access_expr,
+ value: neg_expr,
+ },
+ span,
+ );
+ }
+ _ => {}
+ }
+ }
+
+ let mut emitter = Emitter::default();
+ emitter.start(&function.expressions);
+ for component in components.iter_mut() {
+ let load_expr = crate::Expression::Load {
+ pointer: *component,
+ };
+ let span = function.expressions.get_span(*component);
+ *component = function.expressions.append(load_expr, span);
+ }
+
+ match members[..] {
+ [] => {}
+ [ref member] => {
+ function.body.extend(emitter.finish(&function.expressions));
+ let span = function.expressions.get_span(components[0]);
+ function.body.push(
+ crate::Statement::Return {
+ value: components.first().cloned(),
+ },
+ span,
+ );
+ function.result = Some(crate::FunctionResult {
+ ty: member.ty,
+ binding: member.binding.clone(),
+ });
+ }
+ _ => {
+ let span = crate::Span::total_span(
+ components.iter().map(|h| function.expressions.get_span(*h)),
+ );
+ let ty = module.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Struct {
+ members,
+ span: 0xFFFF, // shouldn't matter
+ },
+ },
+ span,
+ );
+ let result_expr = function
+ .expressions
+ .append(crate::Expression::Compose { ty, components }, span);
+ function.body.extend(emitter.finish(&function.expressions));
+ function.body.push(
+ crate::Statement::Return {
+ value: Some(result_expr),
+ },
+ span,
+ );
+ function.result = Some(crate::FunctionResult { ty, binding: None });
+ }
+ }
+
+ module.entry_points.push(crate::EntryPoint {
+ name: ep.name,
+ stage: ep.stage,
+ early_depth_test: ep.early_depth_test,
+ workgroup_size: ep.workgroup_size,
+ function,
+ });
+ }
+
+ Ok(())
+ }
+}
+
+impl<'function> BlockContext<'function> {
+ /// Consumes the `BlockContext` producing a Ir [`Block`](crate::Block)
+ fn lower(mut self) -> crate::Block {
+ fn lower_impl(
+ blocks: &mut crate::FastHashMap<spirv::Word, crate::Block>,
+ bodies: &[super::Body],
+ body_idx: BodyIndex,
+ ) -> crate::Block {
+ let mut block = crate::Block::new();
+
+ for item in bodies[body_idx].data.iter() {
+ match *item {
+ super::BodyFragment::BlockId(id) => block.append(blocks.get_mut(&id).unwrap()),
+ super::BodyFragment::If {
+ condition,
+ accept,
+ reject,
+ } => {
+ let accept = lower_impl(blocks, bodies, accept);
+ let reject = lower_impl(blocks, bodies, reject);
+
+ block.push(
+ crate::Statement::If {
+ condition,
+ accept,
+ reject,
+ },
+ crate::Span::default(),
+ )
+ }
+ super::BodyFragment::Loop { body, continuing } => {
+ let body = lower_impl(blocks, bodies, body);
+ let continuing = lower_impl(blocks, bodies, continuing);
+
+ block.push(
+ crate::Statement::Loop {
+ body,
+ continuing,
+ break_if: None,
+ },
+ crate::Span::default(),
+ )
+ }
+ super::BodyFragment::Switch {
+ selector,
+ ref cases,
+ default,
+ } => {
+ let mut ir_cases: Vec<_> = cases
+ .iter()
+ .map(|&(value, body_idx)| {
+ let body = lower_impl(blocks, bodies, body_idx);
+
+ // Handle simple cases that would make a fallthrough statement unreachable code
+ let fall_through = body.last().map_or(true, |s| !s.is_terminator());
+
+ crate::SwitchCase {
+ value: crate::SwitchValue::I32(value),
+ body,
+ fall_through,
+ }
+ })
+ .collect();
+ ir_cases.push(crate::SwitchCase {
+ value: crate::SwitchValue::Default,
+ body: lower_impl(blocks, bodies, default),
+ fall_through: false,
+ });
+
+ block.push(
+ crate::Statement::Switch {
+ selector,
+ cases: ir_cases,
+ },
+ crate::Span::default(),
+ )
+ }
+ super::BodyFragment::Break => {
+ block.push(crate::Statement::Break, crate::Span::default())
+ }
+ super::BodyFragment::Continue => {
+ block.push(crate::Statement::Continue, crate::Span::default())
+ }
+ }
+ }
+
+ block
+ }
+
+ lower_impl(&mut self.blocks, &self.bodies, 0)
+ }
+}
diff --git a/third_party/rust/naga/src/front/spv/image.rs b/third_party/rust/naga/src/front/spv/image.rs
new file mode 100644
index 0000000000..757843f789
--- /dev/null
+++ b/third_party/rust/naga/src/front/spv/image.rs
@@ -0,0 +1,737 @@
+use crate::arena::{Arena, Handle, UniqueArena};
+
+use super::{Error, LookupExpression, LookupHelper as _};
+
+#[derive(Clone, Debug)]
+pub(super) struct LookupSampledImage {
+ image: Handle<crate::Expression>,
+ sampler: Handle<crate::Expression>,
+}
+
+bitflags::bitflags! {
+ /// Flags describing sampling method.
+ pub struct SamplingFlags: u32 {
+ /// Regular sampling.
+ const REGULAR = 0x1;
+ /// Comparison sampling.
+ const COMPARISON = 0x2;
+ }
+}
+
+impl<'function> super::BlockContext<'function> {
+ fn get_image_expr_ty(
+ &self,
+ handle: Handle<crate::Expression>,
+ ) -> Result<Handle<crate::Type>, Error> {
+ match self.expressions[handle] {
+ crate::Expression::GlobalVariable(handle) => Ok(self.global_arena[handle].ty),
+ crate::Expression::FunctionArgument(i) => Ok(self.arguments[i as usize].ty),
+ ref other => Err(Error::InvalidImageExpression(other.clone())),
+ }
+ }
+}
+
+/// Options of a sampling operation.
+#[derive(Debug)]
+pub struct SamplingOptions {
+ /// Projection sampling: the division by W is expected to happen
+ /// in the texture unit.
+ pub project: bool,
+ /// Depth comparison sampling with a reference value.
+ pub compare: bool,
+}
+
+enum ExtraCoordinate {
+ ArrayLayer,
+ Projection,
+ Garbage,
+}
+
+/// Return the texture coordinates separated from the array layer,
+/// and/or divided by the projection term.
+///
+/// The Proj sampling ops expect an extra coordinate for the W.
+/// The arrayed (can't be Proj!) images expect an extra coordinate for the layer.
+fn extract_image_coordinates(
+ image_dim: crate::ImageDimension,
+ extra_coordinate: ExtraCoordinate,
+ base: Handle<crate::Expression>,
+ coordinate_ty: Handle<crate::Type>,
+ ctx: &mut super::BlockContext,
+) -> (Handle<crate::Expression>, Option<Handle<crate::Expression>>) {
+ let (given_size, kind) = match ctx.type_arena[coordinate_ty].inner {
+ crate::TypeInner::Scalar { kind, .. } => (None, kind),
+ crate::TypeInner::Vector { size, kind, .. } => (Some(size), kind),
+ ref other => unreachable!("Unexpected texture coordinate {:?}", other),
+ };
+
+ let required_size = image_dim.required_coordinate_size();
+ let required_ty = required_size.map(|size| {
+ ctx.type_arena
+ .get(&crate::Type {
+ name: None,
+ inner: crate::TypeInner::Vector {
+ size,
+ kind,
+ width: 4,
+ },
+ })
+ .expect("Required coordinate type should have been set up by `parse_type_image`!")
+ });
+ let extra_expr = crate::Expression::AccessIndex {
+ base,
+ index: required_size.map_or(1, |size| size as u32),
+ };
+
+ let base_span = ctx.expressions.get_span(base);
+
+ match extra_coordinate {
+ ExtraCoordinate::ArrayLayer => {
+ let extracted = match required_size {
+ None => ctx
+ .expressions
+ .append(crate::Expression::AccessIndex { base, index: 0 }, base_span),
+ Some(size) => {
+ let mut components = Vec::with_capacity(size as usize);
+ for index in 0..size as u32 {
+ let comp = ctx
+ .expressions
+ .append(crate::Expression::AccessIndex { base, index }, base_span);
+ components.push(comp);
+ }
+ ctx.expressions.append(
+ crate::Expression::Compose {
+ ty: required_ty.unwrap(),
+ components,
+ },
+ base_span,
+ )
+ }
+ };
+ let array_index_f32 = ctx.expressions.append(extra_expr, base_span);
+ let array_index = ctx.expressions.append(
+ crate::Expression::As {
+ kind: crate::ScalarKind::Sint,
+ expr: array_index_f32,
+ convert: Some(4),
+ },
+ base_span,
+ );
+ (extracted, Some(array_index))
+ }
+ ExtraCoordinate::Projection => {
+ let projection = ctx.expressions.append(extra_expr, base_span);
+ let divided = match required_size {
+ None => {
+ let temp = ctx
+ .expressions
+ .append(crate::Expression::AccessIndex { base, index: 0 }, base_span);
+ ctx.expressions.append(
+ crate::Expression::Binary {
+ op: crate::BinaryOperator::Divide,
+ left: temp,
+ right: projection,
+ },
+ base_span,
+ )
+ }
+ Some(size) => {
+ let mut components = Vec::with_capacity(size as usize);
+ for index in 0..size as u32 {
+ let temp = ctx
+ .expressions
+ .append(crate::Expression::AccessIndex { base, index }, base_span);
+ let comp = ctx.expressions.append(
+ crate::Expression::Binary {
+ op: crate::BinaryOperator::Divide,
+ left: temp,
+ right: projection,
+ },
+ base_span,
+ );
+ components.push(comp);
+ }
+ ctx.expressions.append(
+ crate::Expression::Compose {
+ ty: required_ty.unwrap(),
+ components,
+ },
+ base_span,
+ )
+ }
+ };
+ (divided, None)
+ }
+ ExtraCoordinate::Garbage if given_size == required_size => (base, None),
+ ExtraCoordinate::Garbage => {
+ use crate::SwizzleComponent as Sc;
+ let cut_expr = match required_size {
+ None => crate::Expression::AccessIndex { base, index: 0 },
+ Some(size) => crate::Expression::Swizzle {
+ size,
+ vector: base,
+ pattern: [Sc::X, Sc::Y, Sc::Z, Sc::W],
+ },
+ };
+ (ctx.expressions.append(cut_expr, base_span), None)
+ }
+ }
+}
+
+pub(super) fn patch_comparison_type(
+ flags: SamplingFlags,
+ var: &mut crate::GlobalVariable,
+ arena: &mut UniqueArena<crate::Type>,
+) -> bool {
+ if !flags.contains(SamplingFlags::COMPARISON) {
+ return true;
+ }
+ if flags == SamplingFlags::all() {
+ return false;
+ }
+
+ log::debug!("Flipping comparison for {:?}", var);
+ let original_ty = &arena[var.ty];
+ let original_ty_span = arena.get_span(var.ty);
+ let ty_inner = match original_ty.inner {
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Sampled { multi, .. },
+ dim,
+ arrayed,
+ } => crate::TypeInner::Image {
+ class: crate::ImageClass::Depth { multi },
+ dim,
+ arrayed,
+ },
+ crate::TypeInner::Sampler { .. } => crate::TypeInner::Sampler { comparison: true },
+ ref other => unreachable!("Unexpected type for comparison mutation: {:?}", other),
+ };
+
+ let name = original_ty.name.clone();
+ var.ty = arena.insert(
+ crate::Type {
+ name,
+ inner: ty_inner,
+ },
+ original_ty_span,
+ );
+ true
+}
+
+impl<I: Iterator<Item = u32>> super::Frontend<I> {
+ pub(super) fn parse_image_couple(&mut self) -> Result<(), Error> {
+ let _result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let image_id = self.next()?;
+ let sampler_id = self.next()?;
+ let image_lexp = self.lookup_expression.lookup(image_id)?;
+ let sampler_lexp = self.lookup_expression.lookup(sampler_id)?;
+ self.lookup_sampled_image.insert(
+ result_id,
+ LookupSampledImage {
+ image: image_lexp.handle,
+ sampler: sampler_lexp.handle,
+ },
+ );
+ Ok(())
+ }
+
+ pub(super) fn parse_image_uncouple(&mut self, block_id: spirv::Word) -> Result<(), Error> {
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let sampled_image_id = self.next()?;
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: self.lookup_sampled_image.lookup(sampled_image_id)?.image,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Ok(())
+ }
+
+ pub(super) fn parse_image_write(
+ &mut self,
+ words_left: u16,
+ ctx: &mut super::BlockContext,
+ emitter: &mut crate::front::Emitter,
+ block: &mut crate::Block,
+ body_idx: usize,
+ ) -> Result<crate::Statement, Error> {
+ let image_id = self.next()?;
+ let coordinate_id = self.next()?;
+ let value_id = self.next()?;
+
+ let image_ops = if words_left != 0 { self.next()? } else { 0 };
+
+ if image_ops != 0 {
+ let other = spirv::ImageOperands::from_bits_truncate(image_ops);
+ log::warn!("Unknown image write ops {:?}", other);
+ for _ in 1..words_left {
+ self.next()?;
+ }
+ }
+
+ let image_lexp = self.lookup_expression.lookup(image_id)?;
+ let image_ty = ctx.get_image_expr_ty(image_lexp.handle)?;
+
+ let coord_lexp = self.lookup_expression.lookup(coordinate_id)?;
+ let coord_handle =
+ self.get_expr_handle(coordinate_id, coord_lexp, ctx, emitter, block, body_idx);
+ let coord_type_handle = self.lookup_type.lookup(coord_lexp.type_id)?.handle;
+ let (coordinate, array_index) = match ctx.type_arena[image_ty].inner {
+ crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class: _,
+ } => extract_image_coordinates(
+ dim,
+ if arrayed {
+ ExtraCoordinate::ArrayLayer
+ } else {
+ ExtraCoordinate::Garbage
+ },
+ coord_handle,
+ coord_type_handle,
+ ctx,
+ ),
+ _ => return Err(Error::InvalidImage(image_ty)),
+ };
+
+ let value_lexp = self.lookup_expression.lookup(value_id)?;
+ let value = self.get_expr_handle(value_id, value_lexp, ctx, emitter, block, body_idx);
+
+ Ok(crate::Statement::ImageStore {
+ image: image_lexp.handle,
+ coordinate,
+ array_index,
+ value,
+ })
+ }
+
+ pub(super) fn parse_image_load(
+ &mut self,
+ mut words_left: u16,
+ ctx: &mut super::BlockContext,
+ emitter: &mut crate::front::Emitter,
+ block: &mut crate::Block,
+ block_id: spirv::Word,
+ body_idx: usize,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let image_id = self.next()?;
+ let coordinate_id = self.next()?;
+
+ let mut image_ops = if words_left != 0 {
+ words_left -= 1;
+ self.next()?
+ } else {
+ 0
+ };
+
+ let mut sample = None;
+ let mut level = None;
+ while image_ops != 0 {
+ let bit = 1 << image_ops.trailing_zeros();
+ match spirv::ImageOperands::from_bits_truncate(bit) {
+ spirv::ImageOperands::LOD => {
+ let lod_expr = self.next()?;
+ let lod_lexp = self.lookup_expression.lookup(lod_expr)?;
+ let lod_handle =
+ self.get_expr_handle(lod_expr, lod_lexp, ctx, emitter, block, body_idx);
+ level = Some(lod_handle);
+ words_left -= 1;
+ }
+ spirv::ImageOperands::SAMPLE => {
+ let sample_expr = self.next()?;
+ let sample_handle = self.lookup_expression.lookup(sample_expr)?.handle;
+ sample = Some(sample_handle);
+ words_left -= 1;
+ }
+ other => {
+ log::warn!("Unknown image load op {:?}", other);
+ for _ in 0..words_left {
+ self.next()?;
+ }
+ break;
+ }
+ }
+ image_ops ^= bit;
+ }
+
+ // No need to call get_expr_handle here since only globals/arguments are
+ // allowed as images and they are always in the root scope
+ let image_lexp = self.lookup_expression.lookup(image_id)?;
+ let image_ty = ctx.get_image_expr_ty(image_lexp.handle)?;
+
+ let coord_lexp = self.lookup_expression.lookup(coordinate_id)?;
+ let coord_handle =
+ self.get_expr_handle(coordinate_id, coord_lexp, ctx, emitter, block, body_idx);
+ let coord_type_handle = self.lookup_type.lookup(coord_lexp.type_id)?.handle;
+ let (coordinate, array_index) = match ctx.type_arena[image_ty].inner {
+ crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class: _,
+ } => extract_image_coordinates(
+ dim,
+ if arrayed {
+ ExtraCoordinate::ArrayLayer
+ } else {
+ ExtraCoordinate::Garbage
+ },
+ coord_handle,
+ coord_type_handle,
+ ctx,
+ ),
+ _ => return Err(Error::InvalidImage(image_ty)),
+ };
+
+ let expr = crate::Expression::ImageLoad {
+ image: image_lexp.handle,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, self.span_from_with_op(start)),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Ok(())
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ pub(super) fn parse_image_sample(
+ &mut self,
+ mut words_left: u16,
+ options: SamplingOptions,
+ ctx: &mut super::BlockContext,
+ emitter: &mut crate::front::Emitter,
+ block: &mut crate::Block,
+ block_id: spirv::Word,
+ body_idx: usize,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let sampled_image_id = self.next()?;
+ let coordinate_id = self.next()?;
+ let dref_id = if options.compare {
+ Some(self.next()?)
+ } else {
+ None
+ };
+
+ let mut image_ops = if words_left != 0 {
+ words_left -= 1;
+ self.next()?
+ } else {
+ 0
+ };
+
+ let mut level = crate::SampleLevel::Auto;
+ let mut offset = None;
+ while image_ops != 0 {
+ let bit = 1 << image_ops.trailing_zeros();
+ match spirv::ImageOperands::from_bits_truncate(bit) {
+ spirv::ImageOperands::BIAS => {
+ let bias_expr = self.next()?;
+ let bias_lexp = self.lookup_expression.lookup(bias_expr)?;
+ let bias_handle =
+ self.get_expr_handle(bias_expr, bias_lexp, ctx, emitter, block, body_idx);
+ level = crate::SampleLevel::Bias(bias_handle);
+ words_left -= 1;
+ }
+ spirv::ImageOperands::LOD => {
+ let lod_expr = self.next()?;
+ let lod_lexp = self.lookup_expression.lookup(lod_expr)?;
+ let lod_handle =
+ self.get_expr_handle(lod_expr, lod_lexp, ctx, emitter, block, body_idx);
+ level = if options.compare {
+ log::debug!("Assuming {:?} is zero", lod_handle);
+ crate::SampleLevel::Zero
+ } else {
+ crate::SampleLevel::Exact(lod_handle)
+ };
+ words_left -= 1;
+ }
+ spirv::ImageOperands::GRAD => {
+ let grad_x_expr = self.next()?;
+ let grad_x_lexp = self.lookup_expression.lookup(grad_x_expr)?;
+ let grad_x_handle = self.get_expr_handle(
+ grad_x_expr,
+ grad_x_lexp,
+ ctx,
+ emitter,
+ block,
+ body_idx,
+ );
+ let grad_y_expr = self.next()?;
+ let grad_y_lexp = self.lookup_expression.lookup(grad_y_expr)?;
+ let grad_y_handle = self.get_expr_handle(
+ grad_y_expr,
+ grad_y_lexp,
+ ctx,
+ emitter,
+ block,
+ body_idx,
+ );
+ level = if options.compare {
+ log::debug!(
+ "Assuming gradients {:?} and {:?} are not greater than 1",
+ grad_x_handle,
+ grad_y_handle
+ );
+ crate::SampleLevel::Zero
+ } else {
+ crate::SampleLevel::Gradient {
+ x: grad_x_handle,
+ y: grad_y_handle,
+ }
+ };
+ words_left -= 2;
+ }
+ spirv::ImageOperands::CONST_OFFSET => {
+ let offset_constant = self.next()?;
+ let offset_handle = self.lookup_constant.lookup(offset_constant)?.handle;
+ offset = Some(offset_handle);
+ words_left -= 1;
+ }
+ other => {
+ log::warn!("Unknown image sample operand {:?}", other);
+ for _ in 0..words_left {
+ self.next()?;
+ }
+ break;
+ }
+ }
+ image_ops ^= bit;
+ }
+
+ let si_lexp = self.lookup_sampled_image.lookup(sampled_image_id)?;
+ let coord_lexp = self.lookup_expression.lookup(coordinate_id)?;
+ let coord_handle =
+ self.get_expr_handle(coordinate_id, coord_lexp, ctx, emitter, block, body_idx);
+ let coord_type_handle = self.lookup_type.lookup(coord_lexp.type_id)?.handle;
+
+ let sampling_bit = if options.compare {
+ SamplingFlags::COMPARISON
+ } else {
+ SamplingFlags::REGULAR
+ };
+
+ let image_ty = match ctx.expressions[si_lexp.image] {
+ crate::Expression::GlobalVariable(handle) => {
+ if let Some(flags) = self.handle_sampling.get_mut(&handle) {
+ *flags |= sampling_bit;
+ }
+
+ ctx.global_arena[handle].ty
+ }
+
+ crate::Expression::FunctionArgument(i) => {
+ ctx.parameter_sampling[i as usize] |= sampling_bit;
+ ctx.arguments[i as usize].ty
+ }
+
+ crate::Expression::Access { base, .. } => match ctx.expressions[base] {
+ crate::Expression::GlobalVariable(handle) => {
+ if let Some(flags) = self.handle_sampling.get_mut(&handle) {
+ *flags |= sampling_bit;
+ }
+
+ match ctx.type_arena[ctx.global_arena[handle].ty].inner {
+ crate::TypeInner::BindingArray { base, .. } => base,
+ _ => return Err(Error::InvalidGlobalVar(ctx.expressions[base].clone())),
+ }
+ }
+
+ ref other => return Err(Error::InvalidGlobalVar(other.clone())),
+ },
+
+ ref other => return Err(Error::InvalidGlobalVar(other.clone())),
+ };
+
+ match ctx.expressions[si_lexp.sampler] {
+ crate::Expression::GlobalVariable(handle) => {
+ *self.handle_sampling.get_mut(&handle).unwrap() |= sampling_bit;
+ }
+
+ crate::Expression::FunctionArgument(i) => {
+ ctx.parameter_sampling[i as usize] |= sampling_bit;
+ }
+
+ crate::Expression::Access { base, .. } => match ctx.expressions[base] {
+ crate::Expression::GlobalVariable(handle) => {
+ *self.handle_sampling.get_mut(&handle).unwrap() |= sampling_bit;
+ }
+
+ ref other => return Err(Error::InvalidGlobalVar(other.clone())),
+ },
+
+ ref other => return Err(Error::InvalidGlobalVar(other.clone())),
+ }
+
+ let ((coordinate, array_index), depth_ref) = match ctx.type_arena[image_ty].inner {
+ crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class: _,
+ } => (
+ extract_image_coordinates(
+ dim,
+ if options.project {
+ ExtraCoordinate::Projection
+ } else if arrayed {
+ ExtraCoordinate::ArrayLayer
+ } else {
+ ExtraCoordinate::Garbage
+ },
+ coord_handle,
+ coord_type_handle,
+ ctx,
+ ),
+ {
+ match dref_id {
+ Some(id) => {
+ let expr_lexp = self.lookup_expression.lookup(id)?;
+ let mut expr =
+ self.get_expr_handle(id, expr_lexp, ctx, emitter, block, body_idx);
+
+ if options.project {
+ let required_size = dim.required_coordinate_size();
+ let right = ctx.expressions.append(
+ crate::Expression::AccessIndex {
+ base: coord_handle,
+ index: required_size.map_or(1, |size| size as u32),
+ },
+ crate::Span::default(),
+ );
+ expr = ctx.expressions.append(
+ crate::Expression::Binary {
+ op: crate::BinaryOperator::Divide,
+ left: expr,
+ right,
+ },
+ crate::Span::default(),
+ )
+ };
+ Some(expr)
+ }
+ None => None,
+ }
+ },
+ ),
+ _ => return Err(Error::InvalidImage(image_ty)),
+ };
+
+ let expr = crate::Expression::ImageSample {
+ image: si_lexp.image,
+ sampler: si_lexp.sampler,
+ gather: None, //TODO
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, self.span_from_with_op(start)),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Ok(())
+ }
+
+ pub(super) fn parse_image_query_size(
+ &mut self,
+ at_level: bool,
+ ctx: &mut super::BlockContext,
+ emitter: &mut crate::front::Emitter,
+ block: &mut crate::Block,
+ block_id: spirv::Word,
+ body_idx: usize,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let image_id = self.next()?;
+ let level = if at_level {
+ let level_id = self.next()?;
+ let level_lexp = self.lookup_expression.lookup(level_id)?;
+ Some(self.get_expr_handle(level_id, level_lexp, ctx, emitter, block, body_idx))
+ } else {
+ None
+ };
+
+ // No need to call get_expr_handle here since only globals/arguments are
+ // allowed as images and they are always in the root scope
+ //TODO: handle arrays and cubes
+ let image_lexp = self.lookup_expression.lookup(image_id)?;
+
+ let expr = crate::Expression::ImageQuery {
+ image: image_lexp.handle,
+ query: crate::ImageQuery::Size { level },
+ };
+ let expr = crate::Expression::As {
+ expr: ctx.expressions.append(expr, self.span_from_with_op(start)),
+ kind: crate::ScalarKind::Sint,
+ convert: Some(4),
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, self.span_from_with_op(start)),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Ok(())
+ }
+
+ pub(super) fn parse_image_query_other(
+ &mut self,
+ query: crate::ImageQuery,
+ expressions: &mut Arena<crate::Expression>,
+ block_id: spirv::Word,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let image_id = self.next()?;
+
+ // No need to call get_expr_handle here since only globals/arguments are
+ // allowed as images and they are always in the root scope
+ let image_lexp = self.lookup_expression.lookup(image_id)?.clone();
+
+ let expr = crate::Expression::ImageQuery {
+ image: image_lexp.handle,
+ query,
+ };
+ let expr = crate::Expression::As {
+ expr: expressions.append(expr, self.span_from_with_op(start)),
+ kind: crate::ScalarKind::Sint,
+ convert: Some(4),
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: expressions.append(expr, self.span_from_with_op(start)),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Ok(())
+ }
+}
diff --git a/third_party/rust/naga/src/front/spv/mod.rs b/third_party/rust/naga/src/front/spv/mod.rs
new file mode 100644
index 0000000000..c69a230cb0
--- /dev/null
+++ b/third_party/rust/naga/src/front/spv/mod.rs
@@ -0,0 +1,5195 @@
+/*!
+Frontend for [SPIR-V][spv] (Standard Portable Intermediate Representation).
+
+## ID lookups
+
+Our IR links to everything with `Handle`, while SPIR-V uses IDs.
+In order to keep track of the associations, the parser has many lookup tables.
+There map `spv::Word` into a specific IR handle, plus potentially a bit of
+extra info, such as the related SPIR-V type ID.
+TODO: would be nice to find ways that avoid looking up as much
+
+## Inputs/Outputs
+
+We create a private variable for each input/output. The relevant inputs are
+populated at the start of an entry point. The outputs are saved at the end.
+
+The function associated with an entry point is wrapped in another function,
+such that we can handle any `Return` statements without problems.
+
+## Row-major matrices
+
+We don't handle them natively, since the IR only expects column majority.
+Instead, we detect when such matrix is accessed in the `OpAccessChain`,
+and we generate a parallel expression that loads the value, but transposed.
+This value then gets used instead of `OpLoad` result later on.
+
+[spv]: https://www.khronos.org/registry/SPIR-V/
+*/
+
+mod convert;
+mod error;
+mod function;
+mod image;
+mod null;
+
+use convert::*;
+pub use error::Error;
+use function::*;
+
+use crate::{
+ arena::{Arena, Handle, UniqueArena},
+ proc::{Alignment, Layouter},
+ FastHashMap, FastHashSet,
+};
+
+use num_traits::cast::FromPrimitive;
+use petgraph::graphmap::GraphMap;
+use std::{convert::TryInto, mem, num::NonZeroU32, path::PathBuf};
+
+pub const SUPPORTED_CAPABILITIES: &[spirv::Capability] = &[
+ spirv::Capability::Shader,
+ spirv::Capability::VulkanMemoryModel,
+ spirv::Capability::ClipDistance,
+ spirv::Capability::CullDistance,
+ spirv::Capability::SampleRateShading,
+ spirv::Capability::DerivativeControl,
+ spirv::Capability::InterpolationFunction,
+ spirv::Capability::Matrix,
+ spirv::Capability::ImageQuery,
+ spirv::Capability::Sampled1D,
+ spirv::Capability::Image1D,
+ spirv::Capability::SampledCubeArray,
+ spirv::Capability::ImageCubeArray,
+ spirv::Capability::ImageMSArray,
+ spirv::Capability::StorageImageExtendedFormats,
+ spirv::Capability::Sampled1D,
+ spirv::Capability::SampledCubeArray,
+ spirv::Capability::Int8,
+ spirv::Capability::Int16,
+ spirv::Capability::Int64,
+ spirv::Capability::Float16,
+ spirv::Capability::Float64,
+ spirv::Capability::Geometry,
+ spirv::Capability::MultiView,
+ // tricky ones
+ spirv::Capability::UniformBufferArrayDynamicIndexing,
+ spirv::Capability::StorageBufferArrayDynamicIndexing,
+];
+pub const SUPPORTED_EXTENSIONS: &[&str] = &[
+ "SPV_KHR_storage_buffer_storage_class",
+ "SPV_KHR_vulkan_memory_model",
+ "SPV_KHR_multiview",
+];
+pub const SUPPORTED_EXT_SETS: &[&str] = &["GLSL.std.450"];
+
+#[derive(Copy, Clone)]
+pub struct Instruction {
+ op: spirv::Op,
+ wc: u16,
+}
+
+impl Instruction {
+ const fn expect(self, count: u16) -> Result<(), Error> {
+ if self.wc == count {
+ Ok(())
+ } else {
+ Err(Error::InvalidOperandCount(self.op, self.wc))
+ }
+ }
+
+ fn expect_at_least(self, count: u16) -> Result<u16, Error> {
+ self.wc
+ .checked_sub(count)
+ .ok_or(Error::InvalidOperandCount(self.op, self.wc))
+ }
+}
+
+impl crate::TypeInner {
+ fn can_comparison_sample(&self, module: &crate::Module) -> bool {
+ match *self {
+ crate::TypeInner::Image {
+ class:
+ crate::ImageClass::Sampled {
+ kind: crate::ScalarKind::Float,
+ multi: false,
+ },
+ ..
+ } => true,
+ crate::TypeInner::Sampler { .. } => true,
+ crate::TypeInner::BindingArray { base, .. } => {
+ module.types[base].inner.can_comparison_sample(module)
+ }
+ _ => false,
+ }
+ }
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
+pub enum ModuleState {
+ Empty,
+ Capability,
+ Extension,
+ ExtInstImport,
+ MemoryModel,
+ EntryPoint,
+ ExecutionMode,
+ Source,
+ Name,
+ ModuleProcessed,
+ Annotation,
+ Type,
+ Function,
+}
+
+trait LookupHelper {
+ type Target;
+ fn lookup(&self, key: spirv::Word) -> Result<&Self::Target, Error>;
+}
+
+impl<T> LookupHelper for FastHashMap<spirv::Word, T> {
+ type Target = T;
+ fn lookup(&self, key: spirv::Word) -> Result<&T, Error> {
+ self.get(&key).ok_or(Error::InvalidId(key))
+ }
+}
+
+impl crate::ImageDimension {
+ const fn required_coordinate_size(&self) -> Option<crate::VectorSize> {
+ match *self {
+ crate::ImageDimension::D1 => None,
+ crate::ImageDimension::D2 => Some(crate::VectorSize::Bi),
+ crate::ImageDimension::D3 => Some(crate::VectorSize::Tri),
+ crate::ImageDimension::Cube => Some(crate::VectorSize::Tri),
+ }
+ }
+}
+
+type MemberIndex = u32;
+
+bitflags::bitflags! {
+ #[derive(Default)]
+ struct DecorationFlags: u32 {
+ const NON_READABLE = 0x1;
+ const NON_WRITABLE = 0x2;
+ }
+}
+
+impl DecorationFlags {
+ fn to_storage_access(self) -> crate::StorageAccess {
+ let mut access = crate::StorageAccess::all();
+ if self.contains(DecorationFlags::NON_READABLE) {
+ access &= !crate::StorageAccess::LOAD;
+ }
+ if self.contains(DecorationFlags::NON_WRITABLE) {
+ access &= !crate::StorageAccess::STORE;
+ }
+ access
+ }
+}
+
+#[derive(Debug, PartialEq)]
+enum Majority {
+ Column,
+ Row,
+}
+
+#[derive(Debug, Default)]
+struct Decoration {
+ name: Option<String>,
+ built_in: Option<spirv::Word>,
+ location: Option<spirv::Word>,
+ desc_set: Option<spirv::Word>,
+ desc_index: Option<spirv::Word>,
+ specialization: Option<spirv::Word>,
+ storage_buffer: bool,
+ offset: Option<spirv::Word>,
+ array_stride: Option<NonZeroU32>,
+ matrix_stride: Option<NonZeroU32>,
+ matrix_major: Option<Majority>,
+ invariant: bool,
+ interpolation: Option<crate::Interpolation>,
+ sampling: Option<crate::Sampling>,
+ flags: DecorationFlags,
+}
+
+impl Decoration {
+ fn debug_name(&self) -> &str {
+ match self.name {
+ Some(ref name) => name.as_str(),
+ None => "?",
+ }
+ }
+
+ const fn resource_binding(&self) -> Option<crate::ResourceBinding> {
+ match *self {
+ Decoration {
+ desc_set: Some(group),
+ desc_index: Some(binding),
+ ..
+ } => Some(crate::ResourceBinding { group, binding }),
+ _ => None,
+ }
+ }
+
+ fn io_binding(&self) -> Result<crate::Binding, Error> {
+ match *self {
+ Decoration {
+ built_in: Some(built_in),
+ location: None,
+ invariant,
+ ..
+ } => Ok(crate::Binding::BuiltIn(map_builtin(built_in, invariant)?)),
+ Decoration {
+ built_in: None,
+ location: Some(location),
+ interpolation,
+ sampling,
+ ..
+ } => Ok(crate::Binding::Location {
+ location,
+ interpolation,
+ sampling,
+ }),
+ _ => Err(Error::MissingDecoration(spirv::Decoration::Location)),
+ }
+ }
+}
+
+#[derive(Debug)]
+struct LookupFunctionType {
+ parameter_type_ids: Vec<spirv::Word>,
+ return_type_id: spirv::Word,
+}
+
+struct LookupFunction {
+ handle: Handle<crate::Function>,
+ parameters_sampling: Vec<image::SamplingFlags>,
+}
+
+#[derive(Debug)]
+struct EntryPoint {
+ stage: crate::ShaderStage,
+ name: String,
+ early_depth_test: Option<crate::EarlyDepthTest>,
+ workgroup_size: [u32; 3],
+ variable_ids: Vec<spirv::Word>,
+}
+
+#[derive(Clone, Debug)]
+struct LookupType {
+ handle: Handle<crate::Type>,
+ base_id: Option<spirv::Word>,
+}
+
+#[derive(Debug)]
+struct LookupConstant {
+ handle: Handle<crate::Constant>,
+ type_id: spirv::Word,
+}
+
+#[derive(Debug)]
+enum Variable {
+ Global,
+ Input(crate::FunctionArgument),
+ Output(crate::FunctionResult),
+}
+
+#[derive(Debug)]
+struct LookupVariable {
+ inner: Variable,
+ handle: Handle<crate::GlobalVariable>,
+ type_id: spirv::Word,
+}
+
+/// Information about SPIR-V result ids, stored in `Parser::lookup_expression`.
+#[derive(Clone, Debug)]
+struct LookupExpression {
+ /// The `Expression` constructed for this result.
+ ///
+ /// Note that, while a SPIR-V result id can be used in any block dominated
+ /// by its definition, a Naga `Expression` is only in scope for the rest of
+ /// its subtree. `Parser::get_expr_handle` takes care of spilling the result
+ /// to a `LocalVariable` which can then be used anywhere.
+ handle: Handle<crate::Expression>,
+
+ /// The SPIR-V type of this result.
+ type_id: spirv::Word,
+
+ /// The label id of the block that defines this expression.
+ ///
+ /// This is zero for globals, constants, and function parameters, since they
+ /// originate outside any function's block.
+ block_id: spirv::Word,
+}
+
+#[derive(Debug)]
+struct LookupMember {
+ type_id: spirv::Word,
+ // This is true for either matrices, or arrays of matrices (yikes).
+ row_major: bool,
+}
+
+#[derive(Clone, Debug)]
+enum LookupLoadOverride {
+ /// For arrays of matrices, we track them but not loading yet.
+ Pending,
+ /// For matrices, vectors, and scalars, we pre-load the data.
+ Loaded(Handle<crate::Expression>),
+}
+
+#[derive(PartialEq)]
+enum ExtendedClass {
+ Global(crate::AddressSpace),
+ Input,
+ Output,
+}
+
+#[derive(Clone, Debug)]
+pub struct Options {
+ /// The IR coordinate space matches all the APIs except SPIR-V,
+ /// so by default we flip the Y coordinate of the `BuiltIn::Position`.
+ /// This flag can be used to avoid this.
+ pub adjust_coordinate_space: bool,
+ /// Only allow shaders with the known set of capabilities.
+ pub strict_capabilities: bool,
+ pub block_ctx_dump_prefix: Option<PathBuf>,
+}
+
+impl Default for Options {
+ fn default() -> Self {
+ Options {
+ adjust_coordinate_space: true,
+ strict_capabilities: false,
+ block_ctx_dump_prefix: None,
+ }
+ }
+}
+
+/// An index into the `BlockContext::bodies` table.
+type BodyIndex = usize;
+
+/// An intermediate representation of a Naga [`Statement`].
+///
+/// `Body` and `BodyFragment` values form a tree: the `BodyIndex` fields of the
+/// variants are indices of the child `Body` values in [`BlockContext::bodies`].
+/// The `lower` function assembles the final `Statement` tree from this `Body`
+/// tree. See [`BlockContext`] for details.
+///
+/// [`Statement`]: crate::Statement
+#[derive(Debug)]
+enum BodyFragment {
+ BlockId(spirv::Word),
+ If {
+ condition: Handle<crate::Expression>,
+ accept: BodyIndex,
+ reject: BodyIndex,
+ },
+ Loop {
+ body: BodyIndex,
+ continuing: BodyIndex,
+ },
+ Switch {
+ selector: Handle<crate::Expression>,
+ cases: Vec<(i32, BodyIndex)>,
+ default: BodyIndex,
+ },
+ Break,
+ Continue,
+}
+
+/// An intermediate representation of a Naga [`Block`].
+///
+/// This will be assembled into a `Block` once we've added spills for phi nodes
+/// and out-of-scope expressions. See [`BlockContext`] for details.
+///
+/// [`Block`]: crate::Block
+#[derive(Debug)]
+struct Body {
+ /// The index of the direct parent of this body
+ parent: usize,
+ data: Vec<BodyFragment>,
+}
+
+impl Body {
+ /// Creates a new empty `Body` with the specified `parent`
+ pub const fn with_parent(parent: usize) -> Self {
+ Body {
+ parent,
+ data: Vec::new(),
+ }
+ }
+}
+
+#[derive(Debug)]
+struct PhiExpression {
+ /// The local variable used for the phi node
+ local: Handle<crate::LocalVariable>,
+ /// List of (expression, block)
+ expressions: Vec<(spirv::Word, spirv::Word)>,
+}
+
+#[derive(Debug)]
+enum MergeBlockInformation {
+ LoopMerge,
+ LoopContinue,
+ SelectionMerge,
+ SwitchMerge,
+}
+
+/// Fragments of Naga IR, to be assembled into `Statements` once data flow is
+/// resolved.
+///
+/// We can't build a Naga `Statement` tree directly from SPIR-V blocks for two
+/// main reasons:
+///
+/// - A SPIR-V expression can be used in any SPIR-V block dominated by its
+/// definition, whereas Naga expressions are scoped to the rest of their
+/// subtree. This means that discovering an expression use later in the
+/// function retroactively requires us to have spilled that expression into a
+/// local variable back before we left its scope.
+///
+/// - We translate SPIR-V OpPhi expressions as Naga local variables in which we
+/// store the appropriate value before jumping to the OpPhi's block.
+///
+/// Both cases require us to go back and amend previously generated Naga IR
+/// based on things we discover later. But modifying old blocks in arbitrary
+/// spots in a `Statement` tree is awkward.
+///
+/// Instead, as we iterate through the function's body, we accumulate
+/// control-flow-free fragments of Naga IR in the [`blocks`] table, while
+/// building a skeleton of the Naga `Statement` tree in [`bodies`]. We note any
+/// spills and temporaries we must introduce in [`phis`].
+///
+/// Finally, once we've processed the entire function, we add temporaries and
+/// spills to the fragmentary `Blocks` as directed by `phis`, and assemble them
+/// into the final Naga `Statement` tree as directed by `bodies`.
+///
+/// [`blocks`]: BlockContext::blocks
+/// [`bodies`]: BlockContext::bodies
+/// [`phis`]: BlockContext::phis
+/// [`lower`]: function::lower
+#[derive(Debug)]
+struct BlockContext<'function> {
+ /// Phi nodes encountered when parsing the function, used to generate spills
+ /// to local variables.
+ phis: Vec<PhiExpression>,
+
+ /// Fragments of control-flow-free Naga IR.
+ ///
+ /// These will be stitched together into a proper `Statement` tree according
+ /// to `bodies`, once parsing is complete.
+ blocks: FastHashMap<spirv::Word, crate::Block>,
+
+ /// Map from block label ids to the index of the corresponding `Body` in
+ /// `bodies`.
+ body_for_label: FastHashMap<spirv::Word, BodyIndex>,
+
+ /// SPIR-V metadata about merge/continue blocks.
+ mergers: FastHashMap<spirv::Word, MergeBlockInformation>,
+
+ /// A table of `Body` values, each representing a block in the final IR.
+ bodies: Vec<Body>,
+
+ /// Id of the function currently being processed
+ function_id: spirv::Word,
+ /// Expression arena of the function currently being processed
+ expressions: &'function mut Arena<crate::Expression>,
+ /// Local variables arena of the function currently being processed
+ local_arena: &'function mut Arena<crate::LocalVariable>,
+ /// Constants arena of the module being processed
+ const_arena: &'function mut Arena<crate::Constant>,
+ /// Type arena of the module being processed
+ type_arena: &'function UniqueArena<crate::Type>,
+ /// Global arena of the module being processed
+ global_arena: &'function Arena<crate::GlobalVariable>,
+ /// Arguments of the function currently being processed
+ arguments: &'function [crate::FunctionArgument],
+ /// Metadata about the usage of function parameters as sampling objects
+ parameter_sampling: &'function mut [image::SamplingFlags],
+}
+
+enum SignAnchor {
+ Result,
+ Operand,
+}
+
+pub struct Frontend<I> {
+ data: I,
+ data_offset: usize,
+ state: ModuleState,
+ layouter: Layouter,
+ temp_bytes: Vec<u8>,
+ ext_glsl_id: Option<spirv::Word>,
+ future_decor: FastHashMap<spirv::Word, Decoration>,
+ future_member_decor: FastHashMap<(spirv::Word, MemberIndex), Decoration>,
+ lookup_member: FastHashMap<(Handle<crate::Type>, MemberIndex), LookupMember>,
+ handle_sampling: FastHashMap<Handle<crate::GlobalVariable>, image::SamplingFlags>,
+ lookup_type: FastHashMap<spirv::Word, LookupType>,
+ lookup_void_type: Option<spirv::Word>,
+ lookup_storage_buffer_types: FastHashMap<Handle<crate::Type>, crate::StorageAccess>,
+ // Lookup for samplers and sampled images, storing flags on how they are used.
+ lookup_constant: FastHashMap<spirv::Word, LookupConstant>,
+ lookup_variable: FastHashMap<spirv::Word, LookupVariable>,
+ lookup_expression: FastHashMap<spirv::Word, LookupExpression>,
+ // Load overrides are used to work around row-major matrices
+ lookup_load_override: FastHashMap<spirv::Word, LookupLoadOverride>,
+ lookup_sampled_image: FastHashMap<spirv::Word, image::LookupSampledImage>,
+ lookup_function_type: FastHashMap<spirv::Word, LookupFunctionType>,
+ lookup_function: FastHashMap<spirv::Word, LookupFunction>,
+ lookup_entry_point: FastHashMap<spirv::Word, EntryPoint>,
+ //Note: each `OpFunctionCall` gets a single entry here, indexed by the
+ // dummy `Handle<crate::Function>` of the call site.
+ deferred_function_calls: Vec<spirv::Word>,
+ dummy_functions: Arena<crate::Function>,
+ // Graph of all function calls through the module.
+ // It's used to sort the functions (as nodes) topologically,
+ // so that in the IR any called function is already known.
+ function_call_graph: GraphMap<spirv::Word, (), petgraph::Directed>,
+ options: Options,
+ index_constants: Vec<Handle<crate::Constant>>,
+ index_constant_expressions: Vec<Handle<crate::Expression>>,
+
+ /// Maps for a switch from a case target to the respective body and associated literals that
+ /// use that target block id.
+ ///
+ /// Used to preserve allocations between instruction parsing.
+ switch_cases: indexmap::IndexMap<
+ spirv::Word,
+ (BodyIndex, Vec<i32>),
+ std::hash::BuildHasherDefault<rustc_hash::FxHasher>,
+ >,
+
+ /// Tracks access to gl_PerVertex's builtins, it is used to cull unused builtins since initializing those can
+ /// affect performance and the mere presence of some of these builtins might cause backends to error since they
+ /// might be unsupported.
+ ///
+ /// The problematic builtins are: PointSize, ClipDistance and CullDistance.
+ ///
+ /// glslang declares those by default even though they are never written to
+ /// (see <https://github.com/KhronosGroup/glslang/issues/1868>)
+ gl_per_vertex_builtin_access: FastHashSet<crate::BuiltIn>,
+}
+
+impl<I: Iterator<Item = u32>> Frontend<I> {
+ pub fn new(data: I, options: &Options) -> Self {
+ Frontend {
+ data,
+ data_offset: 0,
+ state: ModuleState::Empty,
+ layouter: Layouter::default(),
+ temp_bytes: Vec::new(),
+ ext_glsl_id: None,
+ future_decor: FastHashMap::default(),
+ future_member_decor: FastHashMap::default(),
+ handle_sampling: FastHashMap::default(),
+ lookup_member: FastHashMap::default(),
+ lookup_type: FastHashMap::default(),
+ lookup_void_type: None,
+ lookup_storage_buffer_types: FastHashMap::default(),
+ lookup_constant: FastHashMap::default(),
+ lookup_variable: FastHashMap::default(),
+ lookup_expression: FastHashMap::default(),
+ lookup_load_override: FastHashMap::default(),
+ lookup_sampled_image: FastHashMap::default(),
+ lookup_function_type: FastHashMap::default(),
+ lookup_function: FastHashMap::default(),
+ lookup_entry_point: FastHashMap::default(),
+ deferred_function_calls: Vec::default(),
+ dummy_functions: Arena::new(),
+ function_call_graph: GraphMap::new(),
+ options: options.clone(),
+ index_constants: Vec::new(),
+ index_constant_expressions: Vec::new(),
+ switch_cases: indexmap::IndexMap::default(),
+ gl_per_vertex_builtin_access: FastHashSet::default(),
+ }
+ }
+
+ fn span_from(&self, from: usize) -> crate::Span {
+ crate::Span::from(from..self.data_offset)
+ }
+
+ fn span_from_with_op(&self, from: usize) -> crate::Span {
+ crate::Span::from((from - 4)..self.data_offset)
+ }
+
+ fn next(&mut self) -> Result<u32, Error> {
+ if let Some(res) = self.data.next() {
+ self.data_offset += 4;
+ Ok(res)
+ } else {
+ Err(Error::IncompleteData)
+ }
+ }
+
+ fn next_inst(&mut self) -> Result<Instruction, Error> {
+ let word = self.next()?;
+ let (wc, opcode) = ((word >> 16) as u16, (word & 0xffff) as u16);
+ if wc == 0 {
+ return Err(Error::InvalidWordCount);
+ }
+ let op = spirv::Op::from_u16(opcode).ok_or(Error::UnknownInstruction(opcode))?;
+
+ Ok(Instruction { op, wc })
+ }
+
+ fn next_string(&mut self, mut count: u16) -> Result<(String, u16), Error> {
+ self.temp_bytes.clear();
+ loop {
+ if count == 0 {
+ return Err(Error::BadString);
+ }
+ count -= 1;
+ let chars = self.next()?.to_le_bytes();
+ let pos = chars.iter().position(|&c| c == 0).unwrap_or(4);
+ self.temp_bytes.extend_from_slice(&chars[..pos]);
+ if pos < 4 {
+ break;
+ }
+ }
+ std::str::from_utf8(&self.temp_bytes)
+ .map(|s| (s.to_owned(), count))
+ .map_err(|_| Error::BadString)
+ }
+
+ fn next_decoration(
+ &mut self,
+ inst: Instruction,
+ base_words: u16,
+ dec: &mut Decoration,
+ ) -> Result<(), Error> {
+ let raw = self.next()?;
+ let dec_typed = spirv::Decoration::from_u32(raw).ok_or(Error::InvalidDecoration(raw))?;
+ log::trace!("\t\t{}: {:?}", dec.debug_name(), dec_typed);
+ match dec_typed {
+ spirv::Decoration::BuiltIn => {
+ inst.expect(base_words + 2)?;
+ dec.built_in = Some(self.next()?);
+ }
+ spirv::Decoration::Location => {
+ inst.expect(base_words + 2)?;
+ dec.location = Some(self.next()?);
+ }
+ spirv::Decoration::DescriptorSet => {
+ inst.expect(base_words + 2)?;
+ dec.desc_set = Some(self.next()?);
+ }
+ spirv::Decoration::Binding => {
+ inst.expect(base_words + 2)?;
+ dec.desc_index = Some(self.next()?);
+ }
+ spirv::Decoration::BufferBlock => {
+ dec.storage_buffer = true;
+ }
+ spirv::Decoration::Offset => {
+ inst.expect(base_words + 2)?;
+ dec.offset = Some(self.next()?);
+ }
+ spirv::Decoration::ArrayStride => {
+ inst.expect(base_words + 2)?;
+ dec.array_stride = NonZeroU32::new(self.next()?);
+ }
+ spirv::Decoration::MatrixStride => {
+ inst.expect(base_words + 2)?;
+ dec.matrix_stride = NonZeroU32::new(self.next()?);
+ }
+ spirv::Decoration::Invariant => {
+ dec.invariant = true;
+ }
+ spirv::Decoration::NoPerspective => {
+ dec.interpolation = Some(crate::Interpolation::Linear);
+ }
+ spirv::Decoration::Flat => {
+ dec.interpolation = Some(crate::Interpolation::Flat);
+ }
+ spirv::Decoration::Centroid => {
+ dec.sampling = Some(crate::Sampling::Centroid);
+ }
+ spirv::Decoration::Sample => {
+ dec.sampling = Some(crate::Sampling::Sample);
+ }
+ spirv::Decoration::NonReadable => {
+ dec.flags |= DecorationFlags::NON_READABLE;
+ }
+ spirv::Decoration::NonWritable => {
+ dec.flags |= DecorationFlags::NON_WRITABLE;
+ }
+ spirv::Decoration::ColMajor => {
+ dec.matrix_major = Some(Majority::Column);
+ }
+ spirv::Decoration::RowMajor => {
+ dec.matrix_major = Some(Majority::Row);
+ }
+ spirv::Decoration::SpecId => {
+ dec.specialization = Some(self.next()?);
+ }
+ other => {
+ log::warn!("Unknown decoration {:?}", other);
+ for _ in base_words + 1..inst.wc {
+ let _var = self.next()?;
+ }
+ }
+ }
+ Ok(())
+ }
+
+ /// Return the Naga `Expression` for a given SPIR-V result `id`.
+ ///
+ /// `lookup` must be the `LookupExpression` for `id`.
+ ///
+ /// SPIR-V result ids can be used by any block dominated by the id's
+ /// definition, but Naga `Expressions` are only in scope for the remainder
+ /// of their `Statement` subtree. This means that the `Expression` generated
+ /// for `id` may no longer be in scope. In such cases, this function takes
+ /// care of spilling the value of `id` to a `LocalVariable` which can then
+ /// be used anywhere. The SPIR-V domination rule ensures that the
+ /// `LocalVariable` has been initialized before it is used.
+ ///
+ /// The `body_idx` argument should be the index of the `Body` that hopes to
+ /// use `id`'s `Expression`.
+ fn get_expr_handle(
+ &self,
+ id: spirv::Word,
+ lookup: &LookupExpression,
+ ctx: &mut BlockContext,
+ emitter: &mut super::Emitter,
+ block: &mut crate::Block,
+ body_idx: BodyIndex,
+ ) -> Handle<crate::Expression> {
+ // What `Body` was `id` defined in?
+ let expr_body_idx = ctx
+ .body_for_label
+ .get(&lookup.block_id)
+ .copied()
+ .unwrap_or(0);
+
+ // Don't need to do a load/store if the expression is in the main body
+ // or if the expression is in the same body as where the query was
+ // requested. The body_idx might actually not be the final one if a loop
+ // or conditional occurs but in those cases we know that the new body
+ // will be a subscope of the body that was passed so we can still reuse
+ // the handle and not issue a load/store.
+ if is_parent(body_idx, expr_body_idx, ctx) {
+ lookup.handle
+ } else {
+ // Add a temporary variable of the same type which will be used to
+ // store the original expression and used in the current block
+ let ty = self.lookup_type[&lookup.type_id].handle;
+ let local = ctx.local_arena.append(
+ crate::LocalVariable {
+ name: None,
+ ty,
+ init: None,
+ },
+ crate::Span::default(),
+ );
+
+ block.extend(emitter.finish(ctx.expressions));
+ let pointer = ctx.expressions.append(
+ crate::Expression::LocalVariable(local),
+ crate::Span::default(),
+ );
+ emitter.start(ctx.expressions);
+ let expr = ctx
+ .expressions
+ .append(crate::Expression::Load { pointer }, crate::Span::default());
+
+ // Add a slightly odd entry to the phi table, so that while `id`'s
+ // `Expression` is still in scope, the usual phi processing will
+ // spill its value to `local`, where we can find it later.
+ //
+ // This pretends that the block in which `id` is defined is the
+ // predecessor of some other block with a phi in it that cites id as
+ // one of its sources, and uses `local` as its variable. There is no
+ // such phi, but nobody needs to know that.
+ ctx.phis.push(PhiExpression {
+ local,
+ expressions: vec![(id, lookup.block_id)],
+ });
+
+ expr
+ }
+ }
+
+ fn parse_expr_unary_op(
+ &mut self,
+ ctx: &mut BlockContext,
+ emitter: &mut super::Emitter,
+ block: &mut crate::Block,
+ block_id: spirv::Word,
+ body_idx: usize,
+ op: crate::UnaryOperator,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let p_id = self.next()?;
+
+ let p_lexp = self.lookup_expression.lookup(p_id)?;
+ let handle = self.get_expr_handle(p_id, p_lexp, ctx, emitter, block, body_idx);
+
+ let expr = crate::Expression::Unary { op, expr: handle };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, self.span_from_with_op(start)),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_expr_binary_op(
+ &mut self,
+ ctx: &mut BlockContext,
+ emitter: &mut super::Emitter,
+ block: &mut crate::Block,
+ block_id: spirv::Word,
+ body_idx: usize,
+ op: crate::BinaryOperator,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let p1_id = self.next()?;
+ let p2_id = self.next()?;
+
+ let p1_lexp = self.lookup_expression.lookup(p1_id)?;
+ let left = self.get_expr_handle(p1_id, p1_lexp, ctx, emitter, block, body_idx);
+ let p2_lexp = self.lookup_expression.lookup(p2_id)?;
+ let right = self.get_expr_handle(p2_id, p2_lexp, ctx, emitter, block, body_idx);
+
+ let expr = crate::Expression::Binary { op, left, right };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, self.span_from_with_op(start)),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Ok(())
+ }
+
+ /// A more complicated version of the unary op,
+ /// where we force the operand to have the same type as the result.
+ fn parse_expr_unary_op_sign_adjusted(
+ &mut self,
+ ctx: &mut BlockContext,
+ emitter: &mut super::Emitter,
+ block: &mut crate::Block,
+ block_id: spirv::Word,
+ body_idx: usize,
+ op: crate::UnaryOperator,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let p1_id = self.next()?;
+ let span = self.span_from_with_op(start);
+
+ let p1_lexp = self.lookup_expression.lookup(p1_id)?;
+ let left = self.get_expr_handle(p1_id, p1_lexp, ctx, emitter, block, body_idx);
+
+ let result_lookup_ty = self.lookup_type.lookup(result_type_id)?;
+ let kind = ctx.type_arena[result_lookup_ty.handle]
+ .inner
+ .scalar_kind()
+ .unwrap();
+
+ let expr = crate::Expression::Unary {
+ op,
+ expr: if p1_lexp.type_id == result_type_id {
+ left
+ } else {
+ ctx.expressions.append(
+ crate::Expression::As {
+ expr: left,
+ kind,
+ convert: None,
+ },
+ span,
+ )
+ },
+ };
+
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Ok(())
+ }
+
+ /// A more complicated version of the binary op,
+ /// where we force the operand to have the same type as the result.
+ /// This is mostly needed for "i++" and "i--" coming from GLSL.
+ #[allow(clippy::too_many_arguments)]
+ fn parse_expr_binary_op_sign_adjusted(
+ &mut self,
+ ctx: &mut BlockContext,
+ emitter: &mut super::Emitter,
+ block: &mut crate::Block,
+ block_id: spirv::Word,
+ body_idx: usize,
+ op: crate::BinaryOperator,
+ // For arithmetic operations, we need the sign of operands to match the result.
+ // For boolean operations, however, the operands need to match the signs, but
+ // result is always different - a boolean.
+ anchor: SignAnchor,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let p1_id = self.next()?;
+ let p2_id = self.next()?;
+ let span = self.span_from_with_op(start);
+
+ let p1_lexp = self.lookup_expression.lookup(p1_id)?;
+ let left = self.get_expr_handle(p1_id, p1_lexp, ctx, emitter, block, body_idx);
+ let p2_lexp = self.lookup_expression.lookup(p2_id)?;
+ let right = self.get_expr_handle(p2_id, p2_lexp, ctx, emitter, block, body_idx);
+
+ let expected_type_id = match anchor {
+ SignAnchor::Result => result_type_id,
+ SignAnchor::Operand => p1_lexp.type_id,
+ };
+ let expected_lookup_ty = self.lookup_type.lookup(expected_type_id)?;
+ let kind = ctx.type_arena[expected_lookup_ty.handle]
+ .inner
+ .scalar_kind()
+ .unwrap();
+
+ let expr = crate::Expression::Binary {
+ op,
+ left: if p1_lexp.type_id == expected_type_id {
+ left
+ } else {
+ ctx.expressions.append(
+ crate::Expression::As {
+ expr: left,
+ kind,
+ convert: None,
+ },
+ span,
+ )
+ },
+ right: if p2_lexp.type_id == expected_type_id {
+ right
+ } else {
+ ctx.expressions.append(
+ crate::Expression::As {
+ expr: right,
+ kind,
+ convert: None,
+ },
+ span,
+ )
+ },
+ };
+
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Ok(())
+ }
+
+ /// A version of the binary op where one or both of the arguments might need to be casted to a
+ /// specific integer kind (unsigned or signed), used for operations like OpINotEqual or
+ /// OpUGreaterThan.
+ #[allow(clippy::too_many_arguments)]
+ fn parse_expr_int_comparison(
+ &mut self,
+ ctx: &mut BlockContext,
+ emitter: &mut super::Emitter,
+ block: &mut crate::Block,
+ block_id: spirv::Word,
+ body_idx: usize,
+ op: crate::BinaryOperator,
+ kind: crate::ScalarKind,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let p1_id = self.next()?;
+ let p2_id = self.next()?;
+ let span = self.span_from_with_op(start);
+
+ let p1_lexp = self.lookup_expression.lookup(p1_id)?;
+ let left = self.get_expr_handle(p1_id, p1_lexp, ctx, emitter, block, body_idx);
+ let p1_lookup_ty = self.lookup_type.lookup(p1_lexp.type_id)?;
+ let p1_kind = ctx.type_arena[p1_lookup_ty.handle]
+ .inner
+ .scalar_kind()
+ .unwrap();
+ let p2_lexp = self.lookup_expression.lookup(p2_id)?;
+ let right = self.get_expr_handle(p2_id, p2_lexp, ctx, emitter, block, body_idx);
+ let p2_lookup_ty = self.lookup_type.lookup(p2_lexp.type_id)?;
+ let p2_kind = ctx.type_arena[p2_lookup_ty.handle]
+ .inner
+ .scalar_kind()
+ .unwrap();
+
+ let expr = crate::Expression::Binary {
+ op,
+ left: if p1_kind == kind {
+ left
+ } else {
+ ctx.expressions.append(
+ crate::Expression::As {
+ expr: left,
+ kind,
+ convert: None,
+ },
+ span,
+ )
+ },
+ right: if p2_kind == kind {
+ right
+ } else {
+ ctx.expressions.append(
+ crate::Expression::As {
+ expr: right,
+ kind,
+ convert: None,
+ },
+ span,
+ )
+ },
+ };
+
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_expr_shift_op(
+ &mut self,
+ ctx: &mut BlockContext,
+ emitter: &mut super::Emitter,
+ block: &mut crate::Block,
+ block_id: spirv::Word,
+ body_idx: usize,
+ op: crate::BinaryOperator,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let p1_id = self.next()?;
+ let p2_id = self.next()?;
+
+ let span = self.span_from_with_op(start);
+
+ let p1_lexp = self.lookup_expression.lookup(p1_id)?;
+ let left = self.get_expr_handle(p1_id, p1_lexp, ctx, emitter, block, body_idx);
+ let p2_lexp = self.lookup_expression.lookup(p2_id)?;
+ let p2_handle = self.get_expr_handle(p2_id, p2_lexp, ctx, emitter, block, body_idx);
+ // convert the shift to Uint
+ let right = ctx.expressions.append(
+ crate::Expression::As {
+ expr: p2_handle,
+ kind: crate::ScalarKind::Uint,
+ convert: None,
+ },
+ span,
+ );
+
+ let expr = crate::Expression::Binary { op, left, right };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_expr_derivative(
+ &mut self,
+ ctx: &mut BlockContext,
+ emitter: &mut super::Emitter,
+ block: &mut crate::Block,
+ block_id: spirv::Word,
+ body_idx: usize,
+ (axis, ctrl): (crate::DerivativeAxis, crate::DerivativeControl),
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let arg_id = self.next()?;
+
+ let arg_lexp = self.lookup_expression.lookup(arg_id)?;
+ let arg_handle = self.get_expr_handle(arg_id, arg_lexp, ctx, emitter, block, body_idx);
+
+ let expr = crate::Expression::Derivative {
+ axis,
+ ctrl,
+ expr: arg_handle,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, self.span_from_with_op(start)),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Ok(())
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ fn insert_composite(
+ &self,
+ root_expr: Handle<crate::Expression>,
+ root_type_id: spirv::Word,
+ object_expr: Handle<crate::Expression>,
+ selections: &[spirv::Word],
+ type_arena: &UniqueArena<crate::Type>,
+ expressions: &mut Arena<crate::Expression>,
+ constants: &Arena<crate::Constant>,
+ span: crate::Span,
+ ) -> Result<Handle<crate::Expression>, Error> {
+ let selection = match selections.first() {
+ Some(&index) => index,
+ None => return Ok(object_expr),
+ };
+ let root_span = expressions.get_span(root_expr);
+ let root_lookup = self.lookup_type.lookup(root_type_id)?;
+ let (count, child_type_id) = match type_arena[root_lookup.handle].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ let child_member = self
+ .lookup_member
+ .get(&(root_lookup.handle, selection))
+ .ok_or(Error::InvalidAccessType(root_type_id))?;
+ (members.len(), child_member.type_id)
+ }
+ crate::TypeInner::Array { size, .. } => {
+ let size = match size {
+ crate::ArraySize::Constant(handle) => match constants[handle] {
+ crate::Constant {
+ specialization: Some(_),
+ ..
+ } => return Err(Error::UnsupportedType(root_lookup.handle)),
+ ref unspecialized => unspecialized
+ .to_array_length()
+ .ok_or(Error::InvalidArraySize(handle))?,
+ },
+ // A runtime sized array is not a composite type
+ crate::ArraySize::Dynamic => {
+ return Err(Error::InvalidAccessType(root_type_id))
+ }
+ };
+
+ let child_type_id = root_lookup
+ .base_id
+ .ok_or(Error::InvalidAccessType(root_type_id))?;
+
+ (size as usize, child_type_id)
+ }
+ crate::TypeInner::Vector { size, .. }
+ | crate::TypeInner::Matrix { columns: size, .. } => {
+ let child_type_id = root_lookup
+ .base_id
+ .ok_or(Error::InvalidAccessType(root_type_id))?;
+ (size as usize, child_type_id)
+ }
+ _ => return Err(Error::InvalidAccessType(root_type_id)),
+ };
+
+ let mut components = Vec::with_capacity(count);
+ for index in 0..count as u32 {
+ let expr = expressions.append(
+ crate::Expression::AccessIndex {
+ base: root_expr,
+ index,
+ },
+ if index == selection { span } else { root_span },
+ );
+ components.push(expr);
+ }
+ components[selection as usize] = self.insert_composite(
+ components[selection as usize],
+ child_type_id,
+ object_expr,
+ &selections[1..],
+ type_arena,
+ expressions,
+ constants,
+ span,
+ )?;
+
+ Ok(expressions.append(
+ crate::Expression::Compose {
+ ty: root_lookup.handle,
+ components,
+ },
+ span,
+ ))
+ }
+
+ /// Add the next SPIR-V block's contents to `block_ctx`.
+ ///
+ /// Except for the function's entry block, `block_id` should be the label of
+ /// a block we've seen mentioned before, with an entry in
+ /// `block_ctx.body_for_label` to tell us which `Body` it contributes to.
+ fn next_block(&mut self, block_id: spirv::Word, ctx: &mut BlockContext) -> Result<(), Error> {
+ // Extend `body` with the correct form for a branch to `target`.
+ fn merger(body: &mut Body, target: &MergeBlockInformation) {
+ body.data.push(match *target {
+ MergeBlockInformation::LoopContinue => BodyFragment::Continue,
+ MergeBlockInformation::LoopMerge | MergeBlockInformation::SwitchMerge => {
+ BodyFragment::Break
+ }
+
+ // Finishing a selection merge means just falling off the end of
+ // the `accept` or `reject` block of the `If` statement.
+ MergeBlockInformation::SelectionMerge => return,
+ })
+ }
+
+ let mut emitter = super::Emitter::default();
+ emitter.start(ctx.expressions);
+
+ // Find the `Body` that this block belongs to. Index zero is the
+ // function's root `Body`, corresponding to `Function::body`.
+ let mut body_idx = *ctx.body_for_label.entry(block_id).or_default();
+ let mut block = crate::Block::new();
+ // Stores the merge block as defined by a `OpSelectionMerge` otherwise is `None`
+ //
+ // This is used in `OpSwitch` to promote the `MergeBlockInformation` from
+ // `SelectionMerge` to `SwitchMerge` to allow `Break`s this isn't desirable for
+ // `LoopMerge`s because otherwise `Continue`s wouldn't be allowed
+ let mut selection_merge_block = None;
+
+ macro_rules! get_expr_handle {
+ ($id:expr, $lexp:expr) => {
+ self.get_expr_handle($id, $lexp, ctx, &mut emitter, &mut block, body_idx)
+ };
+ }
+ macro_rules! parse_expr_op {
+ ($op:expr, BINARY) => {
+ self.parse_expr_binary_op(ctx, &mut emitter, &mut block, block_id, body_idx, $op)
+ };
+
+ ($op:expr, SHIFT) => {
+ self.parse_expr_shift_op(ctx, &mut emitter, &mut block, block_id, body_idx, $op)
+ };
+ ($op:expr, UNARY) => {
+ self.parse_expr_unary_op(ctx, &mut emitter, &mut block, block_id, body_idx, $op)
+ };
+ ($axis:expr, $ctrl:expr, DERIVATIVE) => {
+ self.parse_expr_derivative(
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ ($axis, $ctrl),
+ )
+ };
+ }
+
+ let terminator = loop {
+ use spirv::Op;
+ let start = self.data_offset;
+ let inst = self.next_inst()?;
+ let span = crate::Span::from(start..(start + 4 * (inst.wc as usize)));
+ log::debug!("\t\t{:?} [{}]", inst.op, inst.wc);
+
+ match inst.op {
+ Op::Line => {
+ inst.expect(4)?;
+ let _file_id = self.next()?;
+ let _row_id = self.next()?;
+ let _col_id = self.next()?;
+ }
+ Op::NoLine => inst.expect(1)?,
+ Op::Undef => {
+ inst.expect(3)?;
+ let (type_id, id, handle) =
+ self.parse_null_constant(inst, ctx.type_arena, ctx.const_arena)?;
+ self.lookup_expression.insert(
+ id,
+ LookupExpression {
+ handle: ctx
+ .expressions
+ .append(crate::Expression::Constant(handle), span),
+ type_id,
+ block_id,
+ },
+ );
+ }
+ Op::Variable => {
+ inst.expect_at_least(4)?;
+ block.extend(emitter.finish(ctx.expressions));
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let _storage_class = self.next()?;
+ let init = if inst.wc > 4 {
+ inst.expect(5)?;
+ let init_id = self.next()?;
+ let lconst = self.lookup_constant.lookup(init_id)?;
+ Some(lconst.handle)
+ } else {
+ None
+ };
+
+ let name = self
+ .future_decor
+ .remove(&result_id)
+ .and_then(|decor| decor.name);
+ if let Some(ref name) = name {
+ log::debug!("\t\t\tid={} name={}", result_id, name);
+ }
+ let lookup_ty = self.lookup_type.lookup(result_type_id)?;
+ let var_handle = ctx.local_arena.append(
+ crate::LocalVariable {
+ name,
+ ty: match ctx.type_arena[lookup_ty.handle].inner {
+ crate::TypeInner::Pointer { base, .. } => base,
+ _ => lookup_ty.handle,
+ },
+ init,
+ },
+ span,
+ );
+
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx
+ .expressions
+ .append(crate::Expression::LocalVariable(var_handle), span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ emitter.start(ctx.expressions);
+ }
+ Op::Phi => {
+ inst.expect_at_least(3)?;
+ block.extend(emitter.finish(ctx.expressions));
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+
+ let name = format!("phi_{result_id}");
+ let local = ctx.local_arena.append(
+ crate::LocalVariable {
+ name: Some(name),
+ ty: self.lookup_type.lookup(result_type_id)?.handle,
+ init: None,
+ },
+ self.span_from(start),
+ );
+ let pointer = ctx
+ .expressions
+ .append(crate::Expression::LocalVariable(local), span);
+
+ let in_count = (inst.wc - 3) / 2;
+ let mut phi = PhiExpression {
+ local,
+ expressions: Vec::with_capacity(in_count as usize),
+ };
+ for _ in 0..in_count {
+ let expr = self.next()?;
+ let block = self.next()?;
+ phi.expressions.push((expr, block));
+ }
+
+ ctx.phis.push(phi);
+ emitter.start(ctx.expressions);
+
+ // Associate the lookup with an actual value, which is emitted
+ // into the current block.
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx
+ .expressions
+ .append(crate::Expression::Load { pointer }, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::AccessChain | Op::InBoundsAccessChain => {
+ struct AccessExpression {
+ base_handle: Handle<crate::Expression>,
+ type_id: spirv::Word,
+ load_override: Option<LookupLoadOverride>,
+ }
+
+ inst.expect_at_least(4)?;
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let base_id = self.next()?;
+ log::trace!("\t\t\tlooking up expr {:?}", base_id);
+
+ let mut acex = {
+ let lexp = self.lookup_expression.lookup(base_id)?;
+ let lty = self.lookup_type.lookup(lexp.type_id)?;
+
+ // HACK `OpAccessChain` and `OpInBoundsAccessChain`
+ // require for the result type to be a pointer, but if
+ // we're given a pointer to an image / sampler, it will
+ // be *already* dereferenced, since we do that early
+ // during `parse_type_pointer()`.
+ //
+ // This can happen only through `BindingArray`, since
+ // that's the only case where one can obtain a pointer
+ // to an image / sampler, and so let's match on that:
+ let dereference = match ctx.type_arena[lty.handle].inner {
+ crate::TypeInner::BindingArray { .. } => false,
+ _ => true,
+ };
+
+ let type_id = if dereference {
+ lty.base_id.ok_or(Error::InvalidAccessType(lexp.type_id))?
+ } else {
+ lexp.type_id
+ };
+
+ AccessExpression {
+ base_handle: get_expr_handle!(base_id, lexp),
+ type_id,
+ load_override: self.lookup_load_override.get(&base_id).cloned(),
+ }
+ };
+
+ for _ in 4..inst.wc {
+ let access_id = self.next()?;
+ log::trace!("\t\t\tlooking up index expr {:?}", access_id);
+ let index_expr = self.lookup_expression.lookup(access_id)?.clone();
+ let index_expr_handle = get_expr_handle!(access_id, &index_expr);
+ let index_expr_data = &ctx.expressions[index_expr.handle];
+ let index_maybe = match *index_expr_data {
+ crate::Expression::Constant(const_handle) => {
+ Some(ctx.const_arena[const_handle].to_array_length().ok_or(
+ Error::InvalidAccess(crate::Expression::Constant(const_handle)),
+ )?)
+ }
+ _ => None,
+ };
+
+ log::trace!("\t\t\tlooking up type {:?}", acex.type_id);
+ let type_lookup = self.lookup_type.lookup(acex.type_id)?;
+ let ty = &ctx.type_arena[type_lookup.handle];
+ acex = match ty.inner {
+ // can only index a struct with a constant
+ crate::TypeInner::Struct { ref members, .. } => {
+ let index = index_maybe
+ .ok_or_else(|| Error::InvalidAccess(index_expr_data.clone()))?;
+
+ let lookup_member = self
+ .lookup_member
+ .get(&(type_lookup.handle, index))
+ .ok_or(Error::InvalidAccessType(acex.type_id))?;
+ let base_handle = ctx.expressions.append(
+ crate::Expression::AccessIndex {
+ base: acex.base_handle,
+ index,
+ },
+ span,
+ );
+
+ if ty.name.as_deref() == Some("gl_PerVertex") {
+ if let Some(crate::Binding::BuiltIn(built_in)) =
+ members[index as usize].binding
+ {
+ self.gl_per_vertex_builtin_access.insert(built_in);
+ }
+ }
+
+ AccessExpression {
+ base_handle,
+ type_id: lookup_member.type_id,
+ load_override: if lookup_member.row_major {
+ debug_assert!(acex.load_override.is_none());
+ let sub_type_lookup =
+ self.lookup_type.lookup(lookup_member.type_id)?;
+ Some(match ctx.type_arena[sub_type_lookup.handle].inner {
+ // load it transposed, to match column major expectations
+ crate::TypeInner::Matrix { .. } => {
+ let loaded = ctx.expressions.append(
+ crate::Expression::Load {
+ pointer: base_handle,
+ },
+ span,
+ );
+ let transposed = ctx.expressions.append(
+ crate::Expression::Math {
+ fun: crate::MathFunction::Transpose,
+ arg: loaded,
+ arg1: None,
+ arg2: None,
+ arg3: None,
+ },
+ span,
+ );
+ LookupLoadOverride::Loaded(transposed)
+ }
+ _ => LookupLoadOverride::Pending,
+ })
+ } else {
+ None
+ },
+ }
+ }
+ crate::TypeInner::Matrix { .. } => {
+ let load_override = match acex.load_override {
+ // We are indexing inside a row-major matrix
+ Some(LookupLoadOverride::Loaded(load_expr)) => {
+ let index = index_maybe.ok_or_else(|| {
+ Error::InvalidAccess(index_expr_data.clone())
+ })?;
+ let sub_handle = ctx.expressions.append(
+ crate::Expression::AccessIndex {
+ base: load_expr,
+ index,
+ },
+ span,
+ );
+ Some(LookupLoadOverride::Loaded(sub_handle))
+ }
+ _ => None,
+ };
+ let sub_expr = match index_maybe {
+ Some(index) => crate::Expression::AccessIndex {
+ base: acex.base_handle,
+ index,
+ },
+ None => crate::Expression::Access {
+ base: acex.base_handle,
+ index: index_expr_handle,
+ },
+ };
+ AccessExpression {
+ base_handle: ctx.expressions.append(sub_expr, span),
+ type_id: type_lookup
+ .base_id
+ .ok_or(Error::InvalidAccessType(acex.type_id))?,
+ load_override,
+ }
+ }
+ // This must be a vector or an array.
+ _ => {
+ let base_handle = ctx.expressions.append(
+ crate::Expression::Access {
+ base: acex.base_handle,
+ index: index_expr_handle,
+ },
+ span,
+ );
+ let load_override = match acex.load_override {
+ // If there is a load override in place, then we always end up
+ // with a side-loaded value here.
+ Some(lookup_load_override) => {
+ let sub_expr = match lookup_load_override {
+ // We must be indexing into the array of row-major matrices.
+ // Let's load the result of indexing and transpose it.
+ LookupLoadOverride::Pending => {
+ let loaded = ctx.expressions.append(
+ crate::Expression::Load {
+ pointer: base_handle,
+ },
+ span,
+ );
+ ctx.expressions.append(
+ crate::Expression::Math {
+ fun: crate::MathFunction::Transpose,
+ arg: loaded,
+ arg1: None,
+ arg2: None,
+ arg3: None,
+ },
+ span,
+ )
+ }
+ // We are indexing inside a row-major matrix.
+ LookupLoadOverride::Loaded(load_expr) => {
+ ctx.expressions.append(
+ crate::Expression::Access {
+ base: load_expr,
+ index: index_expr_handle,
+ },
+ span,
+ )
+ }
+ };
+ Some(LookupLoadOverride::Loaded(sub_expr))
+ }
+ None => None,
+ };
+ AccessExpression {
+ base_handle,
+ type_id: type_lookup
+ .base_id
+ .ok_or(Error::InvalidAccessType(acex.type_id))?,
+ load_override,
+ }
+ }
+ };
+ }
+
+ if let Some(load_expr) = acex.load_override {
+ self.lookup_load_override.insert(result_id, load_expr);
+ }
+ let lookup_expression = LookupExpression {
+ handle: acex.base_handle,
+ type_id: result_type_id,
+ block_id,
+ };
+ self.lookup_expression.insert(result_id, lookup_expression);
+ }
+ Op::VectorExtractDynamic => {
+ inst.expect(5)?;
+
+ let result_type_id = self.next()?;
+ let id = self.next()?;
+ let composite_id = self.next()?;
+ let index_id = self.next()?;
+
+ let root_lexp = self.lookup_expression.lookup(composite_id)?;
+ let root_handle = get_expr_handle!(composite_id, root_lexp);
+ let root_type_lookup = self.lookup_type.lookup(root_lexp.type_id)?;
+ let index_lexp = self.lookup_expression.lookup(index_id)?;
+ let index_handle = get_expr_handle!(index_id, index_lexp);
+
+ let num_components = match ctx.type_arena[root_type_lookup.handle].inner {
+ crate::TypeInner::Vector { size, .. } => size as usize,
+ _ => return Err(Error::InvalidVectorType(root_type_lookup.handle)),
+ };
+
+ let mut handle = ctx.expressions.append(
+ crate::Expression::Access {
+ base: root_handle,
+ index: self.index_constant_expressions[0],
+ },
+ span,
+ );
+ for &index_expr in self.index_constant_expressions[1..num_components].iter() {
+ let access_expr = ctx.expressions.append(
+ crate::Expression::Access {
+ base: root_handle,
+ index: index_expr,
+ },
+ span,
+ );
+ let cond = ctx.expressions.append(
+ crate::Expression::Binary {
+ op: crate::BinaryOperator::Equal,
+ left: index_expr,
+ right: index_handle,
+ },
+ span,
+ );
+ handle = ctx.expressions.append(
+ crate::Expression::Select {
+ condition: cond,
+ accept: access_expr,
+ reject: handle,
+ },
+ span,
+ );
+ }
+
+ self.lookup_expression.insert(
+ id,
+ LookupExpression {
+ handle,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::VectorInsertDynamic => {
+ inst.expect(6)?;
+
+ let result_type_id = self.next()?;
+ let id = self.next()?;
+ let composite_id = self.next()?;
+ let object_id = self.next()?;
+ let index_id = self.next()?;
+
+ let object_lexp = self.lookup_expression.lookup(object_id)?;
+ let object_handle = get_expr_handle!(object_id, object_lexp);
+ let root_lexp = self.lookup_expression.lookup(composite_id)?;
+ let root_handle = get_expr_handle!(composite_id, root_lexp);
+ let root_type_lookup = self.lookup_type.lookup(root_lexp.type_id)?;
+ let index_lexp = self.lookup_expression.lookup(index_id)?;
+ let mut index_handle = get_expr_handle!(index_id, index_lexp);
+ let index_type = self.lookup_type.lookup(index_lexp.type_id)?.handle;
+
+ // SPIR-V allows signed and unsigned indices but naga's is strict about
+ // types and since the `index_constants` are all signed integers, we need
+ // to cast the index to a signed integer if it's unsigned.
+ if let Some(crate::ScalarKind::Uint) =
+ ctx.type_arena[index_type].inner.scalar_kind()
+ {
+ index_handle = ctx.expressions.append(
+ crate::Expression::As {
+ expr: index_handle,
+ kind: crate::ScalarKind::Sint,
+ convert: None,
+ },
+ span,
+ )
+ }
+
+ let num_components = match ctx.type_arena[root_type_lookup.handle].inner {
+ crate::TypeInner::Vector { size, .. } => size as usize,
+ _ => return Err(Error::InvalidVectorType(root_type_lookup.handle)),
+ };
+ let mut components = Vec::with_capacity(num_components);
+ for &index_expr in self.index_constant_expressions[..num_components].iter() {
+ let access_expr = ctx.expressions.append(
+ crate::Expression::Access {
+ base: root_handle,
+ index: index_expr,
+ },
+ span,
+ );
+ let cond = ctx.expressions.append(
+ crate::Expression::Binary {
+ op: crate::BinaryOperator::Equal,
+ left: index_expr,
+ right: index_handle,
+ },
+ span,
+ );
+ let handle = ctx.expressions.append(
+ crate::Expression::Select {
+ condition: cond,
+ accept: object_handle,
+ reject: access_expr,
+ },
+ span,
+ );
+ components.push(handle);
+ }
+ let handle = ctx.expressions.append(
+ crate::Expression::Compose {
+ ty: root_type_lookup.handle,
+ components,
+ },
+ span,
+ );
+
+ self.lookup_expression.insert(
+ id,
+ LookupExpression {
+ handle,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::CompositeExtract => {
+ inst.expect_at_least(4)?;
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let base_id = self.next()?;
+ log::trace!("\t\t\tlooking up expr {:?}", base_id);
+ let mut lexp = self.lookup_expression.lookup(base_id)?.clone();
+ lexp.handle = get_expr_handle!(base_id, &lexp);
+ for _ in 4..inst.wc {
+ let index = self.next()?;
+ log::trace!("\t\t\tlooking up type {:?}", lexp.type_id);
+ let type_lookup = self.lookup_type.lookup(lexp.type_id)?;
+ let type_id = match ctx.type_arena[type_lookup.handle].inner {
+ crate::TypeInner::Struct { .. } => {
+ self.lookup_member
+ .get(&(type_lookup.handle, index))
+ .ok_or(Error::InvalidAccessType(lexp.type_id))?
+ .type_id
+ }
+ crate::TypeInner::Array { .. }
+ | crate::TypeInner::Vector { .. }
+ | crate::TypeInner::Matrix { .. } => type_lookup
+ .base_id
+ .ok_or(Error::InvalidAccessType(lexp.type_id))?,
+ ref other => {
+ log::warn!("composite type {:?}", other);
+ return Err(Error::UnsupportedType(type_lookup.handle));
+ }
+ };
+ lexp = LookupExpression {
+ handle: ctx.expressions.append(
+ crate::Expression::AccessIndex {
+ base: lexp.handle,
+ index,
+ },
+ span,
+ ),
+ type_id,
+ block_id,
+ };
+ }
+
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: lexp.handle,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::CompositeInsert => {
+ inst.expect_at_least(5)?;
+
+ let result_type_id = self.next()?;
+ let id = self.next()?;
+ let object_id = self.next()?;
+ let composite_id = self.next()?;
+ let mut selections = Vec::with_capacity(inst.wc as usize - 5);
+ for _ in 5..inst.wc {
+ selections.push(self.next()?);
+ }
+
+ let object_lexp = self.lookup_expression.lookup(object_id)?.clone();
+ let object_handle = get_expr_handle!(object_id, &object_lexp);
+ let root_lexp = self.lookup_expression.lookup(composite_id)?.clone();
+ let root_handle = get_expr_handle!(composite_id, &root_lexp);
+ let handle = self.insert_composite(
+ root_handle,
+ result_type_id,
+ object_handle,
+ &selections,
+ ctx.type_arena,
+ ctx.expressions,
+ ctx.const_arena,
+ span,
+ )?;
+
+ self.lookup_expression.insert(
+ id,
+ LookupExpression {
+ handle,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::CompositeConstruct => {
+ inst.expect_at_least(3)?;
+
+ let result_type_id = self.next()?;
+ let id = self.next()?;
+ let mut components = Vec::with_capacity(inst.wc as usize - 2);
+ for _ in 3..inst.wc {
+ let comp_id = self.next()?;
+ log::trace!("\t\t\tlooking up expr {:?}", comp_id);
+ let lexp = self.lookup_expression.lookup(comp_id)?;
+ let handle = get_expr_handle!(comp_id, lexp);
+ components.push(handle);
+ }
+ let ty = self.lookup_type.lookup(result_type_id)?.handle;
+ let first = components[0];
+ let expr = match ctx.type_arena[ty].inner {
+ // this is an optimization to detect the splat
+ crate::TypeInner::Vector { size, .. }
+ if components.len() == size as usize
+ && components[1..].iter().all(|&c| c == first) =>
+ {
+ crate::Expression::Splat { size, value: first }
+ }
+ _ => crate::Expression::Compose { ty, components },
+ };
+ self.lookup_expression.insert(
+ id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::Load => {
+ inst.expect_at_least(4)?;
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let pointer_id = self.next()?;
+ if inst.wc != 4 {
+ inst.expect(5)?;
+ let _memory_access = self.next()?;
+ }
+
+ let base_lexp = self.lookup_expression.lookup(pointer_id)?;
+ let base_handle = get_expr_handle!(pointer_id, base_lexp);
+ let type_lookup = self.lookup_type.lookup(base_lexp.type_id)?;
+ let handle = match ctx.type_arena[type_lookup.handle].inner {
+ crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } => {
+ base_handle
+ }
+ _ => match self.lookup_load_override.get(&pointer_id) {
+ Some(&LookupLoadOverride::Loaded(handle)) => handle,
+ //Note: we aren't handling `LookupLoadOverride::Pending` properly here
+ _ => ctx.expressions.append(
+ crate::Expression::Load {
+ pointer: base_handle,
+ },
+ span,
+ ),
+ },
+ };
+
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::Store => {
+ inst.expect_at_least(3)?;
+
+ let pointer_id = self.next()?;
+ let value_id = self.next()?;
+ if inst.wc != 3 {
+ inst.expect(4)?;
+ let _memory_access = self.next()?;
+ }
+ let base_expr = self.lookup_expression.lookup(pointer_id)?;
+ let base_handle = get_expr_handle!(pointer_id, base_expr);
+ let value_expr = self.lookup_expression.lookup(value_id)?;
+ let value_handle = get_expr_handle!(value_id, value_expr);
+
+ block.extend(emitter.finish(ctx.expressions));
+ block.push(
+ crate::Statement::Store {
+ pointer: base_handle,
+ value: value_handle,
+ },
+ span,
+ );
+ emitter.start(ctx.expressions);
+ }
+ // Arithmetic Instructions +, -, *, /, %
+ Op::SNegate | Op::FNegate => {
+ inst.expect(4)?;
+ self.parse_expr_unary_op_sign_adjusted(
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ crate::UnaryOperator::Negate,
+ )?;
+ }
+ Op::IAdd
+ | Op::ISub
+ | Op::IMul
+ | Op::BitwiseOr
+ | Op::BitwiseXor
+ | Op::BitwiseAnd
+ | Op::SDiv
+ | Op::SRem => {
+ inst.expect(5)?;
+ let operator = map_binary_operator(inst.op)?;
+ self.parse_expr_binary_op_sign_adjusted(
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ operator,
+ SignAnchor::Result,
+ )?;
+ }
+ Op::IEqual | Op::INotEqual => {
+ inst.expect(5)?;
+ let operator = map_binary_operator(inst.op)?;
+ self.parse_expr_binary_op_sign_adjusted(
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ operator,
+ SignAnchor::Operand,
+ )?;
+ }
+ Op::FAdd => {
+ inst.expect(5)?;
+ parse_expr_op!(crate::BinaryOperator::Add, BINARY)?;
+ }
+ Op::FSub => {
+ inst.expect(5)?;
+ parse_expr_op!(crate::BinaryOperator::Subtract, BINARY)?;
+ }
+ Op::FMul => {
+ inst.expect(5)?;
+ parse_expr_op!(crate::BinaryOperator::Multiply, BINARY)?;
+ }
+ Op::UDiv | Op::FDiv => {
+ inst.expect(5)?;
+ parse_expr_op!(crate::BinaryOperator::Divide, BINARY)?;
+ }
+ Op::UMod | Op::FRem => {
+ inst.expect(5)?;
+ parse_expr_op!(crate::BinaryOperator::Modulo, BINARY)?;
+ }
+ Op::SMod => {
+ inst.expect(5)?;
+
+ // x - y * int(floor(float(x) / float(y)))
+
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let p1_id = self.next()?;
+ let p2_id = self.next()?;
+ let span = self.span_from_with_op(start);
+
+ let p1_lexp = self.lookup_expression.lookup(p1_id)?;
+ let left = self.get_expr_handle(
+ p1_id,
+ p1_lexp,
+ ctx,
+ &mut emitter,
+ &mut block,
+ body_idx,
+ );
+ let p2_lexp = self.lookup_expression.lookup(p2_id)?;
+ let right = self.get_expr_handle(
+ p2_id,
+ p2_lexp,
+ ctx,
+ &mut emitter,
+ &mut block,
+ body_idx,
+ );
+
+ let result_ty = self.lookup_type.lookup(result_type_id)?;
+ let inner = &ctx.type_arena[result_ty.handle].inner;
+ let kind = inner.scalar_kind().unwrap();
+ let size = inner.size(ctx.const_arena) as u8;
+
+ let left_cast = ctx.expressions.append(
+ crate::Expression::As {
+ expr: left,
+ kind: crate::ScalarKind::Float,
+ convert: Some(size),
+ },
+ span,
+ );
+ let right_cast = ctx.expressions.append(
+ crate::Expression::As {
+ expr: right,
+ kind: crate::ScalarKind::Float,
+ convert: Some(size),
+ },
+ span,
+ );
+ let div = ctx.expressions.append(
+ crate::Expression::Binary {
+ op: crate::BinaryOperator::Divide,
+ left: left_cast,
+ right: right_cast,
+ },
+ span,
+ );
+ let floor = ctx.expressions.append(
+ crate::Expression::Math {
+ fun: crate::MathFunction::Floor,
+ arg: div,
+ arg1: None,
+ arg2: None,
+ arg3: None,
+ },
+ span,
+ );
+ let cast = ctx.expressions.append(
+ crate::Expression::As {
+ expr: floor,
+ kind,
+ convert: Some(size),
+ },
+ span,
+ );
+ let mult = ctx.expressions.append(
+ crate::Expression::Binary {
+ op: crate::BinaryOperator::Multiply,
+ left: cast,
+ right,
+ },
+ span,
+ );
+ let sub = ctx.expressions.append(
+ crate::Expression::Binary {
+ op: crate::BinaryOperator::Subtract,
+ left,
+ right: mult,
+ },
+ span,
+ );
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: sub,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::FMod => {
+ inst.expect(5)?;
+
+ // x - y * floor(x / y)
+
+ let start = self.data_offset;
+ let span = self.span_from_with_op(start);
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let p1_id = self.next()?;
+ let p2_id = self.next()?;
+
+ let p1_lexp = self.lookup_expression.lookup(p1_id)?;
+ let left = self.get_expr_handle(
+ p1_id,
+ p1_lexp,
+ ctx,
+ &mut emitter,
+ &mut block,
+ body_idx,
+ );
+ let p2_lexp = self.lookup_expression.lookup(p2_id)?;
+ let right = self.get_expr_handle(
+ p2_id,
+ p2_lexp,
+ ctx,
+ &mut emitter,
+ &mut block,
+ body_idx,
+ );
+
+ let div = ctx.expressions.append(
+ crate::Expression::Binary {
+ op: crate::BinaryOperator::Divide,
+ left,
+ right,
+ },
+ span,
+ );
+ let floor = ctx.expressions.append(
+ crate::Expression::Math {
+ fun: crate::MathFunction::Floor,
+ arg: div,
+ arg1: None,
+ arg2: None,
+ arg3: None,
+ },
+ span,
+ );
+ let mult = ctx.expressions.append(
+ crate::Expression::Binary {
+ op: crate::BinaryOperator::Multiply,
+ left: floor,
+ right,
+ },
+ span,
+ );
+ let sub = ctx.expressions.append(
+ crate::Expression::Binary {
+ op: crate::BinaryOperator::Subtract,
+ left,
+ right: mult,
+ },
+ span,
+ );
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: sub,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::VectorTimesScalar
+ | Op::VectorTimesMatrix
+ | Op::MatrixTimesScalar
+ | Op::MatrixTimesVector
+ | Op::MatrixTimesMatrix => {
+ inst.expect(5)?;
+ parse_expr_op!(crate::BinaryOperator::Multiply, BINARY)?;
+ }
+ Op::Transpose => {
+ inst.expect(4)?;
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let matrix_id = self.next()?;
+ let matrix_lexp = self.lookup_expression.lookup(matrix_id)?;
+ let matrix_handle = get_expr_handle!(matrix_id, matrix_lexp);
+ let expr = crate::Expression::Math {
+ fun: crate::MathFunction::Transpose,
+ arg: matrix_handle,
+ arg1: None,
+ arg2: None,
+ arg3: None,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::Dot => {
+ inst.expect(5)?;
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let left_id = self.next()?;
+ let right_id = self.next()?;
+ let left_lexp = self.lookup_expression.lookup(left_id)?;
+ let left_handle = get_expr_handle!(left_id, left_lexp);
+ let right_lexp = self.lookup_expression.lookup(right_id)?;
+ let right_handle = get_expr_handle!(right_id, right_lexp);
+ let expr = crate::Expression::Math {
+ fun: crate::MathFunction::Dot,
+ arg: left_handle,
+ arg1: Some(right_handle),
+ arg2: None,
+ arg3: None,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::BitFieldInsert => {
+ inst.expect(7)?;
+
+ let start = self.data_offset;
+ let span = self.span_from_with_op(start);
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let base_id = self.next()?;
+ let insert_id = self.next()?;
+ let offset_id = self.next()?;
+ let count_id = self.next()?;
+ let base_lexp = self.lookup_expression.lookup(base_id)?;
+ let base_handle = get_expr_handle!(base_id, base_lexp);
+ let insert_lexp = self.lookup_expression.lookup(insert_id)?;
+ let insert_handle = get_expr_handle!(insert_id, insert_lexp);
+ let offset_lexp = self.lookup_expression.lookup(offset_id)?;
+ let offset_handle = get_expr_handle!(offset_id, offset_lexp);
+ let offset_lookup_ty = self.lookup_type.lookup(offset_lexp.type_id)?;
+ let count_lexp = self.lookup_expression.lookup(count_id)?;
+ let count_handle = get_expr_handle!(count_id, count_lexp);
+ let count_lookup_ty = self.lookup_type.lookup(count_lexp.type_id)?;
+
+ let offset_kind = ctx.type_arena[offset_lookup_ty.handle]
+ .inner
+ .scalar_kind()
+ .unwrap();
+ let count_kind = ctx.type_arena[count_lookup_ty.handle]
+ .inner
+ .scalar_kind()
+ .unwrap();
+
+ let offset_cast_handle = if offset_kind != crate::ScalarKind::Uint {
+ ctx.expressions.append(
+ crate::Expression::As {
+ expr: offset_handle,
+ kind: crate::ScalarKind::Uint,
+ convert: None,
+ },
+ span,
+ )
+ } else {
+ offset_handle
+ };
+
+ let count_cast_handle = if count_kind != crate::ScalarKind::Uint {
+ ctx.expressions.append(
+ crate::Expression::As {
+ expr: count_handle,
+ kind: crate::ScalarKind::Uint,
+ convert: None,
+ },
+ span,
+ )
+ } else {
+ count_handle
+ };
+
+ let expr = crate::Expression::Math {
+ fun: crate::MathFunction::InsertBits,
+ arg: base_handle,
+ arg1: Some(insert_handle),
+ arg2: Some(offset_cast_handle),
+ arg3: Some(count_cast_handle),
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::BitFieldSExtract | Op::BitFieldUExtract => {
+ inst.expect(6)?;
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let base_id = self.next()?;
+ let offset_id = self.next()?;
+ let count_id = self.next()?;
+ let base_lexp = self.lookup_expression.lookup(base_id)?;
+ let base_handle = get_expr_handle!(base_id, base_lexp);
+ let offset_lexp = self.lookup_expression.lookup(offset_id)?;
+ let offset_handle = get_expr_handle!(offset_id, offset_lexp);
+ let offset_lookup_ty = self.lookup_type.lookup(offset_lexp.type_id)?;
+ let count_lexp = self.lookup_expression.lookup(count_id)?;
+ let count_handle = get_expr_handle!(count_id, count_lexp);
+ let count_lookup_ty = self.lookup_type.lookup(count_lexp.type_id)?;
+
+ let offset_kind = ctx.type_arena[offset_lookup_ty.handle]
+ .inner
+ .scalar_kind()
+ .unwrap();
+ let count_kind = ctx.type_arena[count_lookup_ty.handle]
+ .inner
+ .scalar_kind()
+ .unwrap();
+
+ let offset_cast_handle = if offset_kind != crate::ScalarKind::Uint {
+ ctx.expressions.append(
+ crate::Expression::As {
+ expr: offset_handle,
+ kind: crate::ScalarKind::Uint,
+ convert: None,
+ },
+ span,
+ )
+ } else {
+ offset_handle
+ };
+
+ let count_cast_handle = if count_kind != crate::ScalarKind::Uint {
+ ctx.expressions.append(
+ crate::Expression::As {
+ expr: count_handle,
+ kind: crate::ScalarKind::Uint,
+ convert: None,
+ },
+ span,
+ )
+ } else {
+ count_handle
+ };
+
+ let expr = crate::Expression::Math {
+ fun: crate::MathFunction::ExtractBits,
+ arg: base_handle,
+ arg1: Some(offset_cast_handle),
+ arg2: Some(count_cast_handle),
+ arg3: None,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::BitReverse | Op::BitCount => {
+ inst.expect(4)?;
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let base_id = self.next()?;
+ let base_lexp = self.lookup_expression.lookup(base_id)?;
+ let base_handle = get_expr_handle!(base_id, base_lexp);
+ let expr = crate::Expression::Math {
+ fun: match inst.op {
+ Op::BitReverse => crate::MathFunction::ReverseBits,
+ Op::BitCount => crate::MathFunction::CountOneBits,
+ _ => unreachable!(),
+ },
+ arg: base_handle,
+ arg1: None,
+ arg2: None,
+ arg3: None,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::OuterProduct => {
+ inst.expect(5)?;
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let left_id = self.next()?;
+ let right_id = self.next()?;
+ let left_lexp = self.lookup_expression.lookup(left_id)?;
+ let left_handle = get_expr_handle!(left_id, left_lexp);
+ let right_lexp = self.lookup_expression.lookup(right_id)?;
+ let right_handle = get_expr_handle!(right_id, right_lexp);
+ let expr = crate::Expression::Math {
+ fun: crate::MathFunction::Outer,
+ arg: left_handle,
+ arg1: Some(right_handle),
+ arg2: None,
+ arg3: None,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ // Bitwise instructions
+ Op::Not => {
+ inst.expect(4)?;
+ self.parse_expr_unary_op_sign_adjusted(
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ crate::UnaryOperator::Not,
+ )?;
+ }
+ Op::ShiftRightLogical => {
+ inst.expect(5)?;
+ //TODO: convert input and result to usigned
+ parse_expr_op!(crate::BinaryOperator::ShiftRight, SHIFT)?;
+ }
+ Op::ShiftRightArithmetic => {
+ inst.expect(5)?;
+ //TODO: convert input and result to signed
+ parse_expr_op!(crate::BinaryOperator::ShiftRight, SHIFT)?;
+ }
+ Op::ShiftLeftLogical => {
+ inst.expect(5)?;
+ parse_expr_op!(crate::BinaryOperator::ShiftLeft, SHIFT)?;
+ }
+ // Sampling
+ Op::Image => {
+ inst.expect(4)?;
+ self.parse_image_uncouple(block_id)?;
+ }
+ Op::SampledImage => {
+ inst.expect(5)?;
+ self.parse_image_couple()?;
+ }
+ Op::ImageWrite => {
+ let extra = inst.expect_at_least(4)?;
+ let stmt =
+ self.parse_image_write(extra, ctx, &mut emitter, &mut block, body_idx)?;
+ block.extend(emitter.finish(ctx.expressions));
+ block.push(stmt, span);
+ emitter.start(ctx.expressions);
+ }
+ Op::ImageFetch | Op::ImageRead => {
+ let extra = inst.expect_at_least(5)?;
+ self.parse_image_load(
+ extra,
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ )?;
+ }
+ Op::ImageSampleImplicitLod | Op::ImageSampleExplicitLod => {
+ let extra = inst.expect_at_least(5)?;
+ let options = image::SamplingOptions {
+ compare: false,
+ project: false,
+ };
+ self.parse_image_sample(
+ extra,
+ options,
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ )?;
+ }
+ Op::ImageSampleProjImplicitLod | Op::ImageSampleProjExplicitLod => {
+ let extra = inst.expect_at_least(5)?;
+ let options = image::SamplingOptions {
+ compare: false,
+ project: true,
+ };
+ self.parse_image_sample(
+ extra,
+ options,
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ )?;
+ }
+ Op::ImageSampleDrefImplicitLod | Op::ImageSampleDrefExplicitLod => {
+ let extra = inst.expect_at_least(6)?;
+ let options = image::SamplingOptions {
+ compare: true,
+ project: false,
+ };
+ self.parse_image_sample(
+ extra,
+ options,
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ )?;
+ }
+ Op::ImageSampleProjDrefImplicitLod | Op::ImageSampleProjDrefExplicitLod => {
+ let extra = inst.expect_at_least(6)?;
+ let options = image::SamplingOptions {
+ compare: true,
+ project: true,
+ };
+ self.parse_image_sample(
+ extra,
+ options,
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ )?;
+ }
+ Op::ImageQuerySize => {
+ inst.expect(4)?;
+ self.parse_image_query_size(
+ false,
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ )?;
+ }
+ Op::ImageQuerySizeLod => {
+ inst.expect(5)?;
+ self.parse_image_query_size(
+ true,
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ )?;
+ }
+ Op::ImageQueryLevels => {
+ inst.expect(4)?;
+ self.parse_image_query_other(
+ crate::ImageQuery::NumLevels,
+ ctx.expressions,
+ block_id,
+ )?;
+ }
+ Op::ImageQuerySamples => {
+ inst.expect(4)?;
+ self.parse_image_query_other(
+ crate::ImageQuery::NumSamples,
+ ctx.expressions,
+ block_id,
+ )?;
+ }
+ // other ops
+ Op::Select => {
+ inst.expect(6)?;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let condition = self.next()?;
+ let o1_id = self.next()?;
+ let o2_id = self.next()?;
+
+ let cond_lexp = self.lookup_expression.lookup(condition)?;
+ let cond_handle = get_expr_handle!(condition, cond_lexp);
+ let o1_lexp = self.lookup_expression.lookup(o1_id)?;
+ let o1_handle = get_expr_handle!(o1_id, o1_lexp);
+ let o2_lexp = self.lookup_expression.lookup(o2_id)?;
+ let o2_handle = get_expr_handle!(o2_id, o2_lexp);
+
+ let expr = crate::Expression::Select {
+ condition: cond_handle,
+ accept: o1_handle,
+ reject: o2_handle,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::VectorShuffle => {
+ inst.expect_at_least(5)?;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let v1_id = self.next()?;
+ let v2_id = self.next()?;
+
+ let v1_lexp = self.lookup_expression.lookup(v1_id)?;
+ let v1_lty = self.lookup_type.lookup(v1_lexp.type_id)?;
+ let v1_handle = get_expr_handle!(v1_id, v1_lexp);
+ let n1 = match ctx.type_arena[v1_lty.handle].inner {
+ crate::TypeInner::Vector { size, .. } => size as u32,
+ _ => return Err(Error::InvalidInnerType(v1_lexp.type_id)),
+ };
+ let v2_lexp = self.lookup_expression.lookup(v2_id)?;
+ let v2_lty = self.lookup_type.lookup(v2_lexp.type_id)?;
+ let v2_handle = get_expr_handle!(v2_id, v2_lexp);
+ let n2 = match ctx.type_arena[v2_lty.handle].inner {
+ crate::TypeInner::Vector { size, .. } => size as u32,
+ _ => return Err(Error::InvalidInnerType(v2_lexp.type_id)),
+ };
+
+ self.temp_bytes.clear();
+ let mut max_component = 0;
+ for _ in 5..inst.wc as usize {
+ let mut index = self.next()?;
+ if index == u32::MAX {
+ // treat Undefined as X
+ index = 0;
+ }
+ max_component = max_component.max(index);
+ self.temp_bytes.push(index as u8);
+ }
+
+ // Check for swizzle first.
+ let expr = if max_component < n1 {
+ use crate::SwizzleComponent as Sc;
+ let size = match self.temp_bytes.len() {
+ 2 => crate::VectorSize::Bi,
+ 3 => crate::VectorSize::Tri,
+ _ => crate::VectorSize::Quad,
+ };
+ let mut pattern = [Sc::X; 4];
+ for (pat, index) in pattern.iter_mut().zip(self.temp_bytes.drain(..)) {
+ *pat = match index {
+ 0 => Sc::X,
+ 1 => Sc::Y,
+ 2 => Sc::Z,
+ _ => Sc::W,
+ };
+ }
+ crate::Expression::Swizzle {
+ size,
+ vector: v1_handle,
+ pattern,
+ }
+ } else {
+ // Fall back to access + compose
+ let mut components = Vec::with_capacity(self.temp_bytes.len());
+ for index in self.temp_bytes.drain(..).map(|i| i as u32) {
+ let expr = if index < n1 {
+ crate::Expression::AccessIndex {
+ base: v1_handle,
+ index,
+ }
+ } else if index < n1 + n2 {
+ crate::Expression::AccessIndex {
+ base: v2_handle,
+ index: index - n1,
+ }
+ } else {
+ return Err(Error::InvalidAccessIndex(index));
+ };
+ components.push(ctx.expressions.append(expr, span));
+ }
+ crate::Expression::Compose {
+ ty: self.lookup_type.lookup(result_type_id)?.handle,
+ components,
+ }
+ };
+
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::Bitcast
+ | Op::ConvertSToF
+ | Op::ConvertUToF
+ | Op::ConvertFToU
+ | Op::ConvertFToS
+ | Op::FConvert
+ | Op::UConvert
+ | Op::SConvert => {
+ inst.expect(4)?;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let value_id = self.next()?;
+
+ let value_lexp = self.lookup_expression.lookup(value_id)?;
+ let ty_lookup = self.lookup_type.lookup(result_type_id)?;
+ let (kind, width) = match ctx.type_arena[ty_lookup.handle].inner {
+ crate::TypeInner::Scalar { kind, width }
+ | crate::TypeInner::Vector { kind, width, .. } => (kind, width),
+ crate::TypeInner::Matrix { width, .. } => (crate::ScalarKind::Float, width),
+ _ => return Err(Error::InvalidAsType(ty_lookup.handle)),
+ };
+
+ let expr = crate::Expression::As {
+ expr: get_expr_handle!(value_id, value_lexp),
+ kind,
+ convert: if kind == crate::ScalarKind::Bool {
+ Some(crate::BOOL_WIDTH)
+ } else if inst.op == Op::Bitcast {
+ None
+ } else {
+ Some(width)
+ },
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::FunctionCall => {
+ inst.expect_at_least(4)?;
+ block.extend(emitter.finish(ctx.expressions));
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let func_id = self.next()?;
+
+ let mut arguments = Vec::with_capacity(inst.wc as usize - 4);
+ for _ in 0..arguments.capacity() {
+ let arg_id = self.next()?;
+ let lexp = self.lookup_expression.lookup(arg_id)?;
+ arguments.push(get_expr_handle!(arg_id, lexp));
+ }
+
+ // We just need an unique handle here, nothing more.
+ let function = self.add_call(ctx.function_id, func_id);
+
+ let result = if self.lookup_void_type == Some(result_type_id) {
+ None
+ } else {
+ let expr_handle = ctx
+ .expressions
+ .append(crate::Expression::CallResult(function), span);
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: expr_handle,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Some(expr_handle)
+ };
+ block.push(
+ crate::Statement::Call {
+ function,
+ arguments,
+ result,
+ },
+ span,
+ );
+ emitter.start(ctx.expressions);
+ }
+ Op::ExtInst => {
+ use crate::MathFunction as Mf;
+ use spirv::GLOp as Glo;
+
+ let base_wc = 5;
+ inst.expect_at_least(base_wc)?;
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let set_id = self.next()?;
+ if Some(set_id) != self.ext_glsl_id {
+ return Err(Error::UnsupportedExtInstSet(set_id));
+ }
+ let inst_id = self.next()?;
+ let gl_op = Glo::from_u32(inst_id).ok_or(Error::UnsupportedExtInst(inst_id))?;
+
+ let fun = match gl_op {
+ Glo::Round => Mf::Round,
+ Glo::RoundEven => Mf::Round,
+ Glo::Trunc => Mf::Trunc,
+ Glo::FAbs | Glo::SAbs => Mf::Abs,
+ Glo::FSign | Glo::SSign => Mf::Sign,
+ Glo::Floor => Mf::Floor,
+ Glo::Ceil => Mf::Ceil,
+ Glo::Fract => Mf::Fract,
+ Glo::Sin => Mf::Sin,
+ Glo::Cos => Mf::Cos,
+ Glo::Tan => Mf::Tan,
+ Glo::Asin => Mf::Asin,
+ Glo::Acos => Mf::Acos,
+ Glo::Atan => Mf::Atan,
+ Glo::Sinh => Mf::Sinh,
+ Glo::Cosh => Mf::Cosh,
+ Glo::Tanh => Mf::Tanh,
+ Glo::Atan2 => Mf::Atan2,
+ Glo::Asinh => Mf::Asinh,
+ Glo::Acosh => Mf::Acosh,
+ Glo::Atanh => Mf::Atanh,
+ Glo::Radians => Mf::Radians,
+ Glo::Degrees => Mf::Degrees,
+ Glo::Pow => Mf::Pow,
+ Glo::Exp => Mf::Exp,
+ Glo::Log => Mf::Log,
+ Glo::Exp2 => Mf::Exp2,
+ Glo::Log2 => Mf::Log2,
+ Glo::Sqrt => Mf::Sqrt,
+ Glo::InverseSqrt => Mf::InverseSqrt,
+ Glo::MatrixInverse => Mf::Inverse,
+ Glo::Determinant => Mf::Determinant,
+ Glo::Modf => Mf::Modf,
+ Glo::FMin | Glo::UMin | Glo::SMin | Glo::NMin => Mf::Min,
+ Glo::FMax | Glo::UMax | Glo::SMax | Glo::NMax => Mf::Max,
+ Glo::FClamp | Glo::UClamp | Glo::SClamp | Glo::NClamp => Mf::Clamp,
+ Glo::FMix => Mf::Mix,
+ Glo::Step => Mf::Step,
+ Glo::SmoothStep => Mf::SmoothStep,
+ Glo::Fma => Mf::Fma,
+ Glo::Frexp => Mf::Frexp, //TODO: FrexpStruct?
+ Glo::Ldexp => Mf::Ldexp,
+ Glo::Length => Mf::Length,
+ Glo::Distance => Mf::Distance,
+ Glo::Cross => Mf::Cross,
+ Glo::Normalize => Mf::Normalize,
+ Glo::FaceForward => Mf::FaceForward,
+ Glo::Reflect => Mf::Reflect,
+ Glo::Refract => Mf::Refract,
+ Glo::PackUnorm4x8 => Mf::Pack4x8unorm,
+ Glo::PackSnorm4x8 => Mf::Pack4x8snorm,
+ Glo::PackHalf2x16 => Mf::Pack2x16float,
+ Glo::PackUnorm2x16 => Mf::Pack2x16unorm,
+ Glo::PackSnorm2x16 => Mf::Pack2x16snorm,
+ Glo::UnpackUnorm4x8 => Mf::Unpack4x8unorm,
+ Glo::UnpackSnorm4x8 => Mf::Unpack4x8snorm,
+ Glo::UnpackHalf2x16 => Mf::Unpack2x16float,
+ Glo::UnpackUnorm2x16 => Mf::Unpack2x16unorm,
+ Glo::UnpackSnorm2x16 => Mf::Unpack2x16snorm,
+ Glo::FindILsb => Mf::FindLsb,
+ Glo::FindUMsb | Glo::FindSMsb => Mf::FindMsb,
+ _ => return Err(Error::UnsupportedExtInst(inst_id)),
+ };
+
+ let arg_count = fun.argument_count();
+ inst.expect(base_wc + arg_count as u16)?;
+ let arg = {
+ let arg_id = self.next()?;
+ let lexp = self.lookup_expression.lookup(arg_id)?;
+ get_expr_handle!(arg_id, lexp)
+ };
+ let arg1 = if arg_count > 1 {
+ let arg_id = self.next()?;
+ let lexp = self.lookup_expression.lookup(arg_id)?;
+ Some(get_expr_handle!(arg_id, lexp))
+ } else {
+ None
+ };
+ let arg2 = if arg_count > 2 {
+ let arg_id = self.next()?;
+ let lexp = self.lookup_expression.lookup(arg_id)?;
+ Some(get_expr_handle!(arg_id, lexp))
+ } else {
+ None
+ };
+ let arg3 = if arg_count > 3 {
+ let arg_id = self.next()?;
+ let lexp = self.lookup_expression.lookup(arg_id)?;
+ Some(get_expr_handle!(arg_id, lexp))
+ } else {
+ None
+ };
+
+ let expr = crate::Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ // Relational and Logical Instructions
+ Op::LogicalNot => {
+ inst.expect(4)?;
+ parse_expr_op!(crate::UnaryOperator::Not, UNARY)?;
+ }
+ Op::LogicalOr => {
+ inst.expect(5)?;
+ parse_expr_op!(crate::BinaryOperator::LogicalOr, BINARY)?;
+ }
+ Op::LogicalAnd => {
+ inst.expect(5)?;
+ parse_expr_op!(crate::BinaryOperator::LogicalAnd, BINARY)?;
+ }
+ Op::SGreaterThan | Op::SGreaterThanEqual | Op::SLessThan | Op::SLessThanEqual => {
+ inst.expect(5)?;
+ self.parse_expr_int_comparison(
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ map_binary_operator(inst.op)?,
+ crate::ScalarKind::Sint,
+ )?;
+ }
+ Op::UGreaterThan | Op::UGreaterThanEqual | Op::ULessThan | Op::ULessThanEqual => {
+ inst.expect(5)?;
+ self.parse_expr_int_comparison(
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ map_binary_operator(inst.op)?,
+ crate::ScalarKind::Uint,
+ )?;
+ }
+ Op::FOrdEqual
+ | Op::FUnordEqual
+ | Op::FOrdNotEqual
+ | Op::FUnordNotEqual
+ | Op::FOrdLessThan
+ | Op::FUnordLessThan
+ | Op::FOrdGreaterThan
+ | Op::FUnordGreaterThan
+ | Op::FOrdLessThanEqual
+ | Op::FUnordLessThanEqual
+ | Op::FOrdGreaterThanEqual
+ | Op::FUnordGreaterThanEqual
+ | Op::LogicalEqual
+ | Op::LogicalNotEqual => {
+ inst.expect(5)?;
+ let operator = map_binary_operator(inst.op)?;
+ parse_expr_op!(operator, BINARY)?;
+ }
+ Op::Any | Op::All | Op::IsNan | Op::IsInf | Op::IsFinite | Op::IsNormal => {
+ inst.expect(4)?;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let arg_id = self.next()?;
+
+ let arg_lexp = self.lookup_expression.lookup(arg_id)?;
+ let arg_handle = get_expr_handle!(arg_id, arg_lexp);
+
+ let expr = crate::Expression::Relational {
+ fun: map_relational_fun(inst.op)?,
+ argument: arg_handle,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::Kill => {
+ inst.expect(1)?;
+ break Some(crate::Statement::Kill);
+ }
+ Op::Unreachable => {
+ inst.expect(1)?;
+ break None;
+ }
+ Op::Return => {
+ inst.expect(1)?;
+ break Some(crate::Statement::Return { value: None });
+ }
+ Op::ReturnValue => {
+ inst.expect(2)?;
+ let value_id = self.next()?;
+ let value_lexp = self.lookup_expression.lookup(value_id)?;
+ let value_handle = get_expr_handle!(value_id, value_lexp);
+ break Some(crate::Statement::Return {
+ value: Some(value_handle),
+ });
+ }
+ Op::Branch => {
+ inst.expect(2)?;
+ let target_id = self.next()?;
+
+ // If this is a branch to a merge or continue block,
+ // then that ends the current body.
+ if let Some(info) = ctx.mergers.get(&target_id) {
+ block.extend(emitter.finish(ctx.expressions));
+ ctx.blocks.insert(block_id, block);
+ let body = &mut ctx.bodies[body_idx];
+ body.data.push(BodyFragment::BlockId(block_id));
+
+ merger(body, info);
+
+ return Ok(());
+ }
+
+ // Since the target of the branch has no merge information,
+ // this must be the only branch to that block. This means
+ // we can treat it as an extension of the current `Body`.
+ //
+ // NOTE: it's possible that another branch was already made to this block
+ // setting the body index in which case it SHOULD NOT be overridden.
+ // For example a switch with falltrough, the OpSwitch will set the body to
+ // the respective case and the case may branch to another case in which case
+ // the body index shouldn't be changed
+ ctx.body_for_label.entry(target_id).or_insert(body_idx);
+
+ break None;
+ }
+ Op::BranchConditional => {
+ inst.expect_at_least(4)?;
+
+ let condition = {
+ let condition_id = self.next()?;
+ let lexp = self.lookup_expression.lookup(condition_id)?;
+ get_expr_handle!(condition_id, lexp)
+ };
+
+ block.extend(emitter.finish(ctx.expressions));
+ ctx.blocks.insert(block_id, block);
+ let body = &mut ctx.bodies[body_idx];
+ body.data.push(BodyFragment::BlockId(block_id));
+
+ let true_id = self.next()?;
+ let false_id = self.next()?;
+
+ let same_target = true_id == false_id;
+
+ // Start a body block for the `accept` branch.
+ let accept = ctx.bodies.len();
+ let mut accept_block = Body::with_parent(body_idx);
+
+ // If the `OpBranchConditional`target is somebody else's
+ // merge or continue block, then put a `Break` or `Continue`
+ // statement in this new body block.
+ if let Some(info) = ctx.mergers.get(&true_id) {
+ merger(
+ match same_target {
+ true => &mut ctx.bodies[body_idx],
+ false => &mut accept_block,
+ },
+ info,
+ )
+ } else {
+ // Note the body index for the block we're branching to.
+ let prev = ctx.body_for_label.insert(
+ true_id,
+ match same_target {
+ true => body_idx,
+ false => accept,
+ },
+ );
+ debug_assert!(prev.is_none());
+ }
+
+ if same_target {
+ return Ok(());
+ }
+
+ ctx.bodies.push(accept_block);
+
+ // Handle the `reject` branch just like the `accept` block.
+ let reject = ctx.bodies.len();
+ let mut reject_block = Body::with_parent(body_idx);
+
+ if let Some(info) = ctx.mergers.get(&false_id) {
+ merger(&mut reject_block, info)
+ } else {
+ let prev = ctx.body_for_label.insert(false_id, reject);
+ debug_assert!(prev.is_none());
+ }
+
+ ctx.bodies.push(reject_block);
+
+ let body = &mut ctx.bodies[body_idx];
+ body.data.push(BodyFragment::If {
+ condition,
+ accept,
+ reject,
+ });
+
+ // Consume branch weights
+ for _ in 4..inst.wc {
+ let _ = self.next()?;
+ }
+
+ return Ok(());
+ }
+ Op::Switch => {
+ inst.expect_at_least(3)?;
+ let selector = self.next()?;
+ let default_id = self.next()?;
+
+ // If the previous instruction was a `OpSelectionMerge` then we must
+ // promote the `MergeBlockInformation` to a `SwitchMerge`
+ if let Some(merge) = selection_merge_block {
+ ctx.mergers
+ .insert(merge, MergeBlockInformation::SwitchMerge);
+ }
+
+ let default = ctx.bodies.len();
+ ctx.bodies.push(Body::with_parent(body_idx));
+ ctx.body_for_label.entry(default_id).or_insert(default);
+
+ let selector_lexp = &self.lookup_expression[&selector];
+ let selector_lty = self.lookup_type.lookup(selector_lexp.type_id)?;
+ let selector_handle = get_expr_handle!(selector, selector_lexp);
+ let selector = match ctx.type_arena[selector_lty.handle].inner {
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Uint,
+ width: _,
+ } => {
+ // IR expects a signed integer, so do a bitcast
+ ctx.expressions.append(
+ crate::Expression::As {
+ kind: crate::ScalarKind::Sint,
+ expr: selector_handle,
+ convert: None,
+ },
+ span,
+ )
+ }
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Sint,
+ width: _,
+ } => selector_handle,
+ ref other => unimplemented!("Unexpected selector {:?}", other),
+ };
+
+ // Clear past switch cases to prevent them from entering this one
+ self.switch_cases.clear();
+
+ for _ in 0..(inst.wc - 3) / 2 {
+ let literal = self.next()?;
+ let target = self.next()?;
+
+ let case_body_idx = ctx.bodies.len();
+
+ // Check if any previous case already used this target block id, if so
+ // group them together to reorder them later so that no weird
+ // falltrough cases happen.
+ if let Some(&mut (_, ref mut literals)) = self.switch_cases.get_mut(&target)
+ {
+ literals.push(literal as i32);
+ continue;
+ }
+
+ let mut body = Body::with_parent(body_idx);
+
+ if let Some(info) = ctx.mergers.get(&target) {
+ merger(&mut body, info);
+ }
+
+ ctx.bodies.push(body);
+ ctx.body_for_label.entry(target).or_insert(case_body_idx);
+
+ // Register this target block id as already having been processed and
+ // the respective body index assigned and the first case value
+ self.switch_cases
+ .insert(target, (case_body_idx, vec![literal as i32]));
+ }
+
+ // Loop trough the collected target blocks creating a new case for each
+ // literal pointing to it, only one case will have the true body and all the
+ // others will be empty falltrough so that they all execute the same body
+ // without duplicating code.
+ //
+ // Since `switch_cases` is an indexmap the order of insertion is preserved
+ // this is needed because spir-v defines falltrough order in the switch
+ // instruction.
+ let mut cases = Vec::with_capacity((inst.wc as usize - 3) / 2);
+ for &(case_body_idx, ref literals) in self.switch_cases.values() {
+ let value = literals[0];
+
+ for &literal in literals.iter().skip(1) {
+ let empty_body_idx = ctx.bodies.len();
+ let body = Body::with_parent(body_idx);
+
+ ctx.bodies.push(body);
+
+ cases.push((literal, empty_body_idx));
+ }
+
+ cases.push((value, case_body_idx));
+ }
+
+ block.extend(emitter.finish(ctx.expressions));
+
+ let body = &mut ctx.bodies[body_idx];
+ ctx.blocks.insert(block_id, block);
+ // Make sure the vector has space for at least two more allocations
+ body.data.reserve(2);
+ body.data.push(BodyFragment::BlockId(block_id));
+ body.data.push(BodyFragment::Switch {
+ selector,
+ cases,
+ default,
+ });
+
+ return Ok(());
+ }
+ Op::SelectionMerge => {
+ inst.expect(3)?;
+ let merge_block_id = self.next()?;
+ // TODO: Selection Control Mask
+ let _selection_control = self.next()?;
+
+ // Indicate that the merge block is a continuation of the
+ // current `Body`.
+ ctx.body_for_label.entry(merge_block_id).or_insert(body_idx);
+
+ // Let subsequent branches to the merge block know that
+ // they've reached the end of the selection construct.
+ ctx.mergers
+ .insert(merge_block_id, MergeBlockInformation::SelectionMerge);
+
+ selection_merge_block = Some(merge_block_id);
+ }
+ Op::LoopMerge => {
+ inst.expect_at_least(4)?;
+ let merge_block_id = self.next()?;
+ let continuing = self.next()?;
+
+ // TODO: Loop Control Parameters
+ for _ in 0..inst.wc - 3 {
+ self.next()?;
+ }
+
+ // Indicate that the merge block is a continuation of the
+ // current `Body`.
+ ctx.body_for_label.entry(merge_block_id).or_insert(body_idx);
+ // Let subsequent branches to the merge block know that
+ // they're `Break` statements.
+ ctx.mergers
+ .insert(merge_block_id, MergeBlockInformation::LoopMerge);
+
+ let loop_body_idx = ctx.bodies.len();
+ ctx.bodies.push(Body::with_parent(body_idx));
+
+ let continue_idx = ctx.bodies.len();
+ // The continue block inherits the scope of the loop body
+ ctx.bodies.push(Body::with_parent(loop_body_idx));
+ ctx.body_for_label.entry(continuing).or_insert(continue_idx);
+ // Let subsequent branches to the continue block know that
+ // they're `Continue` statements.
+ ctx.mergers
+ .insert(continuing, MergeBlockInformation::LoopContinue);
+
+ // The loop header always belongs to the loop body
+ ctx.body_for_label.insert(block_id, loop_body_idx);
+
+ let parent_body = &mut ctx.bodies[body_idx];
+ parent_body.data.push(BodyFragment::Loop {
+ body: loop_body_idx,
+ continuing: continue_idx,
+ });
+ body_idx = loop_body_idx;
+ }
+ Op::DPdxCoarse => {
+ parse_expr_op!(
+ crate::DerivativeAxis::X,
+ crate::DerivativeControl::Coarse,
+ DERIVATIVE
+ )?;
+ }
+ Op::DPdyCoarse => {
+ parse_expr_op!(
+ crate::DerivativeAxis::Y,
+ crate::DerivativeControl::Coarse,
+ DERIVATIVE
+ )?;
+ }
+ Op::FwidthCoarse => {
+ parse_expr_op!(
+ crate::DerivativeAxis::Width,
+ crate::DerivativeControl::Coarse,
+ DERIVATIVE
+ )?;
+ }
+ Op::DPdxFine => {
+ parse_expr_op!(
+ crate::DerivativeAxis::X,
+ crate::DerivativeControl::Fine,
+ DERIVATIVE
+ )?;
+ }
+ Op::DPdyFine => {
+ parse_expr_op!(
+ crate::DerivativeAxis::Y,
+ crate::DerivativeControl::Fine,
+ DERIVATIVE
+ )?;
+ }
+ Op::FwidthFine => {
+ parse_expr_op!(
+ crate::DerivativeAxis::Width,
+ crate::DerivativeControl::Fine,
+ DERIVATIVE
+ )?;
+ }
+ Op::DPdx => {
+ parse_expr_op!(
+ crate::DerivativeAxis::X,
+ crate::DerivativeControl::None,
+ DERIVATIVE
+ )?;
+ }
+ Op::DPdy => {
+ parse_expr_op!(
+ crate::DerivativeAxis::Y,
+ crate::DerivativeControl::None,
+ DERIVATIVE
+ )?;
+ }
+ Op::Fwidth => {
+ parse_expr_op!(
+ crate::DerivativeAxis::Width,
+ crate::DerivativeControl::None,
+ DERIVATIVE
+ )?;
+ }
+ Op::ArrayLength => {
+ inst.expect(5)?;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let structure_id = self.next()?;
+ let member_index = self.next()?;
+
+ // We're assuming that the validation pass, if it's run, will catch if the
+ // wrong types or parameters are supplied here.
+
+ let structure_ptr = self.lookup_expression.lookup(structure_id)?;
+ let structure_handle = get_expr_handle!(structure_id, structure_ptr);
+
+ let member_ptr = ctx.expressions.append(
+ crate::Expression::AccessIndex {
+ base: structure_handle,
+ index: member_index,
+ },
+ span,
+ );
+
+ let length = ctx
+ .expressions
+ .append(crate::Expression::ArrayLength(member_ptr), span);
+
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: length,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::CopyMemory => {
+ inst.expect_at_least(3)?;
+ let target_id = self.next()?;
+ let source_id = self.next()?;
+ let _memory_access = if inst.wc != 3 {
+ inst.expect(4)?;
+ spirv::MemoryAccess::from_bits(self.next()?)
+ .ok_or(Error::InvalidParameter(Op::CopyMemory))?
+ } else {
+ spirv::MemoryAccess::NONE
+ };
+
+ // TODO: check if the source and target types are the same?
+ let target = self.lookup_expression.lookup(target_id)?;
+ let target_handle = get_expr_handle!(target_id, target);
+ let source = self.lookup_expression.lookup(source_id)?;
+ let source_handle = get_expr_handle!(source_id, source);
+
+ // This operation is practically the same as loading and then storing, I think.
+ let value_expr = ctx.expressions.append(
+ crate::Expression::Load {
+ pointer: source_handle,
+ },
+ span,
+ );
+
+ block.extend(emitter.finish(ctx.expressions));
+ block.push(
+ crate::Statement::Store {
+ pointer: target_handle,
+ value: value_expr,
+ },
+ span,
+ );
+
+ emitter.start(ctx.expressions);
+ }
+ Op::ControlBarrier => {
+ inst.expect(4)?;
+ let exec_scope_id = self.next()?;
+ let _mem_scope_raw = self.next()?;
+ let semantics_id = self.next()?;
+ let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?;
+ let semantics_const = self.lookup_constant.lookup(semantics_id)?;
+ let exec_scope = match ctx.const_arena[exec_scope_const.handle].inner {
+ crate::ConstantInner::Scalar {
+ value: crate::ScalarValue::Uint(raw),
+ width: _,
+ } => raw as u32,
+ _ => return Err(Error::InvalidBarrierScope(exec_scope_id)),
+ };
+ let semantics = match ctx.const_arena[semantics_const.handle].inner {
+ crate::ConstantInner::Scalar {
+ value: crate::ScalarValue::Uint(raw),
+ width: _,
+ } => raw as u32,
+ _ => return Err(Error::InvalidBarrierMemorySemantics(semantics_id)),
+ };
+ if exec_scope == spirv::Scope::Workgroup as u32 {
+ let mut flags = crate::Barrier::empty();
+ flags.set(
+ crate::Barrier::STORAGE,
+ semantics & spirv::MemorySemantics::UNIFORM_MEMORY.bits() != 0,
+ );
+ flags.set(
+ crate::Barrier::WORK_GROUP,
+ semantics
+ & (spirv::MemorySemantics::SUBGROUP_MEMORY
+ | spirv::MemorySemantics::WORKGROUP_MEMORY)
+ .bits()
+ != 0,
+ );
+ block.push(crate::Statement::Barrier(flags), span);
+ } else {
+ log::warn!("Unsupported barrier execution scope: {}", exec_scope);
+ }
+ }
+ Op::CopyObject => {
+ inst.expect(4)?;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let operand_id = self.next()?;
+
+ let lookup = self.lookup_expression.lookup(operand_id)?;
+ let handle = get_expr_handle!(operand_id, lookup);
+
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ _ => return Err(Error::UnsupportedInstruction(self.state, inst.op)),
+ }
+ };
+
+ block.extend(emitter.finish(ctx.expressions));
+ if let Some(stmt) = terminator {
+ block.push(stmt, crate::Span::default());
+ }
+
+ // Save this block fragment in `block_ctx.blocks`, and mark it to be
+ // incorporated into the current body at `Statement` assembly time.
+ ctx.blocks.insert(block_id, block);
+ let body = &mut ctx.bodies[body_idx];
+ body.data.push(BodyFragment::BlockId(block_id));
+ Ok(())
+ }
+
+ fn make_expression_storage(
+ &mut self,
+ globals: &Arena<crate::GlobalVariable>,
+ constants: &Arena<crate::Constant>,
+ ) -> Arena<crate::Expression> {
+ let mut expressions = Arena::new();
+ #[allow(clippy::panic)]
+ {
+ assert!(self.lookup_expression.is_empty());
+ }
+ // register global variables
+ for (&id, var) in self.lookup_variable.iter() {
+ let span = globals.get_span(var.handle);
+ let handle = expressions.append(crate::Expression::GlobalVariable(var.handle), span);
+ self.lookup_expression.insert(
+ id,
+ LookupExpression {
+ type_id: var.type_id,
+ handle,
+ // Setting this to an invalid id will cause get_expr_handle
+ // to default to the main body making sure no load/stores
+ // are added.
+ block_id: 0,
+ },
+ );
+ }
+ // register special constants
+ self.index_constant_expressions.clear();
+ for &con_handle in self.index_constants.iter() {
+ let span = constants.get_span(con_handle);
+ let handle = expressions.append(crate::Expression::Constant(con_handle), span);
+ self.index_constant_expressions.push(handle);
+ }
+ // register constants
+ for (&id, con) in self.lookup_constant.iter() {
+ let span = constants.get_span(con.handle);
+ let handle = expressions.append(crate::Expression::Constant(con.handle), span);
+ self.lookup_expression.insert(
+ id,
+ LookupExpression {
+ type_id: con.type_id,
+ handle,
+ // Setting this to an invalid id will cause get_expr_handle
+ // to default to the main body making sure no load/stores
+ // are added.
+ block_id: 0,
+ },
+ );
+ }
+ // done
+ expressions
+ }
+
+ fn switch(&mut self, state: ModuleState, op: spirv::Op) -> Result<(), Error> {
+ if state < self.state {
+ Err(Error::UnsupportedInstruction(self.state, op))
+ } else {
+ self.state = state;
+ Ok(())
+ }
+ }
+
+ /// Walk the statement tree and patch it in the following cases:
+ /// 1. Function call targets are replaced by `deferred_function_calls` map
+ fn patch_statements(
+ &mut self,
+ statements: &mut crate::Block,
+ expressions: &mut Arena<crate::Expression>,
+ fun_parameter_sampling: &mut [image::SamplingFlags],
+ ) -> Result<(), Error> {
+ use crate::Statement as S;
+ let mut i = 0usize;
+ while i < statements.len() {
+ match statements[i] {
+ S::Emit(_) => {}
+ S::Block(ref mut block) => {
+ self.patch_statements(block, expressions, fun_parameter_sampling)?;
+ }
+ S::If {
+ condition: _,
+ ref mut accept,
+ ref mut reject,
+ } => {
+ self.patch_statements(reject, expressions, fun_parameter_sampling)?;
+ self.patch_statements(accept, expressions, fun_parameter_sampling)?;
+ }
+ S::Switch {
+ selector: _,
+ ref mut cases,
+ } => {
+ for case in cases.iter_mut() {
+ self.patch_statements(&mut case.body, expressions, fun_parameter_sampling)?;
+ }
+ }
+ S::Loop {
+ ref mut body,
+ ref mut continuing,
+ break_if: _,
+ } => {
+ self.patch_statements(body, expressions, fun_parameter_sampling)?;
+ self.patch_statements(continuing, expressions, fun_parameter_sampling)?;
+ }
+ S::Break
+ | S::Continue
+ | S::Return { .. }
+ | S::Kill
+ | S::Barrier(_)
+ | S::Store { .. }
+ | S::ImageStore { .. }
+ | S::Atomic { .. }
+ | S::RayQuery { .. } => {}
+ S::Call {
+ function: ref mut callee,
+ ref arguments,
+ ..
+ } => {
+ let fun_id = self.deferred_function_calls[callee.index()];
+ let fun_lookup = self.lookup_function.lookup(fun_id)?;
+ *callee = fun_lookup.handle;
+
+ // Patch sampling flags
+ for (arg_index, arg) in arguments.iter().enumerate() {
+ let flags = match fun_lookup.parameters_sampling.get(arg_index) {
+ Some(&flags) if !flags.is_empty() => flags,
+ _ => continue,
+ };
+
+ match expressions[*arg] {
+ crate::Expression::GlobalVariable(handle) => {
+ if let Some(sampling) = self.handle_sampling.get_mut(&handle) {
+ *sampling |= flags
+ }
+ }
+ crate::Expression::FunctionArgument(i) => {
+ fun_parameter_sampling[i as usize] |= flags;
+ }
+ ref other => return Err(Error::InvalidGlobalVar(other.clone())),
+ }
+ }
+ }
+ }
+ i += 1;
+ }
+ Ok(())
+ }
+
+ fn patch_function(
+ &mut self,
+ handle: Option<Handle<crate::Function>>,
+ fun: &mut crate::Function,
+ ) -> Result<(), Error> {
+ // Note: this search is a bit unfortunate
+ let (fun_id, mut parameters_sampling) = match handle {
+ Some(h) => {
+ let (&fun_id, lookup) = self
+ .lookup_function
+ .iter_mut()
+ .find(|&(_, ref lookup)| lookup.handle == h)
+ .unwrap();
+ (fun_id, mem::take(&mut lookup.parameters_sampling))
+ }
+ None => (0, Vec::new()),
+ };
+
+ for (_, expr) in fun.expressions.iter_mut() {
+ if let crate::Expression::CallResult(ref mut function) = *expr {
+ let fun_id = self.deferred_function_calls[function.index()];
+ *function = self.lookup_function.lookup(fun_id)?.handle;
+ }
+ }
+
+ self.patch_statements(
+ &mut fun.body,
+ &mut fun.expressions,
+ &mut parameters_sampling,
+ )?;
+
+ if let Some(lookup) = self.lookup_function.get_mut(&fun_id) {
+ lookup.parameters_sampling = parameters_sampling;
+ }
+ Ok(())
+ }
+
+ pub fn parse(mut self) -> Result<crate::Module, Error> {
+ let mut module = {
+ if self.next()? != spirv::MAGIC_NUMBER {
+ return Err(Error::InvalidHeader);
+ }
+ let version_raw = self.next()?;
+ let generator = self.next()?;
+ let _bound = self.next()?;
+ let _schema = self.next()?;
+ log::info!("Generated by {} version {:x}", generator, version_raw);
+ crate::Module::default()
+ };
+
+ // register indexing constants
+ self.index_constants.clear();
+ for i in 0..4 {
+ let handle = module.constants.append(
+ crate::Constant {
+ name: None,
+ specialization: None,
+ inner: crate::ConstantInner::Scalar {
+ width: 4,
+ value: crate::ScalarValue::Sint(i),
+ },
+ },
+ Default::default(),
+ );
+ self.index_constants.push(handle);
+ }
+
+ self.layouter.clear();
+ self.dummy_functions = Arena::new();
+ self.lookup_function.clear();
+ self.function_call_graph.clear();
+
+ loop {
+ use spirv::Op;
+
+ let inst = match self.next_inst() {
+ Ok(inst) => inst,
+ Err(Error::IncompleteData) => break,
+ Err(other) => return Err(other),
+ };
+ log::debug!("\t{:?} [{}]", inst.op, inst.wc);
+
+ match inst.op {
+ Op::Capability => self.parse_capability(inst),
+ Op::Extension => self.parse_extension(inst),
+ Op::ExtInstImport => self.parse_ext_inst_import(inst),
+ Op::MemoryModel => self.parse_memory_model(inst),
+ Op::EntryPoint => self.parse_entry_point(inst),
+ Op::ExecutionMode => self.parse_execution_mode(inst),
+ Op::String => self.parse_string(inst),
+ Op::Source => self.parse_source(inst),
+ Op::SourceExtension => self.parse_source_extension(inst),
+ Op::Name => self.parse_name(inst),
+ Op::MemberName => self.parse_member_name(inst),
+ Op::ModuleProcessed => self.parse_module_processed(inst),
+ Op::Decorate => self.parse_decorate(inst),
+ Op::MemberDecorate => self.parse_member_decorate(inst),
+ Op::TypeVoid => self.parse_type_void(inst),
+ Op::TypeBool => self.parse_type_bool(inst, &mut module),
+ Op::TypeInt => self.parse_type_int(inst, &mut module),
+ Op::TypeFloat => self.parse_type_float(inst, &mut module),
+ Op::TypeVector => self.parse_type_vector(inst, &mut module),
+ Op::TypeMatrix => self.parse_type_matrix(inst, &mut module),
+ Op::TypeFunction => self.parse_type_function(inst),
+ Op::TypePointer => self.parse_type_pointer(inst, &mut module),
+ Op::TypeArray => self.parse_type_array(inst, &mut module),
+ Op::TypeRuntimeArray => self.parse_type_runtime_array(inst, &mut module),
+ Op::TypeStruct => self.parse_type_struct(inst, &mut module),
+ Op::TypeImage => self.parse_type_image(inst, &mut module),
+ Op::TypeSampledImage => self.parse_type_sampled_image(inst),
+ Op::TypeSampler => self.parse_type_sampler(inst, &mut module),
+ Op::Constant | Op::SpecConstant => self.parse_constant(inst, &mut module),
+ Op::ConstantComposite => self.parse_composite_constant(inst, &mut module),
+ Op::ConstantNull | Op::Undef => self
+ .parse_null_constant(inst, &module.types, &mut module.constants)
+ .map(|_| ()),
+ Op::ConstantTrue => self.parse_bool_constant(inst, true, &mut module),
+ Op::ConstantFalse => self.parse_bool_constant(inst, false, &mut module),
+ Op::Variable => self.parse_global_variable(inst, &mut module),
+ Op::Function => {
+ self.switch(ModuleState::Function, inst.op)?;
+ inst.expect(5)?;
+ self.parse_function(&mut module)
+ }
+ _ => Err(Error::UnsupportedInstruction(self.state, inst.op)), //TODO
+ }?;
+ }
+
+ log::info!("Patching...");
+ {
+ let mut nodes = petgraph::algo::toposort(&self.function_call_graph, None)
+ .map_err(|cycle| Error::FunctionCallCycle(cycle.node_id()))?;
+ nodes.reverse(); // we need dominated first
+ let mut functions = mem::take(&mut module.functions);
+ for fun_id in nodes {
+ if fun_id > !(functions.len() as u32) {
+ // skip all the fake IDs registered for the entry points
+ continue;
+ }
+ let lookup = self.lookup_function.get_mut(&fun_id).unwrap();
+ // take out the function from the old array
+ let fun = mem::take(&mut functions[lookup.handle]);
+ // add it to the newly formed arena, and adjust the lookup
+ lookup.handle = module
+ .functions
+ .append(fun, functions.get_span(lookup.handle));
+ }
+ }
+ // patch all the functions
+ for (handle, fun) in module.functions.iter_mut() {
+ self.patch_function(Some(handle), fun)?;
+ }
+ for ep in module.entry_points.iter_mut() {
+ self.patch_function(None, &mut ep.function)?;
+ }
+
+ // Check all the images and samplers to have consistent comparison property.
+ for (handle, flags) in self.handle_sampling.drain() {
+ if !image::patch_comparison_type(
+ flags,
+ module.global_variables.get_mut(handle),
+ &mut module.types,
+ ) {
+ return Err(Error::InconsistentComparisonSampling(handle));
+ }
+ }
+
+ if !self.future_decor.is_empty() {
+ log::warn!("Unused item decorations: {:?}", self.future_decor);
+ self.future_decor.clear();
+ }
+ if !self.future_member_decor.is_empty() {
+ log::warn!("Unused member decorations: {:?}", self.future_member_decor);
+ self.future_member_decor.clear();
+ }
+
+ Ok(module)
+ }
+
+ fn parse_capability(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Capability, inst.op)?;
+ inst.expect(2)?;
+ let capability = self.next()?;
+ let cap =
+ spirv::Capability::from_u32(capability).ok_or(Error::UnknownCapability(capability))?;
+ if !SUPPORTED_CAPABILITIES.contains(&cap) {
+ if self.options.strict_capabilities {
+ return Err(Error::UnsupportedCapability(cap));
+ } else {
+ log::warn!("Unknown capability {:?}", cap);
+ }
+ }
+ Ok(())
+ }
+
+ fn parse_extension(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Extension, inst.op)?;
+ inst.expect_at_least(2)?;
+ let (name, left) = self.next_string(inst.wc - 1)?;
+ if left != 0 {
+ return Err(Error::InvalidOperand);
+ }
+ if !SUPPORTED_EXTENSIONS.contains(&name.as_str()) {
+ return Err(Error::UnsupportedExtension(name));
+ }
+ Ok(())
+ }
+
+ fn parse_ext_inst_import(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Extension, inst.op)?;
+ inst.expect_at_least(3)?;
+ let result_id = self.next()?;
+ let (name, left) = self.next_string(inst.wc - 2)?;
+ if left != 0 {
+ return Err(Error::InvalidOperand);
+ }
+ if !SUPPORTED_EXT_SETS.contains(&name.as_str()) {
+ return Err(Error::UnsupportedExtSet(name));
+ }
+ self.ext_glsl_id = Some(result_id);
+ Ok(())
+ }
+
+ fn parse_memory_model(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::MemoryModel, inst.op)?;
+ inst.expect(3)?;
+ let _addressing_model = self.next()?;
+ let _memory_model = self.next()?;
+ Ok(())
+ }
+
+ fn parse_entry_point(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::EntryPoint, inst.op)?;
+ inst.expect_at_least(4)?;
+ let exec_model = self.next()?;
+ let exec_model = spirv::ExecutionModel::from_u32(exec_model)
+ .ok_or(Error::UnsupportedExecutionModel(exec_model))?;
+ let function_id = self.next()?;
+ let (name, left) = self.next_string(inst.wc - 3)?;
+ let ep = EntryPoint {
+ stage: match exec_model {
+ spirv::ExecutionModel::Vertex => crate::ShaderStage::Vertex,
+ spirv::ExecutionModel::Fragment => crate::ShaderStage::Fragment,
+ spirv::ExecutionModel::GLCompute => crate::ShaderStage::Compute,
+ _ => return Err(Error::UnsupportedExecutionModel(exec_model as u32)),
+ },
+ name,
+ early_depth_test: None,
+ workgroup_size: [0; 3],
+ variable_ids: self.data.by_ref().take(left as usize).collect(),
+ };
+ self.lookup_entry_point.insert(function_id, ep);
+ Ok(())
+ }
+
+ fn parse_execution_mode(&mut self, inst: Instruction) -> Result<(), Error> {
+ use spirv::ExecutionMode;
+
+ self.switch(ModuleState::ExecutionMode, inst.op)?;
+ inst.expect_at_least(3)?;
+
+ let ep_id = self.next()?;
+ let mode_id = self.next()?;
+ let args: Vec<spirv::Word> = self.data.by_ref().take(inst.wc as usize - 3).collect();
+
+ let ep = self
+ .lookup_entry_point
+ .get_mut(&ep_id)
+ .ok_or(Error::InvalidId(ep_id))?;
+ let mode = spirv::ExecutionMode::from_u32(mode_id)
+ .ok_or(Error::UnsupportedExecutionMode(mode_id))?;
+
+ match mode {
+ ExecutionMode::EarlyFragmentTests => {
+ if ep.early_depth_test.is_none() {
+ ep.early_depth_test = Some(crate::EarlyDepthTest { conservative: None });
+ }
+ }
+ ExecutionMode::DepthUnchanged => {
+ ep.early_depth_test = Some(crate::EarlyDepthTest {
+ conservative: Some(crate::ConservativeDepth::Unchanged),
+ });
+ }
+ ExecutionMode::DepthGreater => {
+ ep.early_depth_test = Some(crate::EarlyDepthTest {
+ conservative: Some(crate::ConservativeDepth::GreaterEqual),
+ });
+ }
+ ExecutionMode::DepthLess => {
+ ep.early_depth_test = Some(crate::EarlyDepthTest {
+ conservative: Some(crate::ConservativeDepth::LessEqual),
+ });
+ }
+ ExecutionMode::DepthReplacing => {
+ // Ignored because it can be deduced from the IR.
+ }
+ ExecutionMode::OriginUpperLeft => {
+ // Ignored because the other option (OriginLowerLeft) is not valid in Vulkan mode.
+ }
+ ExecutionMode::LocalSize => {
+ ep.workgroup_size = [args[0], args[1], args[2]];
+ }
+ _ => {
+ return Err(Error::UnsupportedExecutionMode(mode_id));
+ }
+ }
+
+ Ok(())
+ }
+
+ fn parse_string(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Source, inst.op)?;
+ inst.expect_at_least(3)?;
+ let _id = self.next()?;
+ let (_name, _) = self.next_string(inst.wc - 2)?;
+ Ok(())
+ }
+
+ fn parse_source(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Source, inst.op)?;
+ for _ in 1..inst.wc {
+ let _ = self.next()?;
+ }
+ Ok(())
+ }
+
+ fn parse_source_extension(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Source, inst.op)?;
+ inst.expect_at_least(2)?;
+ let (_name, _) = self.next_string(inst.wc - 1)?;
+ Ok(())
+ }
+
+ fn parse_name(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Name, inst.op)?;
+ inst.expect_at_least(3)?;
+ let id = self.next()?;
+ let (name, left) = self.next_string(inst.wc - 2)?;
+ if left != 0 {
+ return Err(Error::InvalidOperand);
+ }
+ self.future_decor.entry(id).or_default().name = Some(name);
+ Ok(())
+ }
+
+ fn parse_member_name(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Name, inst.op)?;
+ inst.expect_at_least(4)?;
+ let id = self.next()?;
+ let member = self.next()?;
+ let (name, left) = self.next_string(inst.wc - 3)?;
+ if left != 0 {
+ return Err(Error::InvalidOperand);
+ }
+
+ self.future_member_decor
+ .entry((id, member))
+ .or_default()
+ .name = Some(name);
+ Ok(())
+ }
+
+ fn parse_module_processed(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Name, inst.op)?;
+ inst.expect_at_least(2)?;
+ let (_info, left) = self.next_string(inst.wc - 1)?;
+ //Note: string is ignored
+ if left != 0 {
+ return Err(Error::InvalidOperand);
+ }
+ Ok(())
+ }
+
+ fn parse_decorate(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Annotation, inst.op)?;
+ inst.expect_at_least(3)?;
+ let id = self.next()?;
+ let mut dec = self.future_decor.remove(&id).unwrap_or_default();
+ self.next_decoration(inst, 2, &mut dec)?;
+ self.future_decor.insert(id, dec);
+ Ok(())
+ }
+
+ fn parse_member_decorate(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Annotation, inst.op)?;
+ inst.expect_at_least(4)?;
+ let id = self.next()?;
+ let member = self.next()?;
+
+ let mut dec = self
+ .future_member_decor
+ .remove(&(id, member))
+ .unwrap_or_default();
+ self.next_decoration(inst, 3, &mut dec)?;
+ self.future_member_decor.insert((id, member), dec);
+ Ok(())
+ }
+
+ fn parse_type_void(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(2)?;
+ let id = self.next()?;
+ self.lookup_void_type = Some(id);
+ Ok(())
+ }
+
+ fn parse_type_bool(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(2)?;
+ let id = self.next()?;
+ let inner = crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Bool,
+ width: crate::BOOL_WIDTH,
+ };
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: module.types.insert(
+ crate::Type {
+ name: self.future_decor.remove(&id).and_then(|dec| dec.name),
+ inner,
+ },
+ self.span_from_with_op(start),
+ ),
+ base_id: None,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_int(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(4)?;
+ let id = self.next()?;
+ let width = self.next()?;
+ let sign = self.next()?;
+ let inner = crate::TypeInner::Scalar {
+ kind: match sign {
+ 0 => crate::ScalarKind::Uint,
+ 1 => crate::ScalarKind::Sint,
+ _ => return Err(Error::InvalidSign(sign)),
+ },
+ width: map_width(width)?,
+ };
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: module.types.insert(
+ crate::Type {
+ name: self.future_decor.remove(&id).and_then(|dec| dec.name),
+ inner,
+ },
+ self.span_from_with_op(start),
+ ),
+ base_id: None,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_float(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(3)?;
+ let id = self.next()?;
+ let width = self.next()?;
+ let inner = crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Float,
+ width: map_width(width)?,
+ };
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: module.types.insert(
+ crate::Type {
+ name: self.future_decor.remove(&id).and_then(|dec| dec.name),
+ inner,
+ },
+ self.span_from_with_op(start),
+ ),
+ base_id: None,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_vector(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(4)?;
+ let id = self.next()?;
+ let type_id = self.next()?;
+ let type_lookup = self.lookup_type.lookup(type_id)?;
+ let (kind, width) = match module.types[type_lookup.handle].inner {
+ crate::TypeInner::Scalar { kind, width } => (kind, width),
+ _ => return Err(Error::InvalidInnerType(type_id)),
+ };
+ let component_count = self.next()?;
+ let inner = crate::TypeInner::Vector {
+ size: map_vector_size(component_count)?,
+ kind,
+ width,
+ };
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: module.types.insert(
+ crate::Type {
+ name: self.future_decor.remove(&id).and_then(|dec| dec.name),
+ inner,
+ },
+ self.span_from_with_op(start),
+ ),
+ base_id: Some(type_id),
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_matrix(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(4)?;
+ let id = self.next()?;
+ let vector_type_id = self.next()?;
+ let num_columns = self.next()?;
+ let decor = self.future_decor.remove(&id);
+
+ let vector_type_lookup = self.lookup_type.lookup(vector_type_id)?;
+ let inner = match module.types[vector_type_lookup.handle].inner {
+ crate::TypeInner::Vector { size, width, .. } => crate::TypeInner::Matrix {
+ columns: map_vector_size(num_columns)?,
+ rows: size,
+ width,
+ },
+ _ => return Err(Error::InvalidInnerType(vector_type_id)),
+ };
+
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: module.types.insert(
+ crate::Type {
+ name: decor.and_then(|dec| dec.name),
+ inner,
+ },
+ self.span_from_with_op(start),
+ ),
+ base_id: Some(vector_type_id),
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_function(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect_at_least(3)?;
+ let id = self.next()?;
+ let return_type_id = self.next()?;
+ let parameter_type_ids = self.data.by_ref().take(inst.wc as usize - 3).collect();
+ self.lookup_function_type.insert(
+ id,
+ LookupFunctionType {
+ parameter_type_ids,
+ return_type_id,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_pointer(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(4)?;
+ let id = self.next()?;
+ let storage_class = self.next()?;
+ let type_id = self.next()?;
+
+ let decor = self.future_decor.remove(&id);
+ let base_lookup_ty = self.lookup_type.lookup(type_id)?;
+ let base_inner = &module.types[base_lookup_ty.handle].inner;
+
+ let space = if let Some(space) = base_inner.pointer_space() {
+ space
+ } else if self
+ .lookup_storage_buffer_types
+ .contains_key(&base_lookup_ty.handle)
+ {
+ crate::AddressSpace::Storage {
+ access: crate::StorageAccess::default(),
+ }
+ } else {
+ match map_storage_class(storage_class)? {
+ ExtendedClass::Global(space) => space,
+ ExtendedClass::Input | ExtendedClass::Output => crate::AddressSpace::Private,
+ }
+ };
+
+ // We don't support pointers to runtime-sized arrays in the `Uniform`
+ // storage class with the `BufferBlock` decoration. Runtime-sized arrays
+ // should be in the StorageBuffer class.
+ if let crate::TypeInner::Array {
+ size: crate::ArraySize::Dynamic,
+ ..
+ } = *base_inner
+ {
+ match space {
+ crate::AddressSpace::Storage { .. } => {}
+ _ => {
+ return Err(Error::UnsupportedRuntimeArrayStorageClass);
+ }
+ }
+ }
+
+ // Don't bother with pointer stuff for `Handle` types.
+ let lookup_ty = if space == crate::AddressSpace::Handle {
+ base_lookup_ty.clone()
+ } else {
+ LookupType {
+ handle: module.types.insert(
+ crate::Type {
+ name: decor.and_then(|dec| dec.name),
+ inner: crate::TypeInner::Pointer {
+ base: base_lookup_ty.handle,
+ space,
+ },
+ },
+ self.span_from_with_op(start),
+ ),
+ base_id: Some(type_id),
+ }
+ };
+ self.lookup_type.insert(id, lookup_ty);
+ Ok(())
+ }
+
+ fn parse_type_array(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(4)?;
+ let id = self.next()?;
+ let type_id = self.next()?;
+ let length_id = self.next()?;
+ let length_const = self.lookup_constant.lookup(length_id)?;
+
+ let decor = self.future_decor.remove(&id).unwrap_or_default();
+ let base = self.lookup_type.lookup(type_id)?.handle;
+
+ self.layouter
+ .update(&module.types, &module.constants)
+ .unwrap();
+
+ // HACK if the underlying type is an image or a sampler, let's assume
+ // that we're dealing with a binding-array
+ //
+ // Note that it's not a strictly correct assumption, but rather a trade
+ // off caused by an impedance mismatch between SPIR-V's and Naga's type
+ // systems - Naga distinguishes between arrays and binding-arrays via
+ // types (i.e. both kinds of arrays are just different types), while
+ // SPIR-V distinguishes between them through usage - e.g. given:
+ //
+ // ```
+ // %image = OpTypeImage %float 2D 2 0 0 2 Rgba16f
+ // %uint_256 = OpConstant %uint 256
+ // %image_array = OpTypeArray %image %uint_256
+ // ```
+ //
+ // ```
+ // %image = OpTypeImage %float 2D 2 0 0 2 Rgba16f
+ // %uint_256 = OpConstant %uint 256
+ // %image_array = OpTypeArray %image %uint_256
+ // %image_array_ptr = OpTypePointer UniformConstant %image_array
+ // ```
+ //
+ // ... in the first case, `%image_array` should technically correspond
+ // to `TypeInner::Array`, while in the second case it should say
+ // `TypeInner::BindingArray` (kinda, depending on whether `%image_array`
+ // is ever used as a freestanding type or rather always through the
+ // pointer-indirection).
+ //
+ // Anyway, at the moment we don't support other kinds of image / sampler
+ // arrays than those binding-based, so this assumption is pretty safe
+ // for now.
+ let inner = if let crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } =
+ module.types[base].inner
+ {
+ crate::TypeInner::BindingArray {
+ base,
+ size: crate::ArraySize::Constant(length_const.handle),
+ }
+ } else {
+ crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(length_const.handle),
+ stride: match decor.array_stride {
+ Some(stride) => stride.get(),
+ None => self.layouter[base].to_stride(),
+ },
+ }
+ };
+
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: module.types.insert(
+ crate::Type {
+ name: decor.name,
+ inner,
+ },
+ self.span_from_with_op(start),
+ ),
+ base_id: Some(type_id),
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_runtime_array(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(3)?;
+ let id = self.next()?;
+ let type_id = self.next()?;
+
+ let decor = self.future_decor.remove(&id).unwrap_or_default();
+ let base = self.lookup_type.lookup(type_id)?.handle;
+
+ self.layouter
+ .update(&module.types, &module.constants)
+ .unwrap();
+
+ // HACK same case as in `parse_type_array()`
+ let inner = if let crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } =
+ module.types[base].inner
+ {
+ crate::TypeInner::BindingArray {
+ base: self.lookup_type.lookup(type_id)?.handle,
+ size: crate::ArraySize::Dynamic,
+ }
+ } else {
+ crate::TypeInner::Array {
+ base: self.lookup_type.lookup(type_id)?.handle,
+ size: crate::ArraySize::Dynamic,
+ stride: match decor.array_stride {
+ Some(stride) => stride.get(),
+ None => self.layouter[base].to_stride(),
+ },
+ }
+ };
+
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: module.types.insert(
+ crate::Type {
+ name: decor.name,
+ inner,
+ },
+ self.span_from_with_op(start),
+ ),
+ base_id: Some(type_id),
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_struct(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect_at_least(2)?;
+ let id = self.next()?;
+ let parent_decor = self.future_decor.remove(&id);
+ let is_storage_buffer = parent_decor
+ .as_ref()
+ .map_or(false, |decor| decor.storage_buffer);
+
+ self.layouter
+ .update(&module.types, &module.constants)
+ .unwrap();
+
+ let mut members = Vec::<crate::StructMember>::with_capacity(inst.wc as usize - 2);
+ let mut member_lookups = Vec::with_capacity(members.capacity());
+ let mut storage_access = crate::StorageAccess::empty();
+ let mut span = 0;
+ let mut alignment = Alignment::ONE;
+ for i in 0..u32::from(inst.wc) - 2 {
+ let type_id = self.next()?;
+ let ty = self.lookup_type.lookup(type_id)?.handle;
+ let decor = self
+ .future_member_decor
+ .remove(&(id, i))
+ .unwrap_or_default();
+
+ storage_access |= decor.flags.to_storage_access();
+
+ member_lookups.push(LookupMember {
+ type_id,
+ row_major: decor.matrix_major == Some(Majority::Row),
+ });
+
+ let member_alignment = self.layouter[ty].alignment;
+ span = member_alignment.round_up(span);
+ alignment = member_alignment.max(alignment);
+
+ let binding = decor.io_binding().ok();
+ if let Some(offset) = decor.offset {
+ span = offset;
+ }
+ let offset = span;
+
+ span += self.layouter[ty].size;
+
+ let inner = &module.types[ty].inner;
+ if let crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } = *inner
+ {
+ if let Some(stride) = decor.matrix_stride {
+ let expected_stride = Alignment::from(rows) * width as u32;
+ if stride.get() != expected_stride {
+ return Err(Error::UnsupportedMatrixStride {
+ stride: stride.get(),
+ columns: columns as u8,
+ rows: rows as u8,
+ width,
+ });
+ }
+ }
+ }
+
+ members.push(crate::StructMember {
+ name: decor.name,
+ ty,
+ binding,
+ offset,
+ });
+ }
+
+ span = alignment.round_up(span);
+
+ let inner = crate::TypeInner::Struct { span, members };
+
+ let ty_handle = module.types.insert(
+ crate::Type {
+ name: parent_decor.and_then(|dec| dec.name),
+ inner,
+ },
+ self.span_from_with_op(start),
+ );
+
+ if is_storage_buffer {
+ self.lookup_storage_buffer_types
+ .insert(ty_handle, storage_access);
+ }
+ for (i, member_lookup) in member_lookups.into_iter().enumerate() {
+ self.lookup_member
+ .insert((ty_handle, i as u32), member_lookup);
+ }
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: ty_handle,
+ base_id: None,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_image(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(9)?;
+
+ let id = self.next()?;
+ let sample_type_id = self.next()?;
+ let dim = self.next()?;
+ let _is_depth = self.next()?;
+ let is_array = self.next()? != 0;
+ let is_msaa = self.next()? != 0;
+ let _is_sampled = self.next()?;
+ let format = self.next()?;
+
+ let dim = map_image_dim(dim)?;
+ let decor = self.future_decor.remove(&id).unwrap_or_default();
+
+ // ensure there is a type for texture coordinate without extra components
+ module.types.insert(
+ crate::Type {
+ name: None,
+ inner: {
+ let kind = crate::ScalarKind::Float;
+ let width = 4;
+ match dim.required_coordinate_size() {
+ None => crate::TypeInner::Scalar { kind, width },
+ Some(size) => crate::TypeInner::Vector { size, kind, width },
+ }
+ },
+ },
+ Default::default(),
+ );
+
+ let base_handle = self.lookup_type.lookup(sample_type_id)?.handle;
+ let kind = module.types[base_handle]
+ .inner
+ .scalar_kind()
+ .ok_or(Error::InvalidImageBaseType(base_handle))?;
+
+ let inner = crate::TypeInner::Image {
+ class: if format != 0 {
+ crate::ImageClass::Storage {
+ format: map_image_format(format)?,
+ access: crate::StorageAccess::default(),
+ }
+ } else {
+ crate::ImageClass::Sampled {
+ kind,
+ multi: is_msaa,
+ }
+ },
+ dim,
+ arrayed: is_array,
+ };
+
+ let handle = module.types.insert(
+ crate::Type {
+ name: decor.name,
+ inner,
+ },
+ self.span_from_with_op(start),
+ );
+
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle,
+ base_id: Some(sample_type_id),
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_sampled_image(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(3)?;
+ let id = self.next()?;
+ let image_id = self.next()?;
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: self.lookup_type.lookup(image_id)?.handle,
+ base_id: Some(image_id),
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_sampler(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(2)?;
+ let id = self.next()?;
+ let decor = self.future_decor.remove(&id).unwrap_or_default();
+ let handle = module.types.insert(
+ crate::Type {
+ name: decor.name,
+ inner: crate::TypeInner::Sampler { comparison: false },
+ },
+ self.span_from_with_op(start),
+ );
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle,
+ base_id: None,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_constant(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect_at_least(4)?;
+ let type_id = self.next()?;
+ let id = self.next()?;
+ let type_lookup = self.lookup_type.lookup(type_id)?;
+ let ty = type_lookup.handle;
+
+ let inner = match module.types[ty].inner {
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Uint,
+ width,
+ } => {
+ let low = self.next()?;
+ let high = if width > 4 {
+ inst.expect(5)?;
+ self.next()?
+ } else {
+ 0
+ };
+ crate::ConstantInner::Scalar {
+ width,
+ value: crate::ScalarValue::Uint((u64::from(high) << 32) | u64::from(low)),
+ }
+ }
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Sint,
+ width,
+ } => {
+ let low = self.next()?;
+ let high = if width > 4 {
+ inst.expect(5)?;
+ self.next()?
+ } else {
+ 0
+ };
+ crate::ConstantInner::Scalar {
+ width,
+ value: crate::ScalarValue::Sint(
+ (i64::from(high as i32) << 32) | ((i64::from(low as i32) << 32) >> 32),
+ ),
+ }
+ }
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Float,
+ width,
+ } => {
+ let low = self.next()?;
+ let extended = match width {
+ 4 => f64::from(f32::from_bits(low)),
+ 8 => {
+ inst.expect(5)?;
+ let high = self.next()?;
+ f64::from_bits((u64::from(high) << 32) | u64::from(low))
+ }
+ _ => return Err(Error::InvalidTypeWidth(width as u32)),
+ };
+ crate::ConstantInner::Scalar {
+ width,
+ value: crate::ScalarValue::Float(extended),
+ }
+ }
+ _ => return Err(Error::UnsupportedType(type_lookup.handle)),
+ };
+
+ let decor = self.future_decor.remove(&id).unwrap_or_default();
+
+ self.lookup_constant.insert(
+ id,
+ LookupConstant {
+ handle: module.constants.append(
+ crate::Constant {
+ specialization: decor.specialization,
+ name: decor.name,
+ inner,
+ },
+ self.span_from_with_op(start),
+ ),
+ type_id,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_composite_constant(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect_at_least(3)?;
+ let type_id = self.next()?;
+ let type_lookup = self.lookup_type.lookup(type_id)?;
+ let ty = type_lookup.handle;
+ let id = self.next()?;
+
+ let mut components = Vec::with_capacity(inst.wc as usize - 3);
+ for _ in 0..components.capacity() {
+ let component_id = self.next()?;
+ let constant = self.lookup_constant.lookup(component_id)?;
+ components.push(constant.handle);
+ }
+
+ self.lookup_constant.insert(
+ id,
+ LookupConstant {
+ handle: module.constants.append(
+ crate::Constant {
+ name: self.future_decor.remove(&id).and_then(|dec| dec.name),
+ specialization: None,
+ inner: crate::ConstantInner::Composite { ty, components },
+ },
+ self.span_from_with_op(start),
+ ),
+ type_id,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_null_constant(
+ &mut self,
+ inst: Instruction,
+ types: &UniqueArena<crate::Type>,
+ constants: &mut Arena<crate::Constant>,
+ ) -> Result<(u32, u32, Handle<crate::Constant>), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(3)?;
+ let type_id = self.next()?;
+ let id = self.next()?;
+ let span = self.span_from_with_op(start);
+ let type_lookup = self.lookup_type.lookup(type_id)?;
+ let ty = type_lookup.handle;
+
+ let inner = null::generate_null_constant(ty, types, constants, span)?;
+ let handle = constants.append(
+ crate::Constant {
+ name: self.future_decor.remove(&id).and_then(|dec| dec.name),
+ specialization: None, //TODO
+ inner,
+ },
+ span,
+ );
+ self.lookup_constant
+ .insert(id, LookupConstant { handle, type_id });
+ Ok((type_id, id, handle))
+ }
+
+ fn parse_bool_constant(
+ &mut self,
+ inst: Instruction,
+ value: bool,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(3)?;
+ let type_id = self.next()?;
+ let id = self.next()?;
+
+ self.lookup_constant.insert(
+ id,
+ LookupConstant {
+ handle: module.constants.append(
+ crate::Constant {
+ name: self.future_decor.remove(&id).and_then(|dec| dec.name),
+ specialization: None, //TODO
+ inner: crate::ConstantInner::boolean(value),
+ },
+ self.span_from_with_op(start),
+ ),
+ type_id,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_global_variable(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect_at_least(4)?;
+ let type_id = self.next()?;
+ let id = self.next()?;
+ let storage_class = self.next()?;
+ let init = if inst.wc > 4 {
+ inst.expect(5)?;
+ let init_id = self.next()?;
+ let lconst = self.lookup_constant.lookup(init_id)?;
+ Some(lconst.handle)
+ } else {
+ None
+ };
+ let span = self.span_from_with_op(start);
+ let mut dec = self.future_decor.remove(&id).unwrap_or_default();
+
+ let original_ty = self.lookup_type.lookup(type_id)?.handle;
+ let mut ty = original_ty;
+
+ if let crate::TypeInner::Pointer { base, space: _ } = module.types[original_ty].inner {
+ ty = base;
+ }
+
+ if let crate::TypeInner::BindingArray { .. } = module.types[original_ty].inner {
+ // Inside `parse_type_array()` we guess that an array of images or
+ // samplers must be a binding array, and here we validate that guess
+ if dec.desc_set.is_none() || dec.desc_index.is_none() {
+ return Err(Error::NonBindingArrayOfImageOrSamplers);
+ }
+ }
+
+ if let crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class: crate::ImageClass::Storage { format, access: _ },
+ } = module.types[ty].inner
+ {
+ // Storage image types in IR have to contain the access, but not in the SPIR-V.
+ // The same image type in SPIR-V can be used (and has to be used) for multiple images.
+ // So we copy the type out and apply the variable access decorations.
+ let access = dec.flags.to_storage_access();
+
+ ty = module.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class: crate::ImageClass::Storage { format, access },
+ },
+ },
+ Default::default(),
+ );
+ }
+
+ let ext_class = match self.lookup_storage_buffer_types.get(&ty) {
+ Some(&access) => ExtendedClass::Global(crate::AddressSpace::Storage { access }),
+ None => map_storage_class(storage_class)?,
+ };
+
+ // Fix empty name for gl_PerVertex struct generated by glslang
+ if let crate::TypeInner::Pointer { .. } = module.types[original_ty].inner {
+ if ext_class == ExtendedClass::Input || ext_class == ExtendedClass::Output {
+ if let Some(ref dec_name) = dec.name {
+ if dec_name.is_empty() {
+ dec.name = Some("perVertexStruct".to_string())
+ }
+ }
+ }
+ }
+
+ let (inner, var) = match ext_class {
+ ExtendedClass::Global(mut space) => {
+ if let crate::AddressSpace::Storage { ref mut access } = space {
+ *access &= dec.flags.to_storage_access();
+ }
+ let var = crate::GlobalVariable {
+ binding: dec.resource_binding(),
+ name: dec.name,
+ space,
+ ty,
+ init,
+ };
+ (Variable::Global, var)
+ }
+ ExtendedClass::Input => {
+ let binding = dec.io_binding()?;
+ let mut unsigned_ty = ty;
+ if let crate::Binding::BuiltIn(built_in) = binding {
+ let needs_inner_uint = match built_in {
+ crate::BuiltIn::BaseInstance
+ | crate::BuiltIn::BaseVertex
+ | crate::BuiltIn::InstanceIndex
+ | crate::BuiltIn::SampleIndex
+ | crate::BuiltIn::VertexIndex
+ | crate::BuiltIn::PrimitiveIndex
+ | crate::BuiltIn::LocalInvocationIndex => Some(crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ }),
+ crate::BuiltIn::GlobalInvocationId
+ | crate::BuiltIn::LocalInvocationId
+ | crate::BuiltIn::WorkGroupId
+ | crate::BuiltIn::WorkGroupSize => Some(crate::TypeInner::Vector {
+ size: crate::VectorSize::Tri,
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ }),
+ _ => None,
+ };
+ if let (Some(inner), Some(crate::ScalarKind::Sint)) =
+ (needs_inner_uint, module.types[ty].inner.scalar_kind())
+ {
+ unsigned_ty = module
+ .types
+ .insert(crate::Type { name: None, inner }, Default::default());
+ }
+ }
+
+ let var = crate::GlobalVariable {
+ name: dec.name.clone(),
+ space: crate::AddressSpace::Private,
+ binding: None,
+ ty,
+ init: None,
+ };
+
+ let inner = Variable::Input(crate::FunctionArgument {
+ name: dec.name,
+ ty: unsigned_ty,
+ binding: Some(binding),
+ });
+ (inner, var)
+ }
+ ExtendedClass::Output => {
+ // For output interface blocks, this would be a structure.
+ let binding = dec.io_binding().ok();
+ let init = match binding {
+ Some(crate::Binding::BuiltIn(built_in)) => {
+ match null::generate_default_built_in(
+ Some(built_in),
+ ty,
+ &module.types,
+ &mut module.constants,
+ span,
+ ) {
+ Ok(handle) => Some(handle),
+ Err(e) => {
+ log::warn!("Failed to initialize output built-in: {}", e);
+ None
+ }
+ }
+ }
+ Some(crate::Binding::Location { .. }) => None,
+ None => match module.types[ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ // A temporary to avoid borrowing `module.types`
+ let pairs = members
+ .iter()
+ .map(|member| {
+ let built_in = match member.binding {
+ Some(crate::Binding::BuiltIn(built_in)) => Some(built_in),
+ _ => None,
+ };
+ (built_in, member.ty)
+ })
+ .collect::<Vec<_>>();
+
+ let mut components = Vec::with_capacity(members.len());
+ for (built_in, member_ty) in pairs {
+ let handle = null::generate_default_built_in(
+ built_in,
+ member_ty,
+ &module.types,
+ &mut module.constants,
+ span,
+ )?;
+ components.push(handle);
+ }
+ Some(module.constants.append(
+ crate::Constant {
+ name: None,
+ specialization: None,
+ inner: crate::ConstantInner::Composite { ty, components },
+ },
+ span,
+ ))
+ }
+ _ => None,
+ },
+ };
+
+ let var = crate::GlobalVariable {
+ name: dec.name,
+ space: crate::AddressSpace::Private,
+ binding: None,
+ ty,
+ init,
+ };
+ let inner = Variable::Output(crate::FunctionResult { ty, binding });
+ (inner, var)
+ }
+ };
+
+ let handle = module.global_variables.append(var, span);
+
+ if module.types[ty].inner.can_comparison_sample(module) {
+ log::debug!("\t\ttracking {:?} for sampling properties", handle);
+
+ self.handle_sampling
+ .insert(handle, image::SamplingFlags::empty());
+ }
+
+ self.lookup_variable.insert(
+ id,
+ LookupVariable {
+ inner,
+ handle,
+ type_id,
+ },
+ );
+ Ok(())
+ }
+}
+
+pub fn parse_u8_slice(data: &[u8], options: &Options) -> Result<crate::Module, Error> {
+ if data.len() % 4 != 0 {
+ return Err(Error::IncompleteData);
+ }
+
+ let words = data
+ .chunks(4)
+ .map(|c| u32::from_le_bytes(c.try_into().unwrap()));
+ Frontend::new(words, options).parse()
+}
+
+#[cfg(test)]
+mod test {
+ #[test]
+ fn parse() {
+ let bin = vec![
+ // Magic number. Version number: 1.0.
+ 0x03, 0x02, 0x23, 0x07, 0x00, 0x00, 0x01, 0x00,
+ // Generator number: 0. Bound: 0.
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Reserved word: 0.
+ 0x00, 0x00, 0x00, 0x00, // OpMemoryModel. Logical.
+ 0x0e, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, // GLSL450.
+ 0x01, 0x00, 0x00, 0x00,
+ ];
+ let _ = super::parse_u8_slice(&bin, &Default::default()).unwrap();
+ }
+}
+
+/// Helper function to check if `child` is in the scope of `parent`
+fn is_parent(mut child: usize, parent: usize, block_ctx: &BlockContext) -> bool {
+ loop {
+ if child == parent {
+ // The child is in the scope parent
+ break true;
+ } else if child == 0 {
+ // Searched finished at the root the child isn't in the parent's body
+ break false;
+ }
+
+ child = block_ctx.bodies[child].parent;
+ }
+}
diff --git a/third_party/rust/naga/src/front/spv/null.rs b/third_party/rust/naga/src/front/spv/null.rs
new file mode 100644
index 0000000000..85350c563e
--- /dev/null
+++ b/third_party/rust/naga/src/front/spv/null.rs
@@ -0,0 +1,175 @@
+use super::Error;
+use crate::arena::{Arena, Handle, UniqueArena};
+
+const fn make_scalar_inner(kind: crate::ScalarKind, width: crate::Bytes) -> crate::ConstantInner {
+ crate::ConstantInner::Scalar {
+ width,
+ value: match kind {
+ crate::ScalarKind::Uint => crate::ScalarValue::Uint(0),
+ crate::ScalarKind::Sint => crate::ScalarValue::Sint(0),
+ crate::ScalarKind::Float => crate::ScalarValue::Float(0.0),
+ crate::ScalarKind::Bool => crate::ScalarValue::Bool(false),
+ },
+ }
+}
+
+pub fn generate_null_constant(
+ ty: Handle<crate::Type>,
+ type_arena: &UniqueArena<crate::Type>,
+ constant_arena: &mut Arena<crate::Constant>,
+ span: crate::Span,
+) -> Result<crate::ConstantInner, Error> {
+ let inner = match type_arena[ty].inner {
+ crate::TypeInner::Scalar { kind, width } => make_scalar_inner(kind, width),
+ crate::TypeInner::Vector { size, kind, width } => {
+ let mut components = Vec::with_capacity(size as usize);
+ for _ in 0..size as usize {
+ components.push(constant_arena.fetch_or_append(
+ crate::Constant {
+ name: None,
+ specialization: None,
+ inner: make_scalar_inner(kind, width),
+ },
+ span,
+ ));
+ }
+ crate::ConstantInner::Composite { ty, components }
+ }
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ // If we successfully declared a matrix type, we have declared a vector type for it too.
+ let vector_ty = type_arena
+ .get(&crate::Type {
+ name: None,
+ inner: crate::TypeInner::Vector {
+ kind: crate::ScalarKind::Float,
+ size: rows,
+ width,
+ },
+ })
+ .unwrap();
+ let vector_inner = generate_null_constant(vector_ty, type_arena, constant_arena, span)?;
+ let vector_handle = constant_arena.fetch_or_append(
+ crate::Constant {
+ name: None,
+ specialization: None,
+ inner: vector_inner,
+ },
+ span,
+ );
+ crate::ConstantInner::Composite {
+ ty,
+ components: vec![vector_handle; columns as usize],
+ }
+ }
+ crate::TypeInner::Struct { ref members, .. } => {
+ let mut components = Vec::with_capacity(members.len());
+ // copy out the types to avoid borrowing `members`
+ let member_tys = members.iter().map(|member| member.ty).collect::<Vec<_>>();
+ for member_ty in member_tys {
+ let inner = generate_null_constant(member_ty, type_arena, constant_arena, span)?;
+ components.push(constant_arena.fetch_or_append(
+ crate::Constant {
+ name: None,
+ specialization: None,
+ inner,
+ },
+ span,
+ ));
+ }
+ crate::ConstantInner::Composite { ty, components }
+ }
+ crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(handle),
+ ..
+ } => {
+ let size = constant_arena[handle]
+ .to_array_length()
+ .ok_or(Error::InvalidArraySize(handle))?;
+ let inner = generate_null_constant(base, type_arena, constant_arena, span)?;
+ let value = constant_arena.fetch_or_append(
+ crate::Constant {
+ name: None,
+ specialization: None,
+ inner,
+ },
+ span,
+ );
+ crate::ConstantInner::Composite {
+ ty,
+ components: vec![value; size as usize],
+ }
+ }
+ ref other => {
+ log::warn!("null constant type {:?}", other);
+ return Err(Error::UnsupportedType(ty));
+ }
+ };
+ Ok(inner)
+}
+
+/// Create a default value for an output built-in.
+pub fn generate_default_built_in(
+ built_in: Option<crate::BuiltIn>,
+ ty: Handle<crate::Type>,
+ type_arena: &UniqueArena<crate::Type>,
+ constant_arena: &mut Arena<crate::Constant>,
+ span: crate::Span,
+) -> Result<Handle<crate::Constant>, Error> {
+ let inner = match built_in {
+ Some(crate::BuiltIn::Position { .. }) => {
+ let zero = constant_arena.fetch_or_append(
+ crate::Constant {
+ name: None,
+ specialization: None,
+ inner: crate::ConstantInner::Scalar {
+ value: crate::ScalarValue::Float(0.0),
+ width: 4,
+ },
+ },
+ span,
+ );
+ let one = constant_arena.fetch_or_append(
+ crate::Constant {
+ name: None,
+ specialization: None,
+ inner: crate::ConstantInner::Scalar {
+ value: crate::ScalarValue::Float(1.0),
+ width: 4,
+ },
+ },
+ span,
+ );
+ crate::ConstantInner::Composite {
+ ty,
+ components: vec![zero, zero, zero, one],
+ }
+ }
+ Some(crate::BuiltIn::PointSize) => crate::ConstantInner::Scalar {
+ value: crate::ScalarValue::Float(1.0),
+ width: 4,
+ },
+ Some(crate::BuiltIn::FragDepth) => crate::ConstantInner::Scalar {
+ value: crate::ScalarValue::Float(0.0),
+ width: 4,
+ },
+ Some(crate::BuiltIn::SampleMask) => crate::ConstantInner::Scalar {
+ value: crate::ScalarValue::Uint(u64::MAX),
+ width: 4,
+ },
+ //Note: `crate::BuiltIn::ClipDistance` is intentionally left for the default path
+ _ => generate_null_constant(ty, type_arena, constant_arena, span)?,
+ };
+ Ok(constant_arena.fetch_or_append(
+ crate::Constant {
+ name: None,
+ specialization: None,
+ inner,
+ },
+ span,
+ ))
+}
diff --git a/third_party/rust/naga/src/front/type_gen.rs b/third_party/rust/naga/src/front/type_gen.rs
new file mode 100644
index 0000000000..1ee454c448
--- /dev/null
+++ b/third_party/rust/naga/src/front/type_gen.rs
@@ -0,0 +1,314 @@
+/*!
+Type generators.
+*/
+
+use crate::{arena::Handle, span::Span};
+
+impl crate::Module {
+ pub fn generate_atomic_compare_exchange_result(
+ &mut self,
+ kind: crate::ScalarKind,
+ width: crate::Bytes,
+ ) -> Handle<crate::Type> {
+ let bool_ty = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Bool,
+ width: crate::BOOL_WIDTH,
+ },
+ },
+ Span::UNDEFINED,
+ );
+ let scalar_ty = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Scalar { kind, width },
+ },
+ Span::UNDEFINED,
+ );
+
+ self.types.insert(
+ crate::Type {
+ name: Some(format!(
+ "__atomic_compare_exchange_result<{kind:?},{width}>"
+ )),
+ inner: crate::TypeInner::Struct {
+ members: vec![
+ crate::StructMember {
+ name: Some("old_value".to_string()),
+ ty: scalar_ty,
+ binding: None,
+ offset: 0,
+ },
+ crate::StructMember {
+ name: Some("exchanged".to_string()),
+ ty: bool_ty,
+ binding: None,
+ offset: 4,
+ },
+ ],
+ span: 8,
+ },
+ },
+ Span::UNDEFINED,
+ )
+ }
+ /// Populate this module's [`SpecialTypes::ray_desc`] type.
+ ///
+ /// [`SpecialTypes::ray_desc`] is the type of the [`descriptor`] operand of
+ /// an [`Initialize`] [`RayQuery`] statement. In WGSL, it is a struct type
+ /// referred to as `RayDesc`.
+ ///
+ /// Backends consume values of this type to drive platform APIs, so if you
+ /// change any its fields, you must update the backends to match. Look for
+ /// backend code dealing with [`RayQueryFunction::Initialize`].
+ ///
+ /// [`SpecialTypes::ray_desc`]: crate::SpecialTypes::ray_desc
+ /// [`descriptor`]: crate::RayQueryFunction::Initialize::descriptor
+ /// [`Initialize`]: crate::RayQueryFunction::Initialize
+ /// [`RayQuery`]: crate::Statement::RayQuery
+ /// [`RayQueryFunction::Initialize`]: crate::RayQueryFunction::Initialize
+ pub fn generate_ray_desc_type(&mut self) -> Handle<crate::Type> {
+ if let Some(handle) = self.special_types.ray_desc {
+ return handle;
+ }
+
+ let width = 4;
+ let ty_flag = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Scalar {
+ width,
+ kind: crate::ScalarKind::Uint,
+ },
+ },
+ Span::UNDEFINED,
+ );
+ let ty_scalar = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Scalar {
+ width,
+ kind: crate::ScalarKind::Float,
+ },
+ },
+ Span::UNDEFINED,
+ );
+ let ty_vector = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Vector {
+ size: crate::VectorSize::Tri,
+ kind: crate::ScalarKind::Float,
+ width,
+ },
+ },
+ Span::UNDEFINED,
+ );
+
+ let handle = self.types.insert(
+ crate::Type {
+ name: Some("RayDesc".to_string()),
+ inner: crate::TypeInner::Struct {
+ members: vec![
+ crate::StructMember {
+ name: Some("flags".to_string()),
+ ty: ty_flag,
+ binding: None,
+ offset: 0,
+ },
+ crate::StructMember {
+ name: Some("cull_mask".to_string()),
+ ty: ty_flag,
+ binding: None,
+ offset: 4,
+ },
+ crate::StructMember {
+ name: Some("tmin".to_string()),
+ ty: ty_scalar,
+ binding: None,
+ offset: 8,
+ },
+ crate::StructMember {
+ name: Some("tmax".to_string()),
+ ty: ty_scalar,
+ binding: None,
+ offset: 12,
+ },
+ crate::StructMember {
+ name: Some("origin".to_string()),
+ ty: ty_vector,
+ binding: None,
+ offset: 16,
+ },
+ crate::StructMember {
+ name: Some("dir".to_string()),
+ ty: ty_vector,
+ binding: None,
+ offset: 32,
+ },
+ ],
+ span: 48,
+ },
+ },
+ Span::UNDEFINED,
+ );
+
+ self.special_types.ray_desc = Some(handle);
+ handle
+ }
+
+ /// Populate this module's [`SpecialTypes::ray_intersection`] type.
+ ///
+ /// [`SpecialTypes::ray_intersection`] is the type of a
+ /// `RayQueryGetIntersection` expression. In WGSL, it is a struct type
+ /// referred to as `RayIntersection`.
+ ///
+ /// Backends construct values of this type based on platform APIs, so if you
+ /// change any its fields, you must update the backends to match. Look for
+ /// the backend's handling for [`Expression::RayQueryGetIntersection`].
+ ///
+ /// [`SpecialTypes::ray_intersection`]: crate::SpecialTypes::ray_intersection
+ /// [`Expression::RayQueryGetIntersection`]: crate::Expression::RayQueryGetIntersection
+ pub fn generate_ray_intersection_type(&mut self) -> Handle<crate::Type> {
+ if let Some(handle) = self.special_types.ray_intersection {
+ return handle;
+ }
+
+ let width = 4;
+ let ty_flag = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Scalar {
+ width,
+ kind: crate::ScalarKind::Uint,
+ },
+ },
+ Span::UNDEFINED,
+ );
+ let ty_scalar = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Scalar {
+ width,
+ kind: crate::ScalarKind::Float,
+ },
+ },
+ Span::UNDEFINED,
+ );
+ let ty_barycentrics = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Vector {
+ width,
+ size: crate::VectorSize::Bi,
+ kind: crate::ScalarKind::Float,
+ },
+ },
+ Span::UNDEFINED,
+ );
+ let ty_bool = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Scalar {
+ width: crate::BOOL_WIDTH,
+ kind: crate::ScalarKind::Bool,
+ },
+ },
+ Span::UNDEFINED,
+ );
+ let ty_transform = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Matrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Tri,
+ width,
+ },
+ },
+ Span::UNDEFINED,
+ );
+
+ let handle = self.types.insert(
+ crate::Type {
+ name: Some("RayIntersection".to_string()),
+ inner: crate::TypeInner::Struct {
+ members: vec![
+ crate::StructMember {
+ name: Some("kind".to_string()),
+ ty: ty_flag,
+ binding: None,
+ offset: 0,
+ },
+ crate::StructMember {
+ name: Some("t".to_string()),
+ ty: ty_scalar,
+ binding: None,
+ offset: 4,
+ },
+ crate::StructMember {
+ name: Some("instance_custom_index".to_string()),
+ ty: ty_flag,
+ binding: None,
+ offset: 8,
+ },
+ crate::StructMember {
+ name: Some("instance_id".to_string()),
+ ty: ty_flag,
+ binding: None,
+ offset: 12,
+ },
+ crate::StructMember {
+ name: Some("sbt_record_offset".to_string()),
+ ty: ty_flag,
+ binding: None,
+ offset: 16,
+ },
+ crate::StructMember {
+ name: Some("geometry_index".to_string()),
+ ty: ty_flag,
+ binding: None,
+ offset: 20,
+ },
+ crate::StructMember {
+ name: Some("primitive_index".to_string()),
+ ty: ty_flag,
+ binding: None,
+ offset: 24,
+ },
+ crate::StructMember {
+ name: Some("barycentrics".to_string()),
+ ty: ty_barycentrics,
+ binding: None,
+ offset: 28,
+ },
+ crate::StructMember {
+ name: Some("front_face".to_string()),
+ ty: ty_bool,
+ binding: None,
+ offset: 36,
+ },
+ crate::StructMember {
+ name: Some("object_to_world".to_string()),
+ ty: ty_transform,
+ binding: None,
+ offset: 48,
+ },
+ crate::StructMember {
+ name: Some("world_to_object".to_string()),
+ ty: ty_transform,
+ binding: None,
+ offset: 112,
+ },
+ ],
+ span: 176,
+ },
+ },
+ Span::UNDEFINED,
+ );
+
+ self.special_types.ray_intersection = Some(handle);
+ handle
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/error.rs b/third_party/rust/naga/src/front/wgsl/error.rs
new file mode 100644
index 0000000000..82f7c609ee
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/error.rs
@@ -0,0 +1,690 @@
+use crate::front::wgsl::parse::lexer::Token;
+use crate::proc::{Alignment, ResolveError};
+use crate::{SourceLocation, Span};
+use codespan_reporting::diagnostic::{Diagnostic, Label};
+use codespan_reporting::files::SimpleFile;
+use codespan_reporting::term;
+use std::borrow::Cow;
+use std::ops::Range;
+use termcolor::{ColorChoice, NoColor, StandardStream};
+use thiserror::Error;
+
+#[derive(Clone, Debug)]
+pub struct ParseError {
+ message: String,
+ labels: Vec<(Span, Cow<'static, str>)>,
+ notes: Vec<String>,
+}
+
+impl ParseError {
+ pub fn labels(&self) -> impl Iterator<Item = (Span, &str)> + ExactSizeIterator + '_ {
+ self.labels
+ .iter()
+ .map(|&(span, ref msg)| (span, msg.as_ref()))
+ }
+
+ pub fn message(&self) -> &str {
+ &self.message
+ }
+
+ fn diagnostic(&self) -> Diagnostic<()> {
+ let diagnostic = Diagnostic::error()
+ .with_message(self.message.to_string())
+ .with_labels(
+ self.labels
+ .iter()
+ .map(|label| {
+ Label::primary((), label.0.to_range().unwrap())
+ .with_message(label.1.to_string())
+ })
+ .collect(),
+ )
+ .with_notes(
+ self.notes
+ .iter()
+ .map(|note| format!("note: {note}"))
+ .collect(),
+ );
+ diagnostic
+ }
+
+ /// Emits a summary of the error to standard error stream.
+ pub fn emit_to_stderr(&self, source: &str) {
+ self.emit_to_stderr_with_path(source, "wgsl")
+ }
+
+ /// Emits a summary of the error to standard error stream.
+ pub fn emit_to_stderr_with_path(&self, source: &str, path: &str) {
+ let files = SimpleFile::new(path, source);
+ let config = codespan_reporting::term::Config::default();
+ let writer = StandardStream::stderr(ColorChoice::Auto);
+ term::emit(&mut writer.lock(), &config, &files, &self.diagnostic())
+ .expect("cannot write error");
+ }
+
+ /// Emits a summary of the error to a string.
+ pub fn emit_to_string(&self, source: &str) -> String {
+ self.emit_to_string_with_path(source, "wgsl")
+ }
+
+ /// Emits a summary of the error to a string.
+ pub fn emit_to_string_with_path(&self, source: &str, path: &str) -> String {
+ let files = SimpleFile::new(path, source);
+ let config = codespan_reporting::term::Config::default();
+ let mut writer = NoColor::new(Vec::new());
+ term::emit(&mut writer, &config, &files, &self.diagnostic()).expect("cannot write error");
+ String::from_utf8(writer.into_inner()).unwrap()
+ }
+
+ /// Returns a [`SourceLocation`] for the first label in the error message.
+ pub fn location(&self, source: &str) -> Option<SourceLocation> {
+ self.labels.get(0).map(|label| label.0.location(source))
+ }
+}
+
+impl std::fmt::Display for ParseError {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{}", self.message)
+ }
+}
+
+impl std::error::Error for ParseError {
+ fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
+ None
+ }
+}
+
+#[derive(Copy, Clone, Debug, PartialEq)]
+pub enum ExpectedToken<'a> {
+ Token(Token<'a>),
+ Identifier,
+ Number,
+ Integer,
+ /// A compile-time constant expression.
+ Constant,
+ /// Expected: constant, parenthesized expression, identifier
+ PrimaryExpression,
+ /// Expected: assignment, increment/decrement expression
+ Assignment,
+ /// Expected: 'case', 'default', '}'
+ SwitchItem,
+ /// Expected: ',', ')'
+ WorkgroupSizeSeparator,
+ /// Expected: 'struct', 'let', 'var', 'type', ';', 'fn', eof
+ GlobalItem,
+ /// Expected a type.
+ Type,
+ /// Access of `var`, `let`, `const`.
+ Variable,
+ /// Access of a function
+ Function,
+}
+
+#[derive(Clone, Copy, Debug, Error, PartialEq)]
+pub enum NumberError {
+ #[error("invalid numeric literal format")]
+ Invalid,
+ #[error("numeric literal not representable by target type")]
+ NotRepresentable,
+ #[error("unimplemented f16 type")]
+ UnimplementedF16,
+}
+
+#[derive(Copy, Clone, Debug, PartialEq)]
+pub enum InvalidAssignmentType {
+ Other,
+ Swizzle,
+ ImmutableBinding(Span),
+}
+
+#[derive(Clone, Debug)]
+pub enum Error<'a> {
+ Unexpected(Span, ExpectedToken<'a>),
+ UnexpectedComponents(Span),
+ BadNumber(Span, NumberError),
+ /// A negative signed integer literal where both signed and unsigned,
+ /// but only non-negative literals are allowed.
+ NegativeInt(Span),
+ BadU32Constant(Span),
+ BadMatrixScalarKind(Span, crate::ScalarKind, u8),
+ BadAccessor(Span),
+ BadTexture(Span),
+ BadTypeCast {
+ span: Span,
+ from_type: String,
+ to_type: String,
+ },
+ BadTextureSampleType {
+ span: Span,
+ kind: crate::ScalarKind,
+ width: u8,
+ },
+ BadIncrDecrReferenceType(Span),
+ InvalidResolve(ResolveError),
+ InvalidForInitializer(Span),
+ /// A break if appeared outside of a continuing block
+ InvalidBreakIf(Span),
+ InvalidGatherComponent(Span),
+ InvalidConstructorComponentType(Span, i32),
+ InvalidIdentifierUnderscore(Span),
+ ReservedIdentifierPrefix(Span),
+ UnknownAddressSpace(Span),
+ UnknownAttribute(Span),
+ UnknownBuiltin(Span),
+ UnknownAccess(Span),
+ UnknownIdent(Span, &'a str),
+ UnknownScalarType(Span),
+ UnknownType(Span),
+ UnknownStorageFormat(Span),
+ UnknownConservativeDepth(Span),
+ SizeAttributeTooLow(Span, u32),
+ AlignAttributeTooLow(Span, Alignment),
+ NonPowerOfTwoAlignAttribute(Span),
+ InconsistentBinding(Span),
+ TypeNotConstructible(Span),
+ TypeNotInferrable(Span),
+ InitializationTypeMismatch(Span, String, String),
+ MissingType(Span),
+ MissingAttribute(&'static str, Span),
+ InvalidAtomicPointer(Span),
+ InvalidAtomicOperandType(Span),
+ InvalidRayQueryPointer(Span),
+ Pointer(&'static str, Span),
+ NotPointer(Span),
+ NotReference(&'static str, Span),
+ InvalidAssignment {
+ span: Span,
+ ty: InvalidAssignmentType,
+ },
+ ReservedKeyword(Span),
+ /// Redefinition of an identifier (used for both module-scope and local redefinitions).
+ Redefinition {
+ /// Span of the identifier in the previous definition.
+ previous: Span,
+
+ /// Span of the identifier in the new definition.
+ current: Span,
+ },
+ /// A declaration refers to itself directly.
+ RecursiveDeclaration {
+ /// The location of the name of the declaration.
+ ident: Span,
+
+ /// The point at which it is used.
+ usage: Span,
+ },
+ /// A declaration refers to itself indirectly, through one or more other
+ /// definitions.
+ CyclicDeclaration {
+ /// The location of the name of some declaration in the cycle.
+ ident: Span,
+
+ /// The edges of the cycle of references.
+ ///
+ /// Each `(decl, reference)` pair indicates that the declaration whose
+ /// name is `decl` has an identifier at `reference` whose definition is
+ /// the next declaration in the cycle. The last pair's `reference` is
+ /// the same identifier as `ident`, above.
+ path: Vec<(Span, Span)>,
+ },
+ ConstExprUnsupported(Span),
+ InvalidSwitchValue {
+ uint: bool,
+ span: Span,
+ },
+ CalledEntryPoint(Span),
+ WrongArgumentCount {
+ span: Span,
+ expected: Range<u32>,
+ found: u32,
+ },
+ FunctionReturnsVoid(Span),
+ Other,
+}
+
+impl<'a> Error<'a> {
+ pub(crate) fn as_parse_error(&self, source: &'a str) -> ParseError {
+ match *self {
+ Error::Unexpected(unexpected_span, expected) => {
+ let expected_str = match expected {
+ ExpectedToken::Token(token) => {
+ match token {
+ Token::Separator(c) => format!("'{c}'"),
+ Token::Paren(c) => format!("'{c}'"),
+ Token::Attribute => "@".to_string(),
+ Token::Number(_) => "number".to_string(),
+ Token::Word(s) => s.to_string(),
+ Token::Operation(c) => format!("operation ('{c}')"),
+ Token::LogicalOperation(c) => format!("logical operation ('{c}')"),
+ Token::ShiftOperation(c) => format!("bitshift ('{c}{c}')"),
+ Token::AssignmentOperation(c) if c=='<' || c=='>' => format!("bitshift ('{c}{c}=')"),
+ Token::AssignmentOperation(c) => format!("operation ('{c}=')"),
+ Token::IncrementOperation => "increment operation".to_string(),
+ Token::DecrementOperation => "decrement operation".to_string(),
+ Token::Arrow => "->".to_string(),
+ Token::Unknown(c) => format!("unknown ('{c}')"),
+ Token::Trivia => "trivia".to_string(),
+ Token::End => "end".to_string(),
+ }
+ }
+ ExpectedToken::Identifier => "identifier".to_string(),
+ ExpectedToken::Number => "32-bit signed integer literal".to_string(),
+ ExpectedToken::Integer => "unsigned/signed integer literal".to_string(),
+ ExpectedToken::Constant => "compile-time constant".to_string(),
+ ExpectedToken::PrimaryExpression => "expression".to_string(),
+ ExpectedToken::Assignment => "assignment or increment/decrement".to_string(),
+ ExpectedToken::SwitchItem => "switch item ('case' or 'default') or a closing curly bracket to signify the end of the switch statement ('}')".to_string(),
+ ExpectedToken::WorkgroupSizeSeparator => "workgroup size separator (',') or a closing parenthesis".to_string(),
+ ExpectedToken::GlobalItem => "global item ('struct', 'const', 'var', 'alias', ';', 'fn') or the end of the file".to_string(),
+ ExpectedToken::Type => "type".to_string(),
+ ExpectedToken::Variable => "variable access".to_string(),
+ ExpectedToken::Function => "function name".to_string(),
+ };
+ ParseError {
+ message: format!(
+ "expected {}, found '{}'",
+ expected_str, &source[unexpected_span],
+ ),
+ labels: vec![(unexpected_span, format!("expected {expected_str}").into())],
+ notes: vec![],
+ }
+ }
+ Error::UnexpectedComponents(bad_span) => ParseError {
+ message: "unexpected components".to_string(),
+ labels: vec![(bad_span, "unexpected components".into())],
+ notes: vec![],
+ },
+ Error::BadNumber(bad_span, ref err) => ParseError {
+ message: format!("{}: `{}`", err, &source[bad_span],),
+ labels: vec![(bad_span, err.to_string().into())],
+ notes: vec![],
+ },
+ Error::NegativeInt(bad_span) => ParseError {
+ message: format!(
+ "expected non-negative integer literal, found `{}`",
+ &source[bad_span],
+ ),
+ labels: vec![(bad_span, "expected non-negative integer".into())],
+ notes: vec![],
+ },
+ Error::BadU32Constant(bad_span) => ParseError {
+ message: format!(
+ "expected unsigned integer constant expression, found `{}`",
+ &source[bad_span],
+ ),
+ labels: vec![(bad_span, "expected unsigned integer".into())],
+ notes: vec![],
+ },
+ Error::BadMatrixScalarKind(span, kind, width) => ParseError {
+ message: format!(
+ "matrix scalar type must be floating-point, but found `{}`",
+ kind.to_wgsl(width)
+ ),
+ labels: vec![(span, "must be floating-point (e.g. `f32`)".into())],
+ notes: vec![],
+ },
+ Error::BadAccessor(accessor_span) => ParseError {
+ message: format!("invalid field accessor `{}`", &source[accessor_span],),
+ labels: vec![(accessor_span, "invalid accessor".into())],
+ notes: vec![],
+ },
+ Error::UnknownIdent(ident_span, ident) => ParseError {
+ message: format!("no definition in scope for identifier: '{ident}'"),
+ labels: vec![(ident_span, "unknown identifier".into())],
+ notes: vec![],
+ },
+ Error::UnknownScalarType(bad_span) => ParseError {
+ message: format!("unknown scalar type: '{}'", &source[bad_span]),
+ labels: vec![(bad_span, "unknown scalar type".into())],
+ notes: vec!["Valid scalar types are f32, f64, i32, u32, bool".into()],
+ },
+ Error::BadTextureSampleType { span, kind, width } => ParseError {
+ message: format!(
+ "texture sample type must be one of f32, i32 or u32, but found {}",
+ kind.to_wgsl(width)
+ ),
+ labels: vec![(span, "must be one of f32, i32 or u32".into())],
+ notes: vec![],
+ },
+ Error::BadIncrDecrReferenceType(span) => ParseError {
+ message:
+ "increment/decrement operation requires reference type to be one of i32 or u32"
+ .to_string(),
+ labels: vec![(span, "must be a reference type of i32 or u32".into())],
+ notes: vec![],
+ },
+ Error::BadTexture(bad_span) => ParseError {
+ message: format!(
+ "expected an image, but found '{}' which is not an image",
+ &source[bad_span]
+ ),
+ labels: vec![(bad_span, "not an image".into())],
+ notes: vec![],
+ },
+ Error::BadTypeCast {
+ span,
+ ref from_type,
+ ref to_type,
+ } => {
+ let msg = format!("cannot cast a {from_type} to a {to_type}");
+ ParseError {
+ message: msg.clone(),
+ labels: vec![(span, msg.into())],
+ notes: vec![],
+ }
+ }
+ Error::InvalidResolve(ref resolve_error) => ParseError {
+ message: resolve_error.to_string(),
+ labels: vec![],
+ notes: vec![],
+ },
+ Error::InvalidForInitializer(bad_span) => ParseError {
+ message: format!(
+ "for(;;) initializer is not an assignment or a function call: '{}'",
+ &source[bad_span]
+ ),
+ labels: vec![(bad_span, "not an assignment or function call".into())],
+ notes: vec![],
+ },
+ Error::InvalidBreakIf(bad_span) => ParseError {
+ message: "A break if is only allowed in a continuing block".to_string(),
+ labels: vec![(bad_span, "not in a continuing block".into())],
+ notes: vec![],
+ },
+ Error::InvalidGatherComponent(bad_span) => ParseError {
+ message: format!(
+ "textureGather component '{}' doesn't exist, must be 0, 1, 2, or 3",
+ &source[bad_span]
+ ),
+ labels: vec![(bad_span, "invalid component".into())],
+ notes: vec![],
+ },
+ Error::InvalidConstructorComponentType(bad_span, component) => ParseError {
+ message: format!("invalid type for constructor component at index [{component}]"),
+ labels: vec![(bad_span, "invalid component type".into())],
+ notes: vec![],
+ },
+ Error::InvalidIdentifierUnderscore(bad_span) => ParseError {
+ message: "Identifier can't be '_'".to_string(),
+ labels: vec![(bad_span, "invalid identifier".into())],
+ notes: vec![
+ "Use phony assignment instead ('_ =' notice the absence of 'let' or 'var')"
+ .to_string(),
+ ],
+ },
+ Error::ReservedIdentifierPrefix(bad_span) => ParseError {
+ message: format!(
+ "Identifier starts with a reserved prefix: '{}'",
+ &source[bad_span]
+ ),
+ labels: vec![(bad_span, "invalid identifier".into())],
+ notes: vec![],
+ },
+ Error::UnknownAddressSpace(bad_span) => ParseError {
+ message: format!("unknown address space: '{}'", &source[bad_span]),
+ labels: vec![(bad_span, "unknown address space".into())],
+ notes: vec![],
+ },
+ Error::UnknownAttribute(bad_span) => ParseError {
+ message: format!("unknown attribute: '{}'", &source[bad_span]),
+ labels: vec![(bad_span, "unknown attribute".into())],
+ notes: vec![],
+ },
+ Error::UnknownBuiltin(bad_span) => ParseError {
+ message: format!("unknown builtin: '{}'", &source[bad_span]),
+ labels: vec![(bad_span, "unknown builtin".into())],
+ notes: vec![],
+ },
+ Error::UnknownAccess(bad_span) => ParseError {
+ message: format!("unknown access: '{}'", &source[bad_span]),
+ labels: vec![(bad_span, "unknown access".into())],
+ notes: vec![],
+ },
+ Error::UnknownStorageFormat(bad_span) => ParseError {
+ message: format!("unknown storage format: '{}'", &source[bad_span]),
+ labels: vec![(bad_span, "unknown storage format".into())],
+ notes: vec![],
+ },
+ Error::UnknownConservativeDepth(bad_span) => ParseError {
+ message: format!("unknown conservative depth: '{}'", &source[bad_span]),
+ labels: vec![(bad_span, "unknown conservative depth".into())],
+ notes: vec![],
+ },
+ Error::UnknownType(bad_span) => ParseError {
+ message: format!("unknown type: '{}'", &source[bad_span]),
+ labels: vec![(bad_span, "unknown type".into())],
+ notes: vec![],
+ },
+ Error::SizeAttributeTooLow(bad_span, min_size) => ParseError {
+ message: format!("struct member size must be at least {min_size}"),
+ labels: vec![(bad_span, format!("must be at least {min_size}").into())],
+ notes: vec![],
+ },
+ Error::AlignAttributeTooLow(bad_span, min_align) => ParseError {
+ message: format!("struct member alignment must be at least {min_align}"),
+ labels: vec![(bad_span, format!("must be at least {min_align}").into())],
+ notes: vec![],
+ },
+ Error::NonPowerOfTwoAlignAttribute(bad_span) => ParseError {
+ message: "struct member alignment must be a power of 2".to_string(),
+ labels: vec![(bad_span, "must be a power of 2".into())],
+ notes: vec![],
+ },
+ Error::InconsistentBinding(span) => ParseError {
+ message: "input/output binding is not consistent".to_string(),
+ labels: vec![(span, "input/output binding is not consistent".into())],
+ notes: vec![],
+ },
+ Error::TypeNotConstructible(span) => ParseError {
+ message: format!("type `{}` is not constructible", &source[span]),
+ labels: vec![(span, "type is not constructible".into())],
+ notes: vec![],
+ },
+ Error::TypeNotInferrable(span) => ParseError {
+ message: "type can't be inferred".to_string(),
+ labels: vec![(span, "type can't be inferred".into())],
+ notes: vec![],
+ },
+ Error::InitializationTypeMismatch(name_span, ref expected_ty, ref got_ty) => {
+ ParseError {
+ message: format!(
+ "the type of `{}` is expected to be `{}`, but got `{}`",
+ &source[name_span], expected_ty, got_ty,
+ ),
+ labels: vec![(
+ name_span,
+ format!("definition of `{}`", &source[name_span]).into(),
+ )],
+ notes: vec![],
+ }
+ }
+ Error::MissingType(name_span) => ParseError {
+ message: format!("variable `{}` needs a type", &source[name_span]),
+ labels: vec![(
+ name_span,
+ format!("definition of `{}`", &source[name_span]).into(),
+ )],
+ notes: vec![],
+ },
+ Error::MissingAttribute(name, name_span) => ParseError {
+ message: format!(
+ "variable `{}` needs a '{}' attribute",
+ &source[name_span], name
+ ),
+ labels: vec![(
+ name_span,
+ format!("definition of `{}`", &source[name_span]).into(),
+ )],
+ notes: vec![],
+ },
+ Error::InvalidAtomicPointer(span) => ParseError {
+ message: "atomic operation is done on a pointer to a non-atomic".to_string(),
+ labels: vec![(span, "atomic pointer is invalid".into())],
+ notes: vec![],
+ },
+ Error::InvalidAtomicOperandType(span) => ParseError {
+ message: "atomic operand type is inconsistent with the operation".to_string(),
+ labels: vec![(span, "atomic operand type is invalid".into())],
+ notes: vec![],
+ },
+ Error::InvalidRayQueryPointer(span) => ParseError {
+ message: "ray query operation is done on a pointer to a non-ray-query".to_string(),
+ labels: vec![(span, "ray query pointer is invalid".into())],
+ notes: vec![],
+ },
+ Error::NotPointer(span) => ParseError {
+ message: "the operand of the `*` operator must be a pointer".to_string(),
+ labels: vec![(span, "expression is not a pointer".into())],
+ notes: vec![],
+ },
+ Error::NotReference(what, span) => ParseError {
+ message: format!("{what} must be a reference"),
+ labels: vec![(span, "expression is not a reference".into())],
+ notes: vec![],
+ },
+ Error::InvalidAssignment { span, ty } => {
+ let (extra_label, notes) = match ty {
+ InvalidAssignmentType::Swizzle => (
+ None,
+ vec![
+ "WGSL does not support assignments to swizzles".into(),
+ "consider assigning each component individually".into(),
+ ],
+ ),
+ InvalidAssignmentType::ImmutableBinding(binding_span) => (
+ Some((binding_span, "this is an immutable binding".into())),
+ vec![format!(
+ "consider declaring '{}' with `var` instead of `let`",
+ &source[binding_span]
+ )],
+ ),
+ InvalidAssignmentType::Other => (None, vec![]),
+ };
+
+ ParseError {
+ message: "invalid left-hand side of assignment".into(),
+ labels: std::iter::once((span, "cannot assign to this expression".into()))
+ .chain(extra_label)
+ .collect(),
+ notes,
+ }
+ }
+ Error::Pointer(what, span) => ParseError {
+ message: format!("{what} must not be a pointer"),
+ labels: vec![(span, "expression is a pointer".into())],
+ notes: vec![],
+ },
+ Error::ReservedKeyword(name_span) => ParseError {
+ message: format!("name `{}` is a reserved keyword", &source[name_span]),
+ labels: vec![(
+ name_span,
+ format!("definition of `{}`", &source[name_span]).into(),
+ )],
+ notes: vec![],
+ },
+ Error::Redefinition { previous, current } => ParseError {
+ message: format!("redefinition of `{}`", &source[current]),
+ labels: vec![
+ (
+ current,
+ format!("redefinition of `{}`", &source[current]).into(),
+ ),
+ (
+ previous,
+ format!("previous definition of `{}`", &source[previous]).into(),
+ ),
+ ],
+ notes: vec![],
+ },
+ Error::RecursiveDeclaration { ident, usage } => ParseError {
+ message: format!("declaration of `{}` is recursive", &source[ident]),
+ labels: vec![(ident, "".into()), (usage, "uses itself here".into())],
+ notes: vec![],
+ },
+ Error::CyclicDeclaration { ident, ref path } => ParseError {
+ message: format!("declaration of `{}` is cyclic", &source[ident]),
+ labels: path
+ .iter()
+ .enumerate()
+ .flat_map(|(i, &(ident, usage))| {
+ [
+ (ident, "".into()),
+ (
+ usage,
+ if i == path.len() - 1 {
+ "ending the cycle".into()
+ } else {
+ format!("uses `{}`", &source[ident]).into()
+ },
+ ),
+ ]
+ })
+ .collect(),
+ notes: vec![],
+ },
+ Error::ConstExprUnsupported(span) => ParseError {
+ message: "this constant expression is not supported".to_string(),
+ labels: vec![(span, "expression is not supported".into())],
+ notes: vec![
+ "this should be fixed in a future version of Naga".into(),
+ "https://github.com/gfx-rs/naga/issues/1829".into(),
+ ],
+ },
+ Error::InvalidSwitchValue { uint, span } => ParseError {
+ message: "invalid switch value".to_string(),
+ labels: vec![(
+ span,
+ if uint {
+ "expected unsigned integer"
+ } else {
+ "expected signed integer"
+ }
+ .into(),
+ )],
+ notes: vec![if uint {
+ format!("suffix the integer with a `u`: '{}u'", &source[span])
+ } else {
+ let span = span.to_range().unwrap();
+ format!(
+ "remove the `u` suffix: '{}'",
+ &source[span.start..span.end - 1]
+ )
+ }],
+ },
+ Error::CalledEntryPoint(span) => ParseError {
+ message: "entry point cannot be called".to_string(),
+ labels: vec![(span, "entry point cannot be called".into())],
+ notes: vec![],
+ },
+ Error::WrongArgumentCount {
+ span,
+ ref expected,
+ found,
+ } => ParseError {
+ message: format!(
+ "wrong number of arguments: expected {}, found {}",
+ if expected.len() < 2 {
+ format!("{}", expected.start)
+ } else {
+ format!("{}..{}", expected.start, expected.end)
+ },
+ found
+ ),
+ labels: vec![(span, "wrong number of arguments".into())],
+ notes: vec![],
+ },
+ Error::FunctionReturnsVoid(span) => ParseError {
+ message: "function does not return any value".to_string(),
+ labels: vec![(span, "".into())],
+ notes: vec![
+ "perhaps you meant to call the function in a separate statement?".into(),
+ ],
+ },
+ Error::Other => ParseError {
+ message: "other error".to_string(),
+ labels: vec![],
+ notes: vec![],
+ },
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/index.rs b/third_party/rust/naga/src/front/wgsl/index.rs
new file mode 100644
index 0000000000..a5524fe8f1
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/index.rs
@@ -0,0 +1,193 @@
+use super::Error;
+use crate::front::wgsl::parse::ast;
+use crate::{FastHashMap, Handle, Span};
+
+/// A `GlobalDecl` list in which each definition occurs before all its uses.
+pub struct Index<'a> {
+ dependency_order: Vec<Handle<ast::GlobalDecl<'a>>>,
+}
+
+impl<'a> Index<'a> {
+ /// Generate an `Index` for the given translation unit.
+ ///
+ /// Perform a topological sort on `tu`'s global declarations, placing
+ /// referents before the definitions that refer to them.
+ ///
+ /// Return an error if the graph of references between declarations contains
+ /// any cycles.
+ pub fn generate(tu: &ast::TranslationUnit<'a>) -> Result<Self, Error<'a>> {
+ // Produce a map from global definitions' names to their `Handle<GlobalDecl>`s.
+ // While doing so, reject conflicting definitions.
+ let mut globals = FastHashMap::with_capacity_and_hasher(tu.decls.len(), Default::default());
+ for (handle, decl) in tu.decls.iter() {
+ let ident = decl_ident(decl);
+ let name = ident.name;
+ if let Some(old) = globals.insert(name, handle) {
+ return Err(Error::Redefinition {
+ previous: decl_ident(&tu.decls[old]).span,
+ current: ident.span,
+ });
+ }
+ }
+
+ let len = tu.decls.len();
+ let solver = DependencySolver {
+ globals: &globals,
+ module: tu,
+ visited: vec![false; len],
+ temp_visited: vec![false; len],
+ path: Vec::new(),
+ out: Vec::with_capacity(len),
+ };
+ let dependency_order = solver.solve()?;
+
+ Ok(Self { dependency_order })
+ }
+
+ /// Iterate over `GlobalDecl`s, visiting each definition before all its uses.
+ ///
+ /// Produce handles for all of the `GlobalDecl`s of the `TranslationUnit`
+ /// passed to `Index::generate`, ordered so that a given declaration is
+ /// produced before any other declaration that uses it.
+ pub fn visit_ordered(&self) -> impl Iterator<Item = Handle<ast::GlobalDecl<'a>>> + '_ {
+ self.dependency_order.iter().copied()
+ }
+}
+
+/// An edge from a reference to its referent in the current depth-first
+/// traversal.
+///
+/// This is like `ast::Dependency`, except that we've determined which
+/// `GlobalDecl` it refers to.
+struct ResolvedDependency<'a> {
+ /// The referent of some identifier used in the current declaration.
+ decl: Handle<ast::GlobalDecl<'a>>,
+
+ /// Where that use occurs within the current declaration.
+ usage: Span,
+}
+
+/// Local state for ordering a `TranslationUnit`'s module-scope declarations.
+///
+/// Values of this type are used temporarily by `Index::generate`
+/// to perform a depth-first sort on the declarations.
+/// Technically, what we want is a topological sort, but a depth-first sort
+/// has one key benefit - it's much more efficient in storing
+/// the path of each node for error generation.
+struct DependencySolver<'source, 'temp> {
+ /// A map from module-scope definitions' names to their handles.
+ globals: &'temp FastHashMap<&'source str, Handle<ast::GlobalDecl<'source>>>,
+
+ /// The translation unit whose declarations we're ordering.
+ module: &'temp ast::TranslationUnit<'source>,
+
+ /// For each handle, whether we have pushed it onto `out` yet.
+ visited: Vec<bool>,
+
+ /// For each handle, whether it is an predecessor in the current depth-first
+ /// traversal. This is used to detect cycles in the reference graph.
+ temp_visited: Vec<bool>,
+
+ /// The current path in our depth-first traversal. Used for generating
+ /// error messages for non-trivial reference cycles.
+ path: Vec<ResolvedDependency<'source>>,
+
+ /// The list of declaration handles, with declarations before uses.
+ out: Vec<Handle<ast::GlobalDecl<'source>>>,
+}
+
+impl<'a> DependencySolver<'a, '_> {
+ /// Produce the sorted list of declaration handles, and check for cycles.
+ fn solve(mut self) -> Result<Vec<Handle<ast::GlobalDecl<'a>>>, Error<'a>> {
+ for (id, _) in self.module.decls.iter() {
+ if self.visited[id.index()] {
+ continue;
+ }
+
+ self.dfs(id)?;
+ }
+
+ Ok(self.out)
+ }
+
+ /// Ensure that all declarations used by `id` have been added to the
+ /// ordering, and then append `id` itself.
+ fn dfs(&mut self, id: Handle<ast::GlobalDecl<'a>>) -> Result<(), Error<'a>> {
+ let decl = &self.module.decls[id];
+ let id_usize = id.index();
+
+ self.temp_visited[id_usize] = true;
+ for dep in decl.dependencies.iter() {
+ if let Some(&dep_id) = self.globals.get(dep.ident) {
+ self.path.push(ResolvedDependency {
+ decl: dep_id,
+ usage: dep.usage,
+ });
+ let dep_id_usize = dep_id.index();
+
+ if self.temp_visited[dep_id_usize] {
+ // Found a cycle.
+ return if dep_id == id {
+ // A declaration refers to itself directly.
+ Err(Error::RecursiveDeclaration {
+ ident: decl_ident(decl).span,
+ usage: dep.usage,
+ })
+ } else {
+ // A declaration refers to itself indirectly, through
+ // one or more other definitions. Report the entire path
+ // of references.
+ let start_at = self
+ .path
+ .iter()
+ .rev()
+ .enumerate()
+ .find_map(|(i, dep)| (dep.decl == dep_id).then_some(i))
+ .unwrap_or(0);
+
+ Err(Error::CyclicDeclaration {
+ ident: decl_ident(&self.module.decls[dep_id]).span,
+ path: self.path[start_at..]
+ .iter()
+ .map(|curr_dep| {
+ let curr_id = curr_dep.decl;
+ let curr_decl = &self.module.decls[curr_id];
+
+ (decl_ident(curr_decl).span, curr_dep.usage)
+ })
+ .collect(),
+ })
+ };
+ } else if !self.visited[dep_id_usize] {
+ self.dfs(dep_id)?;
+ }
+
+ // Remove this edge from the current path.
+ self.path.pop();
+ }
+
+ // Ignore unresolved identifiers; they may be predeclared objects.
+ }
+
+ // Remove this node from the current path.
+ self.temp_visited[id_usize] = false;
+
+ // Now everything this declaration uses has been visited, and is already
+ // present in `out`. That means we we can append this one to the
+ // ordering, and mark it as visited.
+ self.out.push(id);
+ self.visited[id_usize] = true;
+
+ Ok(())
+ }
+}
+
+const fn decl_ident<'a>(decl: &ast::GlobalDecl<'a>) -> ast::Ident<'a> {
+ match decl.kind {
+ ast::GlobalDeclKind::Fn(ref f) => f.name,
+ ast::GlobalDeclKind::Var(ref v) => v.name,
+ ast::GlobalDeclKind::Const(ref c) => c.name,
+ ast::GlobalDeclKind::Struct(ref s) => s.name,
+ ast::GlobalDeclKind::Type(ref t) => t.name,
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/lower/construction.rs b/third_party/rust/naga/src/front/wgsl/lower/construction.rs
new file mode 100644
index 0000000000..723d4441f5
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/lower/construction.rs
@@ -0,0 +1,668 @@
+use crate::front::wgsl::parse::ast;
+use crate::{Handle, Span};
+
+use crate::front::wgsl::error::Error;
+use crate::front::wgsl::lower::{ExpressionContext, Lowerer, OutputContext};
+use crate::proc::TypeResolution;
+
+enum ConcreteConstructorHandle {
+ PartialVector {
+ size: crate::VectorSize,
+ },
+ PartialMatrix {
+ columns: crate::VectorSize,
+ rows: crate::VectorSize,
+ },
+ PartialArray,
+ Type(Handle<crate::Type>),
+}
+
+impl ConcreteConstructorHandle {
+ fn borrow<'a>(&self, module: &'a crate::Module) -> ConcreteConstructor<'a> {
+ match *self {
+ Self::PartialVector { size } => ConcreteConstructor::PartialVector { size },
+ Self::PartialMatrix { columns, rows } => {
+ ConcreteConstructor::PartialMatrix { columns, rows }
+ }
+ Self::PartialArray => ConcreteConstructor::PartialArray,
+ Self::Type(handle) => ConcreteConstructor::Type(handle, &module.types[handle].inner),
+ }
+ }
+}
+
+enum ConcreteConstructor<'a> {
+ PartialVector {
+ size: crate::VectorSize,
+ },
+ PartialMatrix {
+ columns: crate::VectorSize,
+ rows: crate::VectorSize,
+ },
+ PartialArray,
+ Type(Handle<crate::Type>, &'a crate::TypeInner),
+}
+
+impl ConcreteConstructorHandle {
+ fn to_error_string(&self, ctx: ExpressionContext) -> String {
+ match *self {
+ Self::PartialVector { size } => {
+ format!("vec{}<?>", size as u32,)
+ }
+ Self::PartialMatrix { columns, rows } => {
+ format!("mat{}x{}<?>", columns as u32, rows as u32,)
+ }
+ Self::PartialArray => "array<?, ?>".to_string(),
+ Self::Type(ty) => ctx.format_type(ty),
+ }
+ }
+}
+
+enum ComponentsHandle<'a> {
+ None,
+ One {
+ component: Handle<crate::Expression>,
+ span: Span,
+ ty: &'a TypeResolution,
+ },
+ Many {
+ components: Vec<Handle<crate::Expression>>,
+ spans: Vec<Span>,
+ first_component_ty: &'a TypeResolution,
+ },
+}
+
+impl<'a> ComponentsHandle<'a> {
+ fn borrow(self, module: &'a crate::Module) -> Components<'a> {
+ match self {
+ Self::None => Components::None,
+ Self::One {
+ component,
+ span,
+ ty,
+ } => Components::One {
+ component,
+ span,
+ ty_inner: ty.inner_with(&module.types),
+ },
+ Self::Many {
+ components,
+ spans,
+ first_component_ty,
+ } => Components::Many {
+ components,
+ spans,
+ first_component_ty_inner: first_component_ty.inner_with(&module.types),
+ },
+ }
+ }
+}
+
+enum Components<'a> {
+ None,
+ One {
+ component: Handle<crate::Expression>,
+ span: Span,
+ ty_inner: &'a crate::TypeInner,
+ },
+ Many {
+ components: Vec<Handle<crate::Expression>>,
+ spans: Vec<Span>,
+ first_component_ty_inner: &'a crate::TypeInner,
+ },
+}
+
+impl Components<'_> {
+ fn into_components_vec(self) -> Vec<Handle<crate::Expression>> {
+ match self {
+ Self::None => vec![],
+ Self::One { component, .. } => vec![component],
+ Self::Many { components, .. } => components,
+ }
+ }
+}
+
+impl<'source, 'temp> Lowerer<'source, 'temp> {
+ /// Generate Naga IR for a type constructor expression.
+ ///
+ /// The `constructor` value represents the head of the constructor
+ /// expression, which is at least a hint of which type is being built; if
+ /// it's one of the `Partial` variants, we need to consider the argument
+ /// types as well.
+ ///
+ /// This is used for [`Construct`] expressions, but also for [`Call`]
+ /// expressions, once we've determined that the "callable" (in WGSL spec
+ /// terms) is actually a type.
+ ///
+ /// [`Construct`]: ast::Expression::Construct
+ /// [`Call`]: ast::Expression::Call
+ pub fn construct(
+ &mut self,
+ span: Span,
+ constructor: &ast::ConstructorType<'source>,
+ ty_span: Span,
+ components: &[Handle<ast::Expression<'source>>],
+ mut ctx: ExpressionContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ let constructor_h = self.constructor(constructor, ctx.as_output())?;
+
+ let components_h = match *components {
+ [] => ComponentsHandle::None,
+ [component] => {
+ let span = ctx.ast_expressions.get_span(component);
+ let component = self.expression(component, ctx.reborrow())?;
+ ctx.grow_types(component)?;
+ let ty = &ctx.typifier[component];
+
+ ComponentsHandle::One {
+ component,
+ span,
+ ty,
+ }
+ }
+ [component, ref rest @ ..] => {
+ let span = ctx.ast_expressions.get_span(component);
+ let component = self.expression(component, ctx.reborrow())?;
+
+ let components = std::iter::once(Ok(component))
+ .chain(
+ rest.iter()
+ .map(|&component| self.expression(component, ctx.reborrow())),
+ )
+ .collect::<Result<_, _>>()?;
+ let spans = std::iter::once(span)
+ .chain(
+ rest.iter()
+ .map(|&component| ctx.ast_expressions.get_span(component)),
+ )
+ .collect();
+
+ ctx.grow_types(component)?;
+ let ty = &ctx.typifier[component];
+
+ ComponentsHandle::Many {
+ components,
+ spans,
+ first_component_ty: ty,
+ }
+ }
+ };
+
+ let (components, constructor) = (
+ components_h.borrow(ctx.module),
+ constructor_h.borrow(ctx.module),
+ );
+ let expr = match (components, constructor) {
+ // Empty constructor
+ (Components::None, dst_ty) => {
+ let ty = match dst_ty {
+ ConcreteConstructor::Type(ty, _) => ty,
+ _ => return Err(Error::TypeNotInferrable(ty_span)),
+ };
+
+ return match ctx.create_zero_value_constant(ty) {
+ Some(constant) => {
+ Ok(ctx.interrupt_emitter(crate::Expression::Constant(constant), span))
+ }
+ None => Err(Error::TypeNotConstructible(ty_span)),
+ };
+ }
+
+ // Scalar constructor & conversion (scalar -> scalar)
+ (
+ Components::One {
+ component,
+ ty_inner: &crate::TypeInner::Scalar { .. },
+ ..
+ },
+ ConcreteConstructor::Type(_, &crate::TypeInner::Scalar { kind, width }),
+ ) => crate::Expression::As {
+ expr: component,
+ kind,
+ convert: Some(width),
+ },
+
+ // Vector conversion (vector -> vector)
+ (
+ Components::One {
+ component,
+ ty_inner: &crate::TypeInner::Vector { size: src_size, .. },
+ ..
+ },
+ ConcreteConstructor::Type(
+ _,
+ &crate::TypeInner::Vector {
+ size: dst_size,
+ kind: dst_kind,
+ width: dst_width,
+ },
+ ),
+ ) if dst_size == src_size => crate::Expression::As {
+ expr: component,
+ kind: dst_kind,
+ convert: Some(dst_width),
+ },
+
+ // Vector conversion (vector -> vector) - partial
+ (
+ Components::One {
+ component,
+ ty_inner:
+ &crate::TypeInner::Vector {
+ size: src_size,
+ kind: src_kind,
+ ..
+ },
+ ..
+ },
+ ConcreteConstructor::PartialVector { size: dst_size },
+ ) if dst_size == src_size => crate::Expression::As {
+ expr: component,
+ kind: src_kind,
+ convert: None,
+ },
+
+ // Matrix conversion (matrix -> matrix)
+ (
+ Components::One {
+ component,
+ ty_inner:
+ &crate::TypeInner::Matrix {
+ columns: src_columns,
+ rows: src_rows,
+ ..
+ },
+ ..
+ },
+ ConcreteConstructor::Type(
+ _,
+ &crate::TypeInner::Matrix {
+ columns: dst_columns,
+ rows: dst_rows,
+ width: dst_width,
+ },
+ ),
+ ) if dst_columns == src_columns && dst_rows == src_rows => crate::Expression::As {
+ expr: component,
+ kind: crate::ScalarKind::Float,
+ convert: Some(dst_width),
+ },
+
+ // Matrix conversion (matrix -> matrix) - partial
+ (
+ Components::One {
+ component,
+ ty_inner:
+ &crate::TypeInner::Matrix {
+ columns: src_columns,
+ rows: src_rows,
+ ..
+ },
+ ..
+ },
+ ConcreteConstructor::PartialMatrix {
+ columns: dst_columns,
+ rows: dst_rows,
+ },
+ ) if dst_columns == src_columns && dst_rows == src_rows => crate::Expression::As {
+ expr: component,
+ kind: crate::ScalarKind::Float,
+ convert: None,
+ },
+
+ // Vector constructor (splat) - infer type
+ (
+ Components::One {
+ component,
+ ty_inner: &crate::TypeInner::Scalar { .. },
+ ..
+ },
+ ConcreteConstructor::PartialVector { size },
+ ) => crate::Expression::Splat {
+ size,
+ value: component,
+ },
+
+ // Vector constructor (splat)
+ (
+ Components::One {
+ component,
+ ty_inner:
+ &crate::TypeInner::Scalar {
+ kind: src_kind,
+ width: src_width,
+ ..
+ },
+ ..
+ },
+ ConcreteConstructor::Type(
+ _,
+ &crate::TypeInner::Vector {
+ size,
+ kind: dst_kind,
+ width: dst_width,
+ },
+ ),
+ ) if dst_kind == src_kind || dst_width == src_width => crate::Expression::Splat {
+ size,
+ value: component,
+ },
+
+ // Vector constructor (by elements)
+ (
+ Components::Many {
+ components,
+ first_component_ty_inner:
+ &crate::TypeInner::Scalar { kind, width }
+ | &crate::TypeInner::Vector { kind, width, .. },
+ ..
+ },
+ ConcreteConstructor::PartialVector { size },
+ )
+ | (
+ Components::Many {
+ components,
+ first_component_ty_inner:
+ &crate::TypeInner::Scalar { .. } | &crate::TypeInner::Vector { .. },
+ ..
+ },
+ ConcreteConstructor::Type(_, &crate::TypeInner::Vector { size, width, kind }),
+ ) => {
+ let inner = crate::TypeInner::Vector { size, kind, width };
+ let ty = ctx.ensure_type_exists(inner);
+ crate::Expression::Compose { ty, components }
+ }
+
+ // Matrix constructor (by elements)
+ (
+ Components::Many {
+ components,
+ first_component_ty_inner: &crate::TypeInner::Scalar { width, .. },
+ ..
+ },
+ ConcreteConstructor::PartialMatrix { columns, rows },
+ )
+ | (
+ Components::Many {
+ components,
+ first_component_ty_inner: &crate::TypeInner::Scalar { .. },
+ ..
+ },
+ ConcreteConstructor::Type(
+ _,
+ &crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ },
+ ),
+ ) => {
+ let vec_ty = ctx.ensure_type_exists(crate::TypeInner::Vector {
+ width,
+ kind: crate::ScalarKind::Float,
+ size: rows,
+ });
+
+ let components = components
+ .chunks(rows as usize)
+ .map(|vec_components| {
+ ctx.naga_expressions.append(
+ crate::Expression::Compose {
+ ty: vec_ty,
+ components: Vec::from(vec_components),
+ },
+ Default::default(),
+ )
+ })
+ .collect();
+
+ let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ });
+ crate::Expression::Compose { ty, components }
+ }
+
+ // Matrix constructor (by columns)
+ (
+ Components::Many {
+ components,
+ first_component_ty_inner: &crate::TypeInner::Vector { width, .. },
+ ..
+ },
+ ConcreteConstructor::PartialMatrix { columns, rows },
+ )
+ | (
+ Components::Many {
+ components,
+ first_component_ty_inner: &crate::TypeInner::Vector { .. },
+ ..
+ },
+ ConcreteConstructor::Type(
+ _,
+ &crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ },
+ ),
+ ) => {
+ let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ });
+ crate::Expression::Compose { ty, components }
+ }
+
+ // Array constructor - infer type
+ (components, ConcreteConstructor::PartialArray) => {
+ let components = components.into_components_vec();
+
+ let base = ctx.register_type(components[0])?;
+
+ let size = crate::Constant {
+ name: None,
+ specialization: None,
+ inner: crate::ConstantInner::Scalar {
+ width: 4,
+ value: crate::ScalarValue::Uint(components.len() as _),
+ },
+ };
+
+ let inner = crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(
+ ctx.module.constants.fetch_or_append(size, Span::UNDEFINED),
+ ),
+ stride: {
+ self.layouter
+ .update(&ctx.module.types, &ctx.module.constants)
+ .unwrap();
+ self.layouter[base].to_stride()
+ },
+ };
+ let ty = ctx.ensure_type_exists(inner);
+
+ crate::Expression::Compose { ty, components }
+ }
+
+ // Array constructor
+ (components, ConcreteConstructor::Type(ty, &crate::TypeInner::Array { .. })) => {
+ let components = components.into_components_vec();
+ crate::Expression::Compose { ty, components }
+ }
+
+ // Struct constructor
+ (components, ConcreteConstructor::Type(ty, &crate::TypeInner::Struct { .. })) => {
+ crate::Expression::Compose {
+ ty,
+ components: components.into_components_vec(),
+ }
+ }
+
+ // ERRORS
+
+ // Bad conversion (type cast)
+ (Components::One { span, ty_inner, .. }, _) => {
+ let from_type = ctx.format_typeinner(ty_inner);
+ return Err(Error::BadTypeCast {
+ span,
+ from_type,
+ to_type: constructor_h.to_error_string(ctx.reborrow()),
+ });
+ }
+
+ // Too many parameters for scalar constructor
+ (
+ Components::Many { spans, .. },
+ ConcreteConstructor::Type(_, &crate::TypeInner::Scalar { .. }),
+ ) => {
+ let span = spans[1].until(spans.last().unwrap());
+ return Err(Error::UnexpectedComponents(span));
+ }
+
+ // Parameters are of the wrong type for vector or matrix constructor
+ (
+ Components::Many { spans, .. },
+ ConcreteConstructor::Type(
+ _,
+ &crate::TypeInner::Vector { .. } | &crate::TypeInner::Matrix { .. },
+ )
+ | ConcreteConstructor::PartialVector { .. }
+ | ConcreteConstructor::PartialMatrix { .. },
+ ) => {
+ return Err(Error::InvalidConstructorComponentType(spans[0], 0));
+ }
+
+ // Other types can't be constructed
+ _ => return Err(Error::TypeNotConstructible(ty_span)),
+ };
+
+ let expr = ctx.naga_expressions.append(expr, span);
+ Ok(expr)
+ }
+
+ /// Build a Naga IR [`ConstantInner`] given a WGSL construction expression.
+ ///
+ /// Given `constructor`, representing the head of a WGSL [`type constructor
+ /// expression`], and a slice of [`ast::Expression`] handles representing
+ /// the constructor's arguments, build a Naga [`ConstantInner`] value
+ /// representing the given value.
+ ///
+ /// If `constructor` is for a composite type, this may entail adding new
+ /// [`Type`]s and [`Constant`]s to [`ctx.module`], if it doesn't already
+ /// have what we need.
+ ///
+ /// If the arguments cannot be evaluated at compile time, return an error.
+ ///
+ /// [`ConstantInner`]: crate::ConstantInner
+ /// [`type constructor expression`]: https://gpuweb.github.io/gpuweb/wgsl/#type-constructor-expr
+ /// [`Function::expressions`]: ast::Function::expressions
+ /// [`TranslationUnit::global_expressions`]: ast::TranslationUnit::global_expressions
+ /// [`Type`]: crate::Type
+ /// [`Constant`]: crate::Constant
+ /// [`ctx.module`]: OutputContext::module
+ pub fn const_construct(
+ &mut self,
+ span: Span,
+ constructor: &ast::ConstructorType<'source>,
+ components: &[Handle<ast::Expression<'source>>],
+ mut ctx: OutputContext<'source, '_, '_>,
+ ) -> Result<crate::ConstantInner, Error<'source>> {
+ // TODO: Support zero values, splatting and inference.
+
+ let constructor = self.constructor(constructor, ctx.reborrow())?;
+
+ let c = match constructor {
+ ConcreteConstructorHandle::Type(ty) => {
+ let components = components
+ .iter()
+ .map(|&expr| self.constant(expr, ctx.reborrow()))
+ .collect::<Result<_, _>>()?;
+
+ crate::ConstantInner::Composite { ty, components }
+ }
+ _ => return Err(Error::ConstExprUnsupported(span)),
+ };
+ Ok(c)
+ }
+
+ /// Build a Naga IR [`Type`] for `constructor` if there is enough
+ /// information to do so.
+ ///
+ /// For `Partial` variants of [`ast::ConstructorType`], we don't know the
+ /// component type, so in that case we return the appropriate `Partial`
+ /// variant of [`ConcreteConstructorHandle`].
+ ///
+ /// But for the other `ConstructorType` variants, we have everything we need
+ /// to know to actually produce a Naga IR type. In this case we add to/find
+ /// in [`ctx.module`] a suitable Naga `Type` and return a
+ /// [`ConcreteConstructorHandle::Type`] value holding its handle.
+ ///
+ /// Note that constructing an [`Array`] type may require inserting
+ /// [`Constant`]s as well as `Type`s into `ctx.module`, to represent the
+ /// array's length.
+ ///
+ /// [`Type`]: crate::Type
+ /// [`ctx.module`]: OutputContext::module
+ /// [`Array`]: crate::TypeInner::Array
+ /// [`Constant`]: crate::Constant
+ fn constructor<'out>(
+ &mut self,
+ constructor: &ast::ConstructorType<'source>,
+ mut ctx: OutputContext<'source, '_, 'out>,
+ ) -> Result<ConcreteConstructorHandle, Error<'source>> {
+ let c = match *constructor {
+ ast::ConstructorType::Scalar { width, kind } => {
+ let ty = ctx.ensure_type_exists(crate::TypeInner::Scalar { width, kind });
+ ConcreteConstructorHandle::Type(ty)
+ }
+ ast::ConstructorType::PartialVector { size } => {
+ ConcreteConstructorHandle::PartialVector { size }
+ }
+ ast::ConstructorType::Vector { size, kind, width } => {
+ let ty = ctx.ensure_type_exists(crate::TypeInner::Vector { size, kind, width });
+ ConcreteConstructorHandle::Type(ty)
+ }
+ ast::ConstructorType::PartialMatrix { rows, columns } => {
+ ConcreteConstructorHandle::PartialMatrix { rows, columns }
+ }
+ ast::ConstructorType::Matrix {
+ rows,
+ columns,
+ width,
+ } => {
+ let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ });
+ ConcreteConstructorHandle::Type(ty)
+ }
+ ast::ConstructorType::PartialArray => ConcreteConstructorHandle::PartialArray,
+ ast::ConstructorType::Array { base, size } => {
+ let base = self.resolve_ast_type(base, ctx.reborrow())?;
+ let size = match size {
+ ast::ArraySize::Constant(expr) => {
+ crate::ArraySize::Constant(self.constant(expr, ctx.reborrow())?)
+ }
+ ast::ArraySize::Dynamic => crate::ArraySize::Dynamic,
+ };
+
+ self.layouter
+ .update(&ctx.module.types, &ctx.module.constants)
+ .unwrap();
+ let ty = ctx.ensure_type_exists(crate::TypeInner::Array {
+ base,
+ size,
+ stride: self.layouter[base].to_stride(),
+ });
+ ConcreteConstructorHandle::Type(ty)
+ }
+ ast::ConstructorType::Type(ty) => ConcreteConstructorHandle::Type(ty),
+ };
+
+ Ok(c)
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/lower/mod.rs b/third_party/rust/naga/src/front/wgsl/lower/mod.rs
new file mode 100644
index 0000000000..bc9cce1bee
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/lower/mod.rs
@@ -0,0 +1,2426 @@
+use crate::front::wgsl::error::{Error, ExpectedToken, InvalidAssignmentType};
+use crate::front::wgsl::index::Index;
+use crate::front::wgsl::parse::number::Number;
+use crate::front::wgsl::parse::{ast, conv};
+use crate::front::{Emitter, Typifier};
+use crate::proc::{ensure_block_returns, Alignment, Layouter, ResolveContext, TypeResolution};
+use crate::{Arena, FastHashMap, Handle, Span};
+use indexmap::IndexMap;
+
+mod construction;
+
+/// State for constructing a `crate::Module`.
+pub struct OutputContext<'source, 'temp, 'out> {
+ /// The `TranslationUnit`'s expressions arena.
+ ast_expressions: &'temp Arena<ast::Expression<'source>>,
+
+ /// The `TranslationUnit`'s types arena.
+ types: &'temp Arena<ast::Type<'source>>,
+
+ // Naga IR values.
+ /// The map from the names of module-scope declarations to the Naga IR
+ /// `Handle`s we have built for them, owned by `Lowerer::lower`.
+ globals: &'temp mut FastHashMap<&'source str, LoweredGlobalDecl>,
+
+ /// The module we're constructing.
+ module: &'out mut crate::Module,
+}
+
+impl<'source> OutputContext<'source, '_, '_> {
+ fn reborrow(&mut self) -> OutputContext<'source, '_, '_> {
+ OutputContext {
+ ast_expressions: self.ast_expressions,
+ globals: self.globals,
+ types: self.types,
+ module: self.module,
+ }
+ }
+
+ fn ensure_type_exists(&mut self, inner: crate::TypeInner) -> Handle<crate::Type> {
+ self.module
+ .types
+ .insert(crate::Type { inner, name: None }, Span::UNDEFINED)
+ }
+}
+
+/// State for lowering a statement within a function.
+pub struct StatementContext<'source, 'temp, 'out> {
+ // WGSL AST values.
+ /// A reference to [`TranslationUnit::expressions`] for the translation unit
+ /// we're lowering.
+ ///
+ /// [`TranslationUnit::expressions`]: ast::TranslationUnit::expressions
+ ast_expressions: &'temp Arena<ast::Expression<'source>>,
+
+ /// A reference to [`TranslationUnit::types`] for the translation unit
+ /// we're lowering.
+ ///
+ /// [`TranslationUnit::types`]: ast::TranslationUnit::types
+ types: &'temp Arena<ast::Type<'source>>,
+
+ // Naga IR values.
+ /// The map from the names of module-scope declarations to the Naga IR
+ /// `Handle`s we have built for them, owned by `Lowerer::lower`.
+ globals: &'temp mut FastHashMap<&'source str, LoweredGlobalDecl>,
+
+ /// A map from `ast::Local` handles to the Naga expressions we've built for them.
+ ///
+ /// The Naga expressions are either [`LocalVariable`] or
+ /// [`FunctionArgument`] expressions.
+ ///
+ /// [`LocalVariable`]: crate::Expression::LocalVariable
+ /// [`FunctionArgument`]: crate::Expression::FunctionArgument
+ local_table: &'temp mut FastHashMap<Handle<ast::Local>, TypedExpression>,
+
+ typifier: &'temp mut Typifier,
+ variables: &'out mut Arena<crate::LocalVariable>,
+ naga_expressions: &'out mut Arena<crate::Expression>,
+ /// Stores the names of expressions that are assigned in `let` statement
+ /// Also stores the spans of the names, for use in errors.
+ named_expressions: &'out mut IndexMap<Handle<crate::Expression>, (String, Span)>,
+ arguments: &'out [crate::FunctionArgument],
+ module: &'out mut crate::Module,
+}
+
+impl<'a, 'temp> StatementContext<'a, 'temp, '_> {
+ fn reborrow(&mut self) -> StatementContext<'a, '_, '_> {
+ StatementContext {
+ local_table: self.local_table,
+ globals: self.globals,
+ types: self.types,
+ ast_expressions: self.ast_expressions,
+ typifier: self.typifier,
+ variables: self.variables,
+ naga_expressions: self.naga_expressions,
+ named_expressions: self.named_expressions,
+ arguments: self.arguments,
+ module: self.module,
+ }
+ }
+
+ fn as_expression<'t>(
+ &'t mut self,
+ block: &'t mut crate::Block,
+ emitter: &'t mut Emitter,
+ ) -> ExpressionContext<'a, 't, '_>
+ where
+ 'temp: 't,
+ {
+ ExpressionContext {
+ local_table: self.local_table,
+ globals: self.globals,
+ types: self.types,
+ ast_expressions: self.ast_expressions,
+ typifier: self.typifier,
+ naga_expressions: self.naga_expressions,
+ module: self.module,
+ local_vars: self.variables,
+ arguments: self.arguments,
+ block,
+ emitter,
+ }
+ }
+
+ fn as_output(&mut self) -> OutputContext<'a, '_, '_> {
+ OutputContext {
+ ast_expressions: self.ast_expressions,
+ globals: self.globals,
+ types: self.types,
+ module: self.module,
+ }
+ }
+
+ fn invalid_assignment_type(&self, expr: Handle<crate::Expression>) -> InvalidAssignmentType {
+ if let Some(&(_, span)) = self.named_expressions.get(&expr) {
+ InvalidAssignmentType::ImmutableBinding(span)
+ } else {
+ match self.naga_expressions[expr] {
+ crate::Expression::Swizzle { .. } => InvalidAssignmentType::Swizzle,
+ crate::Expression::Access { base, .. } => self.invalid_assignment_type(base),
+ crate::Expression::AccessIndex { base, .. } => self.invalid_assignment_type(base),
+ _ => InvalidAssignmentType::Other,
+ }
+ }
+ }
+}
+
+/// State for lowering an `ast::Expression` to Naga IR.
+///
+/// Not to be confused with `parser::ExpressionContext`.
+pub struct ExpressionContext<'source, 'temp, 'out> {
+ // WGSL AST values.
+ local_table: &'temp mut FastHashMap<Handle<ast::Local>, TypedExpression>,
+ ast_expressions: &'temp Arena<ast::Expression<'source>>,
+ types: &'temp Arena<ast::Type<'source>>,
+
+ // Naga IR values.
+ /// The map from the names of module-scope declarations to the Naga IR
+ /// `Handle`s we have built for them, owned by `Lowerer::lower`.
+ globals: &'temp mut FastHashMap<&'source str, LoweredGlobalDecl>,
+
+ typifier: &'temp mut Typifier,
+ naga_expressions: &'out mut Arena<crate::Expression>,
+ local_vars: &'out Arena<crate::LocalVariable>,
+ arguments: &'out [crate::FunctionArgument],
+ module: &'out mut crate::Module,
+ block: &'temp mut crate::Block,
+ emitter: &'temp mut Emitter,
+}
+
+impl<'a> ExpressionContext<'a, '_, '_> {
+ fn reborrow(&mut self) -> ExpressionContext<'a, '_, '_> {
+ ExpressionContext {
+ local_table: self.local_table,
+ globals: self.globals,
+ types: self.types,
+ ast_expressions: self.ast_expressions,
+ typifier: self.typifier,
+ naga_expressions: self.naga_expressions,
+ module: self.module,
+ local_vars: self.local_vars,
+ arguments: self.arguments,
+ block: self.block,
+ emitter: self.emitter,
+ }
+ }
+
+ fn as_output(&mut self) -> OutputContext<'a, '_, '_> {
+ OutputContext {
+ ast_expressions: self.ast_expressions,
+ globals: self.globals,
+ types: self.types,
+ module: self.module,
+ }
+ }
+
+ /// Determine the type of `handle`, and add it to the module's arena.
+ ///
+ /// If you just need a `TypeInner` for `handle`'s type, use
+ /// [`grow_types`] and [`resolved_inner`] instead. This function
+ /// should only be used when the type of `handle` needs to appear
+ /// in the module's final `Arena<Type>`, for example, if you're
+ /// creating a [`LocalVariable`] whose type is inferred from its
+ /// initializer.
+ ///
+ /// [`grow_types`]: Self::grow_types
+ /// [`resolved_inner`]: Self::resolved_inner
+ /// [`LocalVariable`]: crate::LocalVariable
+ fn register_type(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ ) -> Result<Handle<crate::Type>, Error<'a>> {
+ self.grow_types(handle)?;
+ Ok(self.typifier.register_type(handle, &mut self.module.types))
+ }
+
+ /// Resolve the types of all expressions up through `handle`.
+ ///
+ /// Ensure that [`self.typifier`] has a [`TypeResolution`] for
+ /// every expression in [`self.naga_expressions`].
+ ///
+ /// This does not add types to any arena. The [`Typifier`]
+ /// documentation explains the steps we take to avoid filling
+ /// arenas with intermediate types.
+ ///
+ /// This function takes `&mut self`, so it can't conveniently
+ /// return a shared reference to the resulting `TypeResolution`:
+ /// the shared reference would extend the mutable borrow, and you
+ /// wouldn't be able to use `self` for anything else. Instead, you
+ /// should call `grow_types` to cover the handles you need, and
+ /// then use `self.typifier[handle]` or
+ /// [`self.resolved_inner(handle)`] to get at their resolutions.
+ ///
+ /// [`self.typifier`]: ExpressionContext::typifier
+ /// [`self.resolved_inner(handle)`]: ExpressionContext::resolved_inner
+ /// [`Typifier`]: Typifier
+ fn grow_types(&mut self, handle: Handle<crate::Expression>) -> Result<&mut Self, Error<'a>> {
+ let resolve_ctx = ResolveContext::with_locals(self.module, self.local_vars, self.arguments);
+ self.typifier
+ .grow(handle, self.naga_expressions, &resolve_ctx)
+ .map_err(Error::InvalidResolve)?;
+ Ok(self)
+ }
+
+ fn resolved_inner(&self, handle: Handle<crate::Expression>) -> &crate::TypeInner {
+ self.typifier[handle].inner_with(&self.module.types)
+ }
+
+ fn image_data(
+ &mut self,
+ image: Handle<crate::Expression>,
+ span: Span,
+ ) -> Result<(crate::ImageClass, bool), Error<'a>> {
+ self.grow_types(image)?;
+ match *self.resolved_inner(image) {
+ crate::TypeInner::Image { class, arrayed, .. } => Ok((class, arrayed)),
+ _ => Err(Error::BadTexture(span)),
+ }
+ }
+
+ fn prepare_args<'b>(
+ &mut self,
+ args: &'b [Handle<ast::Expression<'a>>],
+ min_args: u32,
+ span: Span,
+ ) -> ArgumentContext<'b, 'a> {
+ ArgumentContext {
+ args: args.iter(),
+ min_args,
+ args_used: 0,
+ total_args: args.len() as u32,
+ span,
+ }
+ }
+
+ /// Insert splats, if needed by the non-'*' operations.
+ fn binary_op_splat(
+ &mut self,
+ op: crate::BinaryOperator,
+ left: &mut Handle<crate::Expression>,
+ right: &mut Handle<crate::Expression>,
+ ) -> Result<(), Error<'a>> {
+ if op != crate::BinaryOperator::Multiply {
+ self.grow_types(*left)?.grow_types(*right)?;
+
+ let left_size = match *self.resolved_inner(*left) {
+ crate::TypeInner::Vector { size, .. } => Some(size),
+ _ => None,
+ };
+
+ match (left_size, self.resolved_inner(*right)) {
+ (Some(size), &crate::TypeInner::Scalar { .. }) => {
+ *right = self.naga_expressions.append(
+ crate::Expression::Splat {
+ size,
+ value: *right,
+ },
+ self.naga_expressions.get_span(*right),
+ );
+ }
+ (None, &crate::TypeInner::Vector { size, .. }) => {
+ *left = self.naga_expressions.append(
+ crate::Expression::Splat { size, value: *left },
+ self.naga_expressions.get_span(*left),
+ );
+ }
+ _ => {}
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Add a single expression to the expression table that is not covered by `self.emitter`.
+ ///
+ /// This is useful for `CallResult` and `AtomicResult` expressions, which should not be covered by
+ /// `Emit` statements.
+ fn interrupt_emitter(
+ &mut self,
+ expression: crate::Expression,
+ span: Span,
+ ) -> Handle<crate::Expression> {
+ self.block
+ .extend(self.emitter.finish(self.naga_expressions));
+ let result = self.naga_expressions.append(expression, span);
+ self.emitter.start(self.naga_expressions);
+ result
+ }
+
+ /// Apply the WGSL Load Rule to `expr`.
+ ///
+ /// If `expr` is has type `ref<SC, T, A>`, perform a load to produce a value of type
+ /// `T`. Otherwise, return `expr` unchanged.
+ fn apply_load_rule(&mut self, expr: TypedExpression) -> Handle<crate::Expression> {
+ if expr.is_reference {
+ let load = crate::Expression::Load {
+ pointer: expr.handle,
+ };
+ let span = self.naga_expressions.get_span(expr.handle);
+ self.naga_expressions.append(load, span)
+ } else {
+ expr.handle
+ }
+ }
+
+ /// Creates a zero value constant of type `ty`
+ ///
+ /// Returns `None` if the given `ty` is not a constructible type
+ fn create_zero_value_constant(
+ &mut self,
+ ty: Handle<crate::Type>,
+ ) -> Option<Handle<crate::Constant>> {
+ let inner = match self.module.types[ty].inner {
+ crate::TypeInner::Scalar { kind, width } => {
+ let value = match kind {
+ crate::ScalarKind::Sint => crate::ScalarValue::Sint(0),
+ crate::ScalarKind::Uint => crate::ScalarValue::Uint(0),
+ crate::ScalarKind::Float => crate::ScalarValue::Float(0.),
+ crate::ScalarKind::Bool => crate::ScalarValue::Bool(false),
+ };
+ crate::ConstantInner::Scalar { width, value }
+ }
+ crate::TypeInner::Vector { size, kind, width } => {
+ let scalar_ty = self.ensure_type_exists(crate::TypeInner::Scalar { width, kind });
+ let component = self.create_zero_value_constant(scalar_ty)?;
+ crate::ConstantInner::Composite {
+ ty,
+ components: (0..size as u8).map(|_| component).collect(),
+ }
+ }
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ let vec_ty = self.ensure_type_exists(crate::TypeInner::Vector {
+ width,
+ kind: crate::ScalarKind::Float,
+ size: rows,
+ });
+ let component = self.create_zero_value_constant(vec_ty)?;
+ crate::ConstantInner::Composite {
+ ty,
+ components: (0..columns as u8).map(|_| component).collect(),
+ }
+ }
+ crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(size),
+ ..
+ } => {
+ let size = self.module.constants[size].to_array_length()?;
+ let component = self.create_zero_value_constant(base)?;
+ crate::ConstantInner::Composite {
+ ty,
+ components: (0..size).map(|_| component).collect(),
+ }
+ }
+ crate::TypeInner::Struct { ref members, .. } => {
+ let members = members.clone();
+ crate::ConstantInner::Composite {
+ ty,
+ components: members
+ .iter()
+ .map(|member| self.create_zero_value_constant(member.ty))
+ .collect::<Option<_>>()?,
+ }
+ }
+ _ => return None,
+ };
+
+ let constant = self.module.constants.fetch_or_append(
+ crate::Constant {
+ name: None,
+ specialization: None,
+ inner,
+ },
+ Span::UNDEFINED,
+ );
+ Some(constant)
+ }
+
+ fn format_typeinner(&self, inner: &crate::TypeInner) -> String {
+ inner.to_wgsl(&self.module.types, &self.module.constants)
+ }
+
+ fn format_type(&self, handle: Handle<crate::Type>) -> String {
+ let ty = &self.module.types[handle];
+ match ty.name {
+ Some(ref name) => name.clone(),
+ None => self.format_typeinner(&ty.inner),
+ }
+ }
+
+ fn format_type_resolution(&self, resolution: &TypeResolution) -> String {
+ match *resolution {
+ TypeResolution::Handle(handle) => self.format_type(handle),
+ TypeResolution::Value(ref inner) => self.format_typeinner(inner),
+ }
+ }
+
+ fn ensure_type_exists(&mut self, inner: crate::TypeInner) -> Handle<crate::Type> {
+ self.as_output().ensure_type_exists(inner)
+ }
+}
+
+struct ArgumentContext<'ctx, 'source> {
+ args: std::slice::Iter<'ctx, Handle<ast::Expression<'source>>>,
+ min_args: u32,
+ args_used: u32,
+ total_args: u32,
+ span: Span,
+}
+
+impl<'source> ArgumentContext<'_, 'source> {
+ pub fn finish(self) -> Result<(), Error<'source>> {
+ if self.args.len() == 0 {
+ Ok(())
+ } else {
+ Err(Error::WrongArgumentCount {
+ found: self.total_args,
+ expected: self.min_args..self.args_used + 1,
+ span: self.span,
+ })
+ }
+ }
+
+ pub fn next(&mut self) -> Result<Handle<ast::Expression<'source>>, Error<'source>> {
+ match self.args.next().copied() {
+ Some(arg) => {
+ self.args_used += 1;
+ Ok(arg)
+ }
+ None => Err(Error::WrongArgumentCount {
+ found: self.total_args,
+ expected: self.min_args..self.args_used + 1,
+ span: self.span,
+ }),
+ }
+ }
+}
+
+/// A Naga [`Expression`] handle, with WGSL type information.
+///
+/// Naga and WGSL types are very close, but Naga lacks WGSL's 'reference' types,
+/// which we need to know to apply the Load Rule. This struct carries a Naga
+/// `Handle<Expression>` along with enough information to determine its WGSL type.
+///
+/// [`Expression`]: crate::Expression
+#[derive(Debug, Copy, Clone)]
+struct TypedExpression {
+ /// The handle of the Naga expression.
+ handle: Handle<crate::Expression>,
+
+ /// True if this expression's WGSL type is a reference.
+ ///
+ /// When this is true, `handle` must be a pointer.
+ is_reference: bool,
+}
+
+impl TypedExpression {
+ const fn non_reference(handle: Handle<crate::Expression>) -> TypedExpression {
+ TypedExpression {
+ handle,
+ is_reference: false,
+ }
+ }
+}
+
+enum Composition {
+ Single(u32),
+ Multi(crate::VectorSize, [crate::SwizzleComponent; 4]),
+}
+
+impl Composition {
+ const fn letter_component(letter: char) -> Option<crate::SwizzleComponent> {
+ use crate::SwizzleComponent as Sc;
+ match letter {
+ 'x' | 'r' => Some(Sc::X),
+ 'y' | 'g' => Some(Sc::Y),
+ 'z' | 'b' => Some(Sc::Z),
+ 'w' | 'a' => Some(Sc::W),
+ _ => None,
+ }
+ }
+
+ fn extract_impl(name: &str, name_span: Span) -> Result<u32, Error> {
+ let ch = name.chars().next().ok_or(Error::BadAccessor(name_span))?;
+ match Self::letter_component(ch) {
+ Some(sc) => Ok(sc as u32),
+ None => Err(Error::BadAccessor(name_span)),
+ }
+ }
+
+ fn make(name: &str, name_span: Span) -> Result<Self, Error> {
+ if name.len() > 1 {
+ let mut components = [crate::SwizzleComponent::X; 4];
+ for (comp, ch) in components.iter_mut().zip(name.chars()) {
+ *comp = Self::letter_component(ch).ok_or(Error::BadAccessor(name_span))?;
+ }
+
+ let size = match name.len() {
+ 2 => crate::VectorSize::Bi,
+ 3 => crate::VectorSize::Tri,
+ 4 => crate::VectorSize::Quad,
+ _ => return Err(Error::BadAccessor(name_span)),
+ };
+ Ok(Composition::Multi(size, components))
+ } else {
+ Self::extract_impl(name, name_span).map(Composition::Single)
+ }
+ }
+}
+
+/// An `ast::GlobalDecl` for which we have built the Naga IR equivalent.
+enum LoweredGlobalDecl {
+ Function(Handle<crate::Function>),
+ Var(Handle<crate::GlobalVariable>),
+ Const(Handle<crate::Constant>),
+ Type(Handle<crate::Type>),
+ EntryPoint,
+}
+
+enum ConstantOrInner {
+ Constant(Handle<crate::Constant>),
+ Inner(crate::ConstantInner),
+}
+
+enum Texture {
+ Gather,
+ GatherCompare,
+
+ Sample,
+ SampleBias,
+ SampleCompare,
+ SampleCompareLevel,
+ SampleGrad,
+ SampleLevel,
+ // SampleBaseClampToEdge,
+}
+
+impl Texture {
+ pub fn map(word: &str) -> Option<Self> {
+ Some(match word {
+ "textureGather" => Self::Gather,
+ "textureGatherCompare" => Self::GatherCompare,
+
+ "textureSample" => Self::Sample,
+ "textureSampleBias" => Self::SampleBias,
+ "textureSampleCompare" => Self::SampleCompare,
+ "textureSampleCompareLevel" => Self::SampleCompareLevel,
+ "textureSampleGrad" => Self::SampleGrad,
+ "textureSampleLevel" => Self::SampleLevel,
+ // "textureSampleBaseClampToEdge" => Some(Self::SampleBaseClampToEdge),
+ _ => return None,
+ })
+ }
+
+ pub const fn min_argument_count(&self) -> u32 {
+ match *self {
+ Self::Gather => 3,
+ Self::GatherCompare => 4,
+
+ Self::Sample => 3,
+ Self::SampleBias => 5,
+ Self::SampleCompare => 5,
+ Self::SampleCompareLevel => 5,
+ Self::SampleGrad => 6,
+ Self::SampleLevel => 5,
+ // Self::SampleBaseClampToEdge => 3,
+ }
+ }
+}
+
+pub struct Lowerer<'source, 'temp> {
+ index: &'temp Index<'source>,
+ layouter: Layouter,
+}
+
+impl<'source, 'temp> Lowerer<'source, 'temp> {
+ pub fn new(index: &'temp Index<'source>) -> Self {
+ Self {
+ index,
+ layouter: Layouter::default(),
+ }
+ }
+
+ pub fn lower(
+ &mut self,
+ tu: &'temp ast::TranslationUnit<'source>,
+ ) -> Result<crate::Module, Error<'source>> {
+ let mut module = crate::Module::default();
+
+ let mut ctx = OutputContext {
+ ast_expressions: &tu.expressions,
+ globals: &mut FastHashMap::default(),
+ types: &tu.types,
+ module: &mut module,
+ };
+
+ for decl_handle in self.index.visit_ordered() {
+ let span = tu.decls.get_span(decl_handle);
+ let decl = &tu.decls[decl_handle];
+
+ match decl.kind {
+ ast::GlobalDeclKind::Fn(ref f) => {
+ let lowered_decl = self.function(f, span, ctx.reborrow())?;
+ ctx.globals.insert(f.name.name, lowered_decl);
+ }
+ ast::GlobalDeclKind::Var(ref v) => {
+ let ty = self.resolve_ast_type(v.ty, ctx.reborrow())?;
+
+ let init = v
+ .init
+ .map(|init| self.constant(init, ctx.reborrow()))
+ .transpose()?;
+
+ let handle = ctx.module.global_variables.append(
+ crate::GlobalVariable {
+ name: Some(v.name.name.to_string()),
+ space: v.space,
+ binding: v.binding.clone(),
+ ty,
+ init,
+ },
+ span,
+ );
+
+ ctx.globals
+ .insert(v.name.name, LoweredGlobalDecl::Var(handle));
+ }
+ ast::GlobalDeclKind::Const(ref c) => {
+ let inner = self.constant_inner(c.init, ctx.reborrow())?;
+ let inner = match inner {
+ ConstantOrInner::Constant(c) => ctx.module.constants[c].inner.clone(),
+ ConstantOrInner::Inner(inner) => inner,
+ };
+
+ let inferred_type = match inner {
+ crate::ConstantInner::Scalar { width, value } => {
+ ctx.ensure_type_exists(crate::TypeInner::Scalar {
+ width,
+ kind: value.scalar_kind(),
+ })
+ }
+ crate::ConstantInner::Composite { ty, .. } => ty,
+ };
+
+ let handle = ctx.module.constants.append(
+ crate::Constant {
+ name: Some(c.name.name.to_string()),
+ specialization: None,
+ inner,
+ },
+ span,
+ );
+
+ let explicit_ty =
+ c.ty.map(|ty| self.resolve_ast_type(ty, ctx.reborrow()))
+ .transpose()?;
+
+ if let Some(explicit) = explicit_ty {
+ if explicit != inferred_type {
+ let ty = &ctx.module.types[explicit];
+ let explicit = ty.name.clone().unwrap_or_else(|| {
+ ty.inner.to_wgsl(&ctx.module.types, &ctx.module.constants)
+ });
+
+ let ty = &ctx.module.types[inferred_type];
+ let inferred = ty.name.clone().unwrap_or_else(|| {
+ ty.inner.to_wgsl(&ctx.module.types, &ctx.module.constants)
+ });
+
+ return Err(Error::InitializationTypeMismatch(
+ c.name.span,
+ explicit,
+ inferred,
+ ));
+ }
+ }
+
+ ctx.globals
+ .insert(c.name.name, LoweredGlobalDecl::Const(handle));
+ }
+ ast::GlobalDeclKind::Struct(ref s) => {
+ let handle = self.r#struct(s, span, ctx.reborrow())?;
+ ctx.globals
+ .insert(s.name.name, LoweredGlobalDecl::Type(handle));
+ }
+ ast::GlobalDeclKind::Type(ref alias) => {
+ let ty = self.resolve_ast_type(alias.ty, ctx.reborrow())?;
+ ctx.globals
+ .insert(alias.name.name, LoweredGlobalDecl::Type(ty));
+ }
+ }
+ }
+
+ Ok(module)
+ }
+
+ fn function(
+ &mut self,
+ f: &ast::Function<'source>,
+ span: Span,
+ mut ctx: OutputContext<'source, '_, '_>,
+ ) -> Result<LoweredGlobalDecl, Error<'source>> {
+ let mut local_table = FastHashMap::default();
+ let mut local_variables = Arena::new();
+ let mut expressions = Arena::new();
+ let mut named_expressions = IndexMap::default();
+
+ let arguments = f
+ .arguments
+ .iter()
+ .enumerate()
+ .map(|(i, arg)| {
+ let ty = self.resolve_ast_type(arg.ty, ctx.reborrow())?;
+ let expr = expressions
+ .append(crate::Expression::FunctionArgument(i as u32), arg.name.span);
+ local_table.insert(arg.handle, TypedExpression::non_reference(expr));
+ named_expressions.insert(expr, (arg.name.name.to_string(), arg.name.span));
+
+ Ok(crate::FunctionArgument {
+ name: Some(arg.name.name.to_string()),
+ ty,
+ binding: self.interpolate_default(&arg.binding, ty, ctx.reborrow()),
+ })
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+
+ let result = f
+ .result
+ .as_ref()
+ .map(|res| {
+ self.resolve_ast_type(res.ty, ctx.reborrow())
+ .map(|ty| crate::FunctionResult {
+ ty,
+ binding: self.interpolate_default(&res.binding, ty, ctx.reborrow()),
+ })
+ })
+ .transpose()?;
+
+ let mut typifier = Typifier::default();
+ let mut body = self.block(
+ &f.body,
+ StatementContext {
+ local_table: &mut local_table,
+ globals: ctx.globals,
+ ast_expressions: ctx.ast_expressions,
+ typifier: &mut typifier,
+ variables: &mut local_variables,
+ naga_expressions: &mut expressions,
+ named_expressions: &mut named_expressions,
+ types: ctx.types,
+ module: ctx.module,
+ arguments: &arguments,
+ },
+ )?;
+ ensure_block_returns(&mut body);
+
+ let function = crate::Function {
+ name: Some(f.name.name.to_string()),
+ arguments,
+ result,
+ local_variables,
+ expressions,
+ named_expressions: named_expressions
+ .into_iter()
+ .map(|(key, (name, _))| (key, name))
+ .collect(),
+ body,
+ };
+
+ if let Some(ref entry) = f.entry_point {
+ ctx.module.entry_points.push(crate::EntryPoint {
+ name: f.name.name.to_string(),
+ stage: entry.stage,
+ early_depth_test: entry.early_depth_test,
+ workgroup_size: entry.workgroup_size,
+ function,
+ });
+ Ok(LoweredGlobalDecl::EntryPoint)
+ } else {
+ let handle = ctx.module.functions.append(function, span);
+ Ok(LoweredGlobalDecl::Function(handle))
+ }
+ }
+
+ fn block(
+ &mut self,
+ b: &ast::Block<'source>,
+ mut ctx: StatementContext<'source, '_, '_>,
+ ) -> Result<crate::Block, Error<'source>> {
+ let mut block = crate::Block::default();
+
+ for stmt in b.stmts.iter() {
+ self.statement(stmt, &mut block, ctx.reborrow())?;
+ }
+
+ Ok(block)
+ }
+
+ fn statement(
+ &mut self,
+ stmt: &ast::Statement<'source>,
+ block: &mut crate::Block,
+ mut ctx: StatementContext<'source, '_, '_>,
+ ) -> Result<(), Error<'source>> {
+ let out = match stmt.kind {
+ ast::StatementKind::Block(ref block) => {
+ let block = self.block(block, ctx.reborrow())?;
+ crate::Statement::Block(block)
+ }
+ ast::StatementKind::LocalDecl(ref decl) => match *decl {
+ ast::LocalDecl::Let(ref l) => {
+ let mut emitter = Emitter::default();
+ emitter.start(ctx.naga_expressions);
+
+ let value = self.expression(l.init, ctx.as_expression(block, &mut emitter))?;
+
+ let explicit_ty =
+ l.ty.map(|ty| self.resolve_ast_type(ty, ctx.as_output()))
+ .transpose()?;
+
+ if let Some(ty) = explicit_ty {
+ let mut ctx = ctx.as_expression(block, &mut emitter);
+ let init_ty = ctx.register_type(value)?;
+ if !ctx.module.types[ty]
+ .inner
+ .equivalent(&ctx.module.types[init_ty].inner, &ctx.module.types)
+ {
+ return Err(Error::InitializationTypeMismatch(
+ l.name.span,
+ ctx.format_type(ty),
+ ctx.format_type(init_ty),
+ ));
+ }
+ }
+
+ block.extend(emitter.finish(ctx.naga_expressions));
+ ctx.local_table
+ .insert(l.handle, TypedExpression::non_reference(value));
+ ctx.named_expressions
+ .insert(value, (l.name.name.to_string(), l.name.span));
+
+ return Ok(());
+ }
+ ast::LocalDecl::Var(ref v) => {
+ let mut emitter = Emitter::default();
+ emitter.start(ctx.naga_expressions);
+
+ let initializer = match v.init {
+ Some(init) => {
+ let initializer =
+ self.expression(init, ctx.as_expression(block, &mut emitter))?;
+ ctx.as_expression(block, &mut emitter)
+ .grow_types(initializer)?;
+ Some(initializer)
+ }
+ None => None,
+ };
+
+ let explicit_ty =
+ v.ty.map(|ty| self.resolve_ast_type(ty, ctx.as_output()))
+ .transpose()?;
+
+ let ty = match (explicit_ty, initializer) {
+ (Some(explicit), Some(initializer)) => {
+ let ctx = ctx.as_expression(block, &mut emitter);
+ let initializer_ty = ctx.resolved_inner(initializer);
+ if !ctx.module.types[explicit]
+ .inner
+ .equivalent(initializer_ty, &ctx.module.types)
+ {
+ return Err(Error::InitializationTypeMismatch(
+ v.name.span,
+ ctx.format_type(explicit),
+ ctx.format_typeinner(initializer_ty),
+ ));
+ }
+ explicit
+ }
+ (Some(explicit), None) => explicit,
+ (None, Some(initializer)) => ctx
+ .as_expression(block, &mut emitter)
+ .register_type(initializer)?,
+ (None, None) => {
+ return Err(Error::MissingType(v.name.span));
+ }
+ };
+
+ let var = ctx.variables.append(
+ crate::LocalVariable {
+ name: Some(v.name.name.to_string()),
+ ty,
+ init: None,
+ },
+ stmt.span,
+ );
+
+ let handle = ctx
+ .as_expression(block, &mut emitter)
+ .interrupt_emitter(crate::Expression::LocalVariable(var), Span::UNDEFINED);
+ block.extend(emitter.finish(ctx.naga_expressions));
+ ctx.local_table.insert(
+ v.handle,
+ TypedExpression {
+ handle,
+ is_reference: true,
+ },
+ );
+
+ match initializer {
+ Some(initializer) => crate::Statement::Store {
+ pointer: handle,
+ value: initializer,
+ },
+ None => return Ok(()),
+ }
+ }
+ },
+ ast::StatementKind::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ let mut emitter = Emitter::default();
+ emitter.start(ctx.naga_expressions);
+
+ let condition =
+ self.expression(condition, ctx.as_expression(block, &mut emitter))?;
+ block.extend(emitter.finish(ctx.naga_expressions));
+
+ let accept = self.block(accept, ctx.reborrow())?;
+ let reject = self.block(reject, ctx.reborrow())?;
+
+ crate::Statement::If {
+ condition,
+ accept,
+ reject,
+ }
+ }
+ ast::StatementKind::Switch {
+ selector,
+ ref cases,
+ } => {
+ let mut emitter = Emitter::default();
+ emitter.start(ctx.naga_expressions);
+
+ let mut ectx = ctx.as_expression(block, &mut emitter);
+ let selector = self.expression(selector, ectx.reborrow())?;
+
+ ectx.grow_types(selector)?;
+ let uint =
+ ectx.resolved_inner(selector).scalar_kind() == Some(crate::ScalarKind::Uint);
+ block.extend(emitter.finish(ctx.naga_expressions));
+
+ let cases = cases
+ .iter()
+ .map(|case| {
+ Ok(crate::SwitchCase {
+ value: match case.value {
+ ast::SwitchValue::I32(value) if !uint => {
+ crate::SwitchValue::I32(value)
+ }
+ ast::SwitchValue::U32(value) if uint => {
+ crate::SwitchValue::U32(value)
+ }
+ ast::SwitchValue::Default => crate::SwitchValue::Default,
+ _ => {
+ return Err(Error::InvalidSwitchValue {
+ uint,
+ span: case.value_span,
+ });
+ }
+ },
+ body: self.block(&case.body, ctx.reborrow())?,
+ fall_through: case.fall_through,
+ })
+ })
+ .collect::<Result<_, _>>()?;
+
+ crate::Statement::Switch { selector, cases }
+ }
+ ast::StatementKind::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ let body = self.block(body, ctx.reborrow())?;
+ let mut continuing = self.block(continuing, ctx.reborrow())?;
+
+ let mut emitter = Emitter::default();
+ emitter.start(ctx.naga_expressions);
+ let break_if = break_if
+ .map(|expr| self.expression(expr, ctx.as_expression(block, &mut emitter)))
+ .transpose()?;
+ continuing.extend(emitter.finish(ctx.naga_expressions));
+
+ crate::Statement::Loop {
+ body,
+ continuing,
+ break_if,
+ }
+ }
+ ast::StatementKind::Break => crate::Statement::Break,
+ ast::StatementKind::Continue => crate::Statement::Continue,
+ ast::StatementKind::Return { value } => {
+ let mut emitter = Emitter::default();
+ emitter.start(ctx.naga_expressions);
+
+ let value = value
+ .map(|expr| self.expression(expr, ctx.as_expression(block, &mut emitter)))
+ .transpose()?;
+ block.extend(emitter.finish(ctx.naga_expressions));
+
+ crate::Statement::Return { value }
+ }
+ ast::StatementKind::Kill => crate::Statement::Kill,
+ ast::StatementKind::Call {
+ ref function,
+ ref arguments,
+ } => {
+ let mut emitter = Emitter::default();
+ emitter.start(ctx.naga_expressions);
+
+ let _ = self.call(
+ stmt.span,
+ function,
+ arguments,
+ ctx.as_expression(block, &mut emitter),
+ )?;
+ block.extend(emitter.finish(ctx.naga_expressions));
+ return Ok(());
+ }
+ ast::StatementKind::Assign { target, op, value } => {
+ let mut emitter = Emitter::default();
+ emitter.start(ctx.naga_expressions);
+
+ let expr =
+ self.expression_for_reference(target, ctx.as_expression(block, &mut emitter))?;
+ let mut value = self.expression(value, ctx.as_expression(block, &mut emitter))?;
+
+ if !expr.is_reference {
+ let ty = ctx.invalid_assignment_type(expr.handle);
+
+ return Err(Error::InvalidAssignment {
+ span: ctx.ast_expressions.get_span(target),
+ ty,
+ });
+ }
+
+ let value = match op {
+ Some(op) => {
+ let mut ctx = ctx.as_expression(block, &mut emitter);
+ let mut left = ctx.apply_load_rule(expr);
+ ctx.binary_op_splat(op, &mut left, &mut value)?;
+ ctx.naga_expressions.append(
+ crate::Expression::Binary {
+ op,
+ left,
+ right: value,
+ },
+ stmt.span,
+ )
+ }
+ None => value,
+ };
+ block.extend(emitter.finish(ctx.naga_expressions));
+
+ crate::Statement::Store {
+ pointer: expr.handle,
+ value,
+ }
+ }
+ ast::StatementKind::Increment(value) | ast::StatementKind::Decrement(value) => {
+ let mut emitter = Emitter::default();
+ emitter.start(ctx.naga_expressions);
+
+ let op = match stmt.kind {
+ ast::StatementKind::Increment(_) => crate::BinaryOperator::Add,
+ ast::StatementKind::Decrement(_) => crate::BinaryOperator::Subtract,
+ _ => unreachable!(),
+ };
+
+ let value_span = ctx.ast_expressions.get_span(value);
+ let reference =
+ self.expression_for_reference(value, ctx.as_expression(block, &mut emitter))?;
+ let mut ectx = ctx.as_expression(block, &mut emitter);
+
+ ectx.grow_types(reference.handle)?;
+ let (kind, width) = match *ectx.resolved_inner(reference.handle) {
+ crate::TypeInner::ValuePointer {
+ size: None,
+ kind,
+ width,
+ ..
+ } => (kind, width),
+ crate::TypeInner::Pointer { base, .. } => match ectx.module.types[base].inner {
+ crate::TypeInner::Scalar { kind, width } => (kind, width),
+ _ => return Err(Error::BadIncrDecrReferenceType(value_span)),
+ },
+ _ => return Err(Error::BadIncrDecrReferenceType(value_span)),
+ };
+ let constant_inner = crate::ConstantInner::Scalar {
+ width,
+ value: match kind {
+ crate::ScalarKind::Sint => crate::ScalarValue::Sint(1),
+ crate::ScalarKind::Uint => crate::ScalarValue::Uint(1),
+ _ => return Err(Error::BadIncrDecrReferenceType(value_span)),
+ },
+ };
+ let constant = ectx.module.constants.fetch_or_append(
+ crate::Constant {
+ name: None,
+ specialization: None,
+ inner: constant_inner,
+ },
+ Span::UNDEFINED,
+ );
+
+ let left = ectx.naga_expressions.append(
+ crate::Expression::Load {
+ pointer: reference.handle,
+ },
+ value_span,
+ );
+ let right =
+ ectx.interrupt_emitter(crate::Expression::Constant(constant), Span::UNDEFINED);
+ let value = ectx
+ .naga_expressions
+ .append(crate::Expression::Binary { op, left, right }, stmt.span);
+
+ block.extend(emitter.finish(ctx.naga_expressions));
+ crate::Statement::Store {
+ pointer: reference.handle,
+ value,
+ }
+ }
+ ast::StatementKind::Ignore(expr) => {
+ let mut emitter = Emitter::default();
+ emitter.start(ctx.naga_expressions);
+
+ let _ = self.expression(expr, ctx.as_expression(block, &mut emitter))?;
+ block.extend(emitter.finish(ctx.naga_expressions));
+ return Ok(());
+ }
+ };
+
+ block.push(out, stmt.span);
+
+ Ok(())
+ }
+
+ fn expression(
+ &mut self,
+ expr: Handle<ast::Expression<'source>>,
+ mut ctx: ExpressionContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ let expr = self.expression_for_reference(expr, ctx.reborrow())?;
+ Ok(ctx.apply_load_rule(expr))
+ }
+
+ fn expression_for_reference(
+ &mut self,
+ expr: Handle<ast::Expression<'source>>,
+ mut ctx: ExpressionContext<'source, '_, '_>,
+ ) -> Result<TypedExpression, Error<'source>> {
+ let span = ctx.ast_expressions.get_span(expr);
+ let expr = &ctx.ast_expressions[expr];
+
+ let (expr, is_reference) = match *expr {
+ ast::Expression::Literal(literal) => {
+ let inner = match literal {
+ ast::Literal::Number(Number::F32(f)) => crate::ConstantInner::Scalar {
+ width: 4,
+ value: crate::ScalarValue::Float(f as _),
+ },
+ ast::Literal::Number(Number::I32(i)) => crate::ConstantInner::Scalar {
+ width: 4,
+ value: crate::ScalarValue::Sint(i as _),
+ },
+ ast::Literal::Number(Number::U32(u)) => crate::ConstantInner::Scalar {
+ width: 4,
+ value: crate::ScalarValue::Uint(u as _),
+ },
+ ast::Literal::Number(_) => {
+ unreachable!("got abstract numeric type when not expected");
+ }
+ ast::Literal::Bool(b) => crate::ConstantInner::Scalar {
+ width: 1,
+ value: crate::ScalarValue::Bool(b),
+ },
+ };
+ let handle = ctx.module.constants.fetch_or_append(
+ crate::Constant {
+ name: None,
+ specialization: None,
+ inner,
+ },
+ Span::UNDEFINED,
+ );
+ let handle = ctx.interrupt_emitter(crate::Expression::Constant(handle), span);
+ return Ok(TypedExpression::non_reference(handle));
+ }
+ ast::Expression::Ident(ast::IdentExpr::Local(local)) => {
+ return Ok(ctx.local_table[&local])
+ }
+ ast::Expression::Ident(ast::IdentExpr::Unresolved(name)) => {
+ return if let Some(global) = ctx.globals.get(name) {
+ let (expr, is_reference) = match *global {
+ LoweredGlobalDecl::Var(handle) => (
+ crate::Expression::GlobalVariable(handle),
+ ctx.module.global_variables[handle].space
+ != crate::AddressSpace::Handle,
+ ),
+ LoweredGlobalDecl::Const(handle) => {
+ (crate::Expression::Constant(handle), false)
+ }
+ _ => {
+ return Err(Error::Unexpected(span, ExpectedToken::Variable));
+ }
+ };
+
+ let handle = ctx.interrupt_emitter(expr, span);
+ Ok(TypedExpression {
+ handle,
+ is_reference,
+ })
+ } else {
+ Err(Error::UnknownIdent(span, name))
+ }
+ }
+ ast::Expression::Construct {
+ ref ty,
+ ty_span,
+ ref components,
+ } => {
+ let handle = self.construct(span, ty, ty_span, components, ctx.reborrow())?;
+ return Ok(TypedExpression::non_reference(handle));
+ }
+ ast::Expression::Unary { op, expr } => {
+ let expr = self.expression(expr, ctx.reborrow())?;
+ (crate::Expression::Unary { op, expr }, false)
+ }
+ ast::Expression::AddrOf(expr) => {
+ // The `&` operator simply converts a reference to a pointer. And since a
+ // reference is required, the Load Rule is not applied.
+ let expr = self.expression_for_reference(expr, ctx.reborrow())?;
+ if !expr.is_reference {
+ return Err(Error::NotReference("the operand of the `&` operator", span));
+ }
+
+ // No code is generated. We just declare the pointer a reference now.
+ return Ok(TypedExpression {
+ is_reference: false,
+ ..expr
+ });
+ }
+ ast::Expression::Deref(expr) => {
+ // The pointer we dereference must be loaded.
+ let pointer = self.expression(expr, ctx.reborrow())?;
+
+ ctx.grow_types(pointer)?;
+ if ctx.resolved_inner(pointer).pointer_space().is_none() {
+ return Err(Error::NotPointer(span));
+ }
+
+ return Ok(TypedExpression {
+ handle: pointer,
+ is_reference: true,
+ });
+ }
+ ast::Expression::Binary { op, left, right } => {
+ // Load both operands.
+ let mut left = self.expression(left, ctx.reborrow())?;
+ let mut right = self.expression(right, ctx.reborrow())?;
+ ctx.binary_op_splat(op, &mut left, &mut right)?;
+ (crate::Expression::Binary { op, left, right }, false)
+ }
+ ast::Expression::Call {
+ ref function,
+ ref arguments,
+ } => {
+ let handle = self
+ .call(span, function, arguments, ctx.reborrow())?
+ .ok_or(Error::FunctionReturnsVoid(function.span))?;
+ return Ok(TypedExpression::non_reference(handle));
+ }
+ ast::Expression::Index { base, index } => {
+ let expr = self.expression_for_reference(base, ctx.reborrow())?;
+ let index = self.expression(index, ctx.reborrow())?;
+
+ ctx.grow_types(expr.handle)?;
+ let wgsl_pointer =
+ ctx.resolved_inner(expr.handle).pointer_space().is_some() && !expr.is_reference;
+
+ if wgsl_pointer {
+ return Err(Error::Pointer(
+ "the value indexed by a `[]` subscripting expression",
+ ctx.ast_expressions.get_span(base),
+ ));
+ }
+
+ if let crate::Expression::Constant(constant) = ctx.naga_expressions[index] {
+ let span = ctx.naga_expressions.get_span(index);
+ let index = match ctx.module.constants[constant].inner {
+ crate::ConstantInner::Scalar {
+ value: crate::ScalarValue::Uint(int),
+ ..
+ } => u32::try_from(int).map_err(|_| Error::BadU32Constant(span)),
+ crate::ConstantInner::Scalar {
+ value: crate::ScalarValue::Sint(int),
+ ..
+ } => u32::try_from(int).map_err(|_| Error::BadU32Constant(span)),
+ _ => Err(Error::BadU32Constant(span)),
+ }?;
+
+ (
+ crate::Expression::AccessIndex {
+ base: expr.handle,
+ index,
+ },
+ expr.is_reference,
+ )
+ } else {
+ (
+ crate::Expression::Access {
+ base: expr.handle,
+ index,
+ },
+ expr.is_reference,
+ )
+ }
+ }
+ ast::Expression::Member { base, ref field } => {
+ let TypedExpression {
+ handle,
+ is_reference,
+ } = self.expression_for_reference(base, ctx.reborrow())?;
+
+ ctx.grow_types(handle)?;
+ let temp_inner;
+ let (composite, wgsl_pointer) = match *ctx.resolved_inner(handle) {
+ crate::TypeInner::Pointer { base, .. } => {
+ (&ctx.module.types[base].inner, !is_reference)
+ }
+ crate::TypeInner::ValuePointer {
+ size: None,
+ kind,
+ width,
+ ..
+ } => {
+ temp_inner = crate::TypeInner::Scalar { kind, width };
+ (&temp_inner, !is_reference)
+ }
+ crate::TypeInner::ValuePointer {
+ size: Some(size),
+ kind,
+ width,
+ ..
+ } => {
+ temp_inner = crate::TypeInner::Vector { size, kind, width };
+ (&temp_inner, !is_reference)
+ }
+ ref other => (other, false),
+ };
+
+ if wgsl_pointer {
+ return Err(Error::Pointer(
+ "the value accessed by a `.member` expression",
+ ctx.ast_expressions.get_span(base),
+ ));
+ }
+
+ let access = match *composite {
+ crate::TypeInner::Struct { ref members, .. } => {
+ let index = members
+ .iter()
+ .position(|m| m.name.as_deref() == Some(field.name))
+ .ok_or(Error::BadAccessor(field.span))?
+ as u32;
+
+ (
+ crate::Expression::AccessIndex {
+ base: handle,
+ index,
+ },
+ is_reference,
+ )
+ }
+ crate::TypeInner::Vector { .. } | crate::TypeInner::Matrix { .. } => {
+ match Composition::make(field.name, field.span)? {
+ Composition::Multi(size, pattern) => {
+ let vector = ctx.apply_load_rule(TypedExpression {
+ handle,
+ is_reference,
+ });
+
+ (
+ crate::Expression::Swizzle {
+ size,
+ vector,
+ pattern,
+ },
+ false,
+ )
+ }
+ Composition::Single(index) => (
+ crate::Expression::AccessIndex {
+ base: handle,
+ index,
+ },
+ is_reference,
+ ),
+ }
+ }
+ _ => return Err(Error::BadAccessor(field.span)),
+ };
+
+ access
+ }
+ ast::Expression::Bitcast { expr, to, ty_span } => {
+ let expr = self.expression(expr, ctx.reborrow())?;
+ let to_resolved = self.resolve_ast_type(to, ctx.as_output())?;
+
+ let kind = match ctx.module.types[to_resolved].inner {
+ crate::TypeInner::Scalar { kind, .. } => kind,
+ crate::TypeInner::Vector { kind, .. } => kind,
+ _ => {
+ let ty = &ctx.typifier[expr];
+ return Err(Error::BadTypeCast {
+ from_type: ctx.format_type_resolution(ty),
+ span: ty_span,
+ to_type: ctx.format_type(to_resolved),
+ });
+ }
+ };
+
+ (
+ crate::Expression::As {
+ expr,
+ kind,
+ convert: None,
+ },
+ false,
+ )
+ }
+ };
+
+ let handle = ctx.naga_expressions.append(expr, span);
+ Ok(TypedExpression {
+ handle,
+ is_reference,
+ })
+ }
+
+ /// Generate Naga IR for call expressions and statements, and type
+ /// constructor expressions.
+ ///
+ /// The "function" being called is simply an `Ident` that we know refers to
+ /// some module-scope definition.
+ ///
+ /// - If it is the name of a type, then the expression is a type constructor
+ /// expression: either constructing a value from components, a conversion
+ /// expression, or a zero value expression.
+ ///
+ /// - If it is the name of a function, then we're generating a [`Call`]
+ /// statement. We may be in the midst of generating code for an
+ /// expression, in which case we must generate an `Emit` statement to
+ /// force evaluation of the IR expressions we've generated so far, add the
+ /// `Call` statement to the current block, and then resume generating
+ /// expressions.
+ ///
+ /// [`Call`]: crate::Statement::Call
+ fn call(
+ &mut self,
+ span: Span,
+ function: &ast::Ident<'source>,
+ arguments: &[Handle<ast::Expression<'source>>],
+ mut ctx: ExpressionContext<'source, '_, '_>,
+ ) -> Result<Option<Handle<crate::Expression>>, Error<'source>> {
+ match ctx.globals.get(function.name) {
+ Some(&LoweredGlobalDecl::Type(ty)) => {
+ let handle = self.construct(
+ span,
+ &ast::ConstructorType::Type(ty),
+ function.span,
+ arguments,
+ ctx.reborrow(),
+ )?;
+ Ok(Some(handle))
+ }
+ Some(&LoweredGlobalDecl::Const(_) | &LoweredGlobalDecl::Var(_)) => {
+ Err(Error::Unexpected(function.span, ExpectedToken::Function))
+ }
+ Some(&LoweredGlobalDecl::EntryPoint) => Err(Error::CalledEntryPoint(function.span)),
+ Some(&LoweredGlobalDecl::Function(function)) => {
+ let arguments = arguments
+ .iter()
+ .map(|&arg| self.expression(arg, ctx.reborrow()))
+ .collect::<Result<Vec<_>, _>>()?;
+
+ ctx.block.extend(ctx.emitter.finish(ctx.naga_expressions));
+ let result = ctx.module.functions[function].result.is_some().then(|| {
+ ctx.naga_expressions
+ .append(crate::Expression::CallResult(function), span)
+ });
+ ctx.emitter.start(ctx.naga_expressions);
+ ctx.block.push(
+ crate::Statement::Call {
+ function,
+ arguments,
+ result,
+ },
+ span,
+ );
+
+ Ok(result)
+ }
+ None => {
+ let span = function.span;
+ let expr = if let Some(fun) = conv::map_relational_fun(function.name) {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let argument = self.expression(args.next()?, ctx.reborrow())?;
+ args.finish()?;
+
+ crate::Expression::Relational { fun, argument }
+ } else if let Some((axis, ctrl)) = conv::map_derivative(function.name) {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let expr = self.expression(args.next()?, ctx.reborrow())?;
+ args.finish()?;
+
+ crate::Expression::Derivative { axis, ctrl, expr }
+ } else if let Some(fun) = conv::map_standard_fun(function.name) {
+ let expected = fun.argument_count() as _;
+ let mut args = ctx.prepare_args(arguments, expected, span);
+
+ let arg = self.expression(args.next()?, ctx.reborrow())?;
+ let arg1 = args
+ .next()
+ .map(|x| self.expression(x, ctx.reborrow()))
+ .ok()
+ .transpose()?;
+ let arg2 = args
+ .next()
+ .map(|x| self.expression(x, ctx.reborrow()))
+ .ok()
+ .transpose()?;
+ let arg3 = args
+ .next()
+ .map(|x| self.expression(x, ctx.reborrow()))
+ .ok()
+ .transpose()?;
+
+ args.finish()?;
+
+ crate::Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ }
+ } else if let Some(fun) = Texture::map(function.name) {
+ self.texture_sample_helper(fun, arguments, span, ctx.reborrow())?
+ } else {
+ match function.name {
+ "select" => {
+ let mut args = ctx.prepare_args(arguments, 3, span);
+
+ let reject = self.expression(args.next()?, ctx.reborrow())?;
+ let accept = self.expression(args.next()?, ctx.reborrow())?;
+ let condition = self.expression(args.next()?, ctx.reborrow())?;
+
+ args.finish()?;
+
+ crate::Expression::Select {
+ reject,
+ accept,
+ condition,
+ }
+ }
+ "arrayLength" => {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let expr = self.expression(args.next()?, ctx.reborrow())?;
+ args.finish()?;
+
+ crate::Expression::ArrayLength(expr)
+ }
+ "atomicLoad" => {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let pointer = self.atomic_pointer(args.next()?, ctx.reborrow())?;
+ args.finish()?;
+
+ crate::Expression::Load { pointer }
+ }
+ "atomicStore" => {
+ let mut args = ctx.prepare_args(arguments, 2, span);
+ let pointer = self.atomic_pointer(args.next()?, ctx.reborrow())?;
+ let value = self.expression(args.next()?, ctx.reborrow())?;
+ args.finish()?;
+
+ ctx.block.extend(ctx.emitter.finish(ctx.naga_expressions));
+ ctx.emitter.start(ctx.naga_expressions);
+ ctx.block
+ .push(crate::Statement::Store { pointer, value }, span);
+ return Ok(None);
+ }
+ "atomicAdd" => {
+ return Ok(Some(self.atomic_helper(
+ span,
+ crate::AtomicFunction::Add,
+ arguments,
+ ctx.reborrow(),
+ )?))
+ }
+ "atomicSub" => {
+ return Ok(Some(self.atomic_helper(
+ span,
+ crate::AtomicFunction::Subtract,
+ arguments,
+ ctx.reborrow(),
+ )?))
+ }
+ "atomicAnd" => {
+ return Ok(Some(self.atomic_helper(
+ span,
+ crate::AtomicFunction::And,
+ arguments,
+ ctx.reborrow(),
+ )?))
+ }
+ "atomicOr" => {
+ return Ok(Some(self.atomic_helper(
+ span,
+ crate::AtomicFunction::InclusiveOr,
+ arguments,
+ ctx.reborrow(),
+ )?))
+ }
+ "atomicXor" => {
+ return Ok(Some(self.atomic_helper(
+ span,
+ crate::AtomicFunction::ExclusiveOr,
+ arguments,
+ ctx.reborrow(),
+ )?))
+ }
+ "atomicMin" => {
+ return Ok(Some(self.atomic_helper(
+ span,
+ crate::AtomicFunction::Min,
+ arguments,
+ ctx.reborrow(),
+ )?))
+ }
+ "atomicMax" => {
+ return Ok(Some(self.atomic_helper(
+ span,
+ crate::AtomicFunction::Max,
+ arguments,
+ ctx.reborrow(),
+ )?))
+ }
+ "atomicExchange" => {
+ return Ok(Some(self.atomic_helper(
+ span,
+ crate::AtomicFunction::Exchange { compare: None },
+ arguments,
+ ctx.reborrow(),
+ )?))
+ }
+ "atomicCompareExchangeWeak" => {
+ let mut args = ctx.prepare_args(arguments, 3, span);
+
+ let pointer = self.atomic_pointer(args.next()?, ctx.reborrow())?;
+
+ let compare = self.expression(args.next()?, ctx.reborrow())?;
+
+ let value = args.next()?;
+ let value_span = ctx.ast_expressions.get_span(value);
+ let value = self.expression(value, ctx.reborrow())?;
+ ctx.grow_types(value)?;
+
+ args.finish()?;
+
+ let expression = match *ctx.resolved_inner(value) {
+ crate::TypeInner::Scalar { kind, width } => {
+ crate::Expression::AtomicResult {
+ //TODO: cache this to avoid generating duplicate types
+ ty: ctx
+ .module
+ .generate_atomic_compare_exchange_result(kind, width),
+ comparison: true,
+ }
+ }
+ _ => return Err(Error::InvalidAtomicOperandType(value_span)),
+ };
+
+ let result = ctx.interrupt_emitter(expression, span);
+ ctx.block.push(
+ crate::Statement::Atomic {
+ pointer,
+ fun: crate::AtomicFunction::Exchange {
+ compare: Some(compare),
+ },
+ value,
+ result,
+ },
+ span,
+ );
+ return Ok(Some(result));
+ }
+ "storageBarrier" => {
+ ctx.prepare_args(arguments, 0, span).finish()?;
+
+ ctx.block
+ .push(crate::Statement::Barrier(crate::Barrier::STORAGE), span);
+ return Ok(None);
+ }
+ "workgroupBarrier" => {
+ ctx.prepare_args(arguments, 0, span).finish()?;
+
+ ctx.block
+ .push(crate::Statement::Barrier(crate::Barrier::WORK_GROUP), span);
+ return Ok(None);
+ }
+ "textureStore" => {
+ let mut args = ctx.prepare_args(arguments, 3, span);
+
+ let image = args.next()?;
+ let image_span = ctx.ast_expressions.get_span(image);
+ let image = self.expression(image, ctx.reborrow())?;
+
+ let coordinate = self.expression(args.next()?, ctx.reborrow())?;
+
+ let (_, arrayed) = ctx.image_data(image, image_span)?;
+ let array_index = arrayed
+ .then(|| self.expression(args.next()?, ctx.reborrow()))
+ .transpose()?;
+
+ let value = self.expression(args.next()?, ctx.reborrow())?;
+
+ args.finish()?;
+
+ ctx.block.extend(ctx.emitter.finish(ctx.naga_expressions));
+ ctx.emitter.start(ctx.naga_expressions);
+ let stmt = crate::Statement::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ };
+ ctx.block.push(stmt, span);
+ return Ok(None);
+ }
+ "textureLoad" => {
+ let mut args = ctx.prepare_args(arguments, 3, span);
+
+ let image = args.next()?;
+ let image_span = ctx.ast_expressions.get_span(image);
+ let image = self.expression(image, ctx.reborrow())?;
+
+ let coordinate = self.expression(args.next()?, ctx.reborrow())?;
+
+ let (class, arrayed) = ctx.image_data(image, image_span)?;
+ let array_index = arrayed
+ .then(|| self.expression(args.next()?, ctx.reborrow()))
+ .transpose()?;
+
+ let level = class
+ .is_mipmapped()
+ .then(|| self.expression(args.next()?, ctx.reborrow()))
+ .transpose()?;
+
+ let sample = class
+ .is_multisampled()
+ .then(|| self.expression(args.next()?, ctx.reborrow()))
+ .transpose()?;
+
+ args.finish()?;
+
+ crate::Expression::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ level,
+ sample,
+ }
+ }
+ "textureDimensions" => {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let image = self.expression(args.next()?, ctx.reborrow())?;
+ let level = args
+ .next()
+ .map(|arg| self.expression(arg, ctx.reborrow()))
+ .ok()
+ .transpose()?;
+ args.finish()?;
+
+ crate::Expression::ImageQuery {
+ image,
+ query: crate::ImageQuery::Size { level },
+ }
+ }
+ "textureNumLevels" => {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let image = self.expression(args.next()?, ctx.reborrow())?;
+ args.finish()?;
+
+ crate::Expression::ImageQuery {
+ image,
+ query: crate::ImageQuery::NumLevels,
+ }
+ }
+ "textureNumLayers" => {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let image = self.expression(args.next()?, ctx.reborrow())?;
+ args.finish()?;
+
+ crate::Expression::ImageQuery {
+ image,
+ query: crate::ImageQuery::NumLayers,
+ }
+ }
+ "textureNumSamples" => {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let image = self.expression(args.next()?, ctx.reborrow())?;
+ args.finish()?;
+
+ crate::Expression::ImageQuery {
+ image,
+ query: crate::ImageQuery::NumSamples,
+ }
+ }
+ "rayQueryInitialize" => {
+ let mut args = ctx.prepare_args(arguments, 3, span);
+ let query = self.ray_query_pointer(args.next()?, ctx.reborrow())?;
+ let acceleration_structure =
+ self.expression(args.next()?, ctx.reborrow())?;
+ let descriptor = self.expression(args.next()?, ctx.reborrow())?;
+ args.finish()?;
+
+ let _ = ctx.module.generate_ray_desc_type();
+ let fun = crate::RayQueryFunction::Initialize {
+ acceleration_structure,
+ descriptor,
+ };
+
+ ctx.block.extend(ctx.emitter.finish(ctx.naga_expressions));
+ ctx.emitter.start(ctx.naga_expressions);
+ ctx.block
+ .push(crate::Statement::RayQuery { query, fun }, span);
+ return Ok(None);
+ }
+ "rayQueryProceed" => {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let query = self.ray_query_pointer(args.next()?, ctx.reborrow())?;
+ args.finish()?;
+
+ ctx.block.extend(ctx.emitter.finish(ctx.naga_expressions));
+ let result = ctx
+ .naga_expressions
+ .append(crate::Expression::RayQueryProceedResult, span);
+ let fun = crate::RayQueryFunction::Proceed { result };
+
+ ctx.emitter.start(ctx.naga_expressions);
+ ctx.block
+ .push(crate::Statement::RayQuery { query, fun }, span);
+ return Ok(Some(result));
+ }
+ "rayQueryGetCommittedIntersection" => {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let query = self.ray_query_pointer(args.next()?, ctx.reborrow())?;
+ args.finish()?;
+
+ let _ = ctx.module.generate_ray_intersection_type();
+
+ crate::Expression::RayQueryGetIntersection {
+ query,
+ committed: true,
+ }
+ }
+ "RayDesc" => {
+ let ty = ctx.module.generate_ray_desc_type();
+ let handle = self.construct(
+ span,
+ &ast::ConstructorType::Type(ty),
+ function.span,
+ arguments,
+ ctx.reborrow(),
+ )?;
+ return Ok(Some(handle));
+ }
+ _ => return Err(Error::UnknownIdent(function.span, function.name)),
+ }
+ };
+
+ let expr = ctx.naga_expressions.append(expr, span);
+ Ok(Some(expr))
+ }
+ }
+ }
+
+ fn atomic_pointer(
+ &mut self,
+ expr: Handle<ast::Expression<'source>>,
+ mut ctx: ExpressionContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ let span = ctx.ast_expressions.get_span(expr);
+ let pointer = self.expression(expr, ctx.reborrow())?;
+
+ ctx.grow_types(pointer)?;
+ match *ctx.resolved_inner(pointer) {
+ crate::TypeInner::Pointer { base, .. } => match ctx.module.types[base].inner {
+ crate::TypeInner::Atomic { .. } => Ok(pointer),
+ ref other => {
+ log::error!("Pointer type to {:?} passed to atomic op", other);
+ Err(Error::InvalidAtomicPointer(span))
+ }
+ },
+ ref other => {
+ log::error!("Type {:?} passed to atomic op", other);
+ Err(Error::InvalidAtomicPointer(span))
+ }
+ }
+ }
+
+ fn atomic_helper(
+ &mut self,
+ span: Span,
+ fun: crate::AtomicFunction,
+ args: &[Handle<ast::Expression<'source>>],
+ mut ctx: ExpressionContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ let mut args = ctx.prepare_args(args, 2, span);
+
+ let pointer = self.atomic_pointer(args.next()?, ctx.reborrow())?;
+
+ let value = args.next()?;
+ let value = self.expression(value, ctx.reborrow())?;
+ let ty = ctx.register_type(value)?;
+
+ args.finish()?;
+
+ let result = ctx.interrupt_emitter(
+ crate::Expression::AtomicResult {
+ ty,
+ comparison: false,
+ },
+ span,
+ );
+ ctx.block.push(
+ crate::Statement::Atomic {
+ pointer,
+ fun,
+ value,
+ result,
+ },
+ span,
+ );
+ Ok(result)
+ }
+
+ fn texture_sample_helper(
+ &mut self,
+ fun: Texture,
+ args: &[Handle<ast::Expression<'source>>],
+ span: Span,
+ mut ctx: ExpressionContext<'source, '_, '_>,
+ ) -> Result<crate::Expression, Error<'source>> {
+ let mut args = ctx.prepare_args(args, fun.min_argument_count(), span);
+
+ let (image, gather) = match fun {
+ Texture::Gather => {
+ let image_or_component = args.next()?;
+ match self.gather_component(image_or_component, ctx.reborrow())? {
+ Some(component) => {
+ let image = args.next()?;
+ (image, Some(component))
+ }
+ None => (image_or_component, Some(crate::SwizzleComponent::X)),
+ }
+ }
+ Texture::GatherCompare => {
+ let image = args.next()?;
+ (image, Some(crate::SwizzleComponent::X))
+ }
+
+ _ => {
+ let image = args.next()?;
+ (image, None)
+ }
+ };
+
+ let image_span = ctx.ast_expressions.get_span(image);
+ let image = self.expression(image, ctx.reborrow())?;
+
+ let sampler = self.expression(args.next()?, ctx.reborrow())?;
+
+ let coordinate = self.expression(args.next()?, ctx.reborrow())?;
+
+ let (_, arrayed) = ctx.image_data(image, image_span)?;
+ let array_index = arrayed
+ .then(|| self.expression(args.next()?, ctx.reborrow()))
+ .transpose()?;
+
+ let (level, depth_ref) = match fun {
+ Texture::Gather => (crate::SampleLevel::Zero, None),
+ Texture::GatherCompare => {
+ let reference = self.expression(args.next()?, ctx.reborrow())?;
+ (crate::SampleLevel::Zero, Some(reference))
+ }
+
+ Texture::Sample => (crate::SampleLevel::Auto, None),
+ Texture::SampleBias => {
+ let bias = self.expression(args.next()?, ctx.reborrow())?;
+ (crate::SampleLevel::Bias(bias), None)
+ }
+ Texture::SampleCompare => {
+ let reference = self.expression(args.next()?, ctx.reborrow())?;
+ (crate::SampleLevel::Auto, Some(reference))
+ }
+ Texture::SampleCompareLevel => {
+ let reference = self.expression(args.next()?, ctx.reborrow())?;
+ (crate::SampleLevel::Zero, Some(reference))
+ }
+ Texture::SampleGrad => {
+ let x = self.expression(args.next()?, ctx.reborrow())?;
+ let y = self.expression(args.next()?, ctx.reborrow())?;
+ (crate::SampleLevel::Gradient { x, y }, None)
+ }
+ Texture::SampleLevel => {
+ let level = self.expression(args.next()?, ctx.reborrow())?;
+ (crate::SampleLevel::Exact(level), None)
+ }
+ };
+
+ let offset = args
+ .next()
+ .map(|arg| self.constant(arg, ctx.as_output()))
+ .ok()
+ .transpose()?;
+
+ args.finish()?;
+
+ Ok(crate::Expression::ImageSample {
+ image,
+ sampler,
+ gather,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ })
+ }
+
+ fn gather_component(
+ &mut self,
+ expr: Handle<ast::Expression<'source>>,
+ mut ctx: ExpressionContext<'source, '_, '_>,
+ ) -> Result<Option<crate::SwizzleComponent>, Error<'source>> {
+ let span = ctx.ast_expressions.get_span(expr);
+
+ let constant = match self.constant_inner(expr, ctx.as_output()).ok() {
+ Some(ConstantOrInner::Constant(c)) => ctx.module.constants[c].inner.clone(),
+ Some(ConstantOrInner::Inner(inner)) => inner,
+ None => return Ok(None),
+ };
+
+ let int = match constant {
+ crate::ConstantInner::Scalar {
+ value: crate::ScalarValue::Sint(i),
+ ..
+ } if i >= 0 => i as u64,
+ crate::ConstantInner::Scalar {
+ value: crate::ScalarValue::Uint(i),
+ ..
+ } => i,
+ _ => {
+ return Err(Error::InvalidGatherComponent(span));
+ }
+ };
+
+ crate::SwizzleComponent::XYZW
+ .get(int as usize)
+ .copied()
+ .map(Some)
+ .ok_or(Error::InvalidGatherComponent(span))
+ }
+
+ fn r#struct(
+ &mut self,
+ s: &ast::Struct<'source>,
+ span: Span,
+ mut ctx: OutputContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Type>, Error<'source>> {
+ let mut offset = 0;
+ let mut struct_alignment = Alignment::ONE;
+ let mut members = Vec::with_capacity(s.members.len());
+
+ for member in s.members.iter() {
+ let ty = self.resolve_ast_type(member.ty, ctx.reborrow())?;
+
+ self.layouter
+ .update(&ctx.module.types, &ctx.module.constants)
+ .unwrap();
+
+ let member_min_size = self.layouter[ty].size;
+ let member_min_alignment = self.layouter[ty].alignment;
+
+ let member_size = if let Some((size, span)) = member.size {
+ if size < member_min_size {
+ return Err(Error::SizeAttributeTooLow(span, member_min_size));
+ } else {
+ size
+ }
+ } else {
+ member_min_size
+ };
+
+ let member_alignment = if let Some((align, span)) = member.align {
+ if let Some(alignment) = Alignment::new(align) {
+ if alignment < member_min_alignment {
+ return Err(Error::AlignAttributeTooLow(span, member_min_alignment));
+ } else {
+ alignment
+ }
+ } else {
+ return Err(Error::NonPowerOfTwoAlignAttribute(span));
+ }
+ } else {
+ member_min_alignment
+ };
+
+ let binding = self.interpolate_default(&member.binding, ty, ctx.reborrow());
+
+ offset = member_alignment.round_up(offset);
+ struct_alignment = struct_alignment.max(member_alignment);
+
+ members.push(crate::StructMember {
+ name: Some(member.name.name.to_owned()),
+ ty,
+ binding,
+ offset,
+ });
+
+ offset += member_size;
+ }
+
+ let size = struct_alignment.round_up(offset);
+ let inner = crate::TypeInner::Struct {
+ members,
+ span: size,
+ };
+
+ let handle = ctx.module.types.insert(
+ crate::Type {
+ name: Some(s.name.name.to_string()),
+ inner,
+ },
+ span,
+ );
+ Ok(handle)
+ }
+
+ /// Return a Naga `Handle<Type>` representing the front-end type `handle`.
+ fn resolve_ast_type(
+ &mut self,
+ handle: Handle<ast::Type<'source>>,
+ mut ctx: OutputContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Type>, Error<'source>> {
+ let inner = match ctx.types[handle] {
+ ast::Type::Scalar { kind, width } => crate::TypeInner::Scalar { kind, width },
+ ast::Type::Vector { size, kind, width } => {
+ crate::TypeInner::Vector { size, kind, width }
+ }
+ ast::Type::Matrix {
+ rows,
+ columns,
+ width,
+ } => crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ },
+ ast::Type::Atomic { kind, width } => crate::TypeInner::Atomic { kind, width },
+ ast::Type::Pointer { base, space } => {
+ let base = self.resolve_ast_type(base, ctx.reborrow())?;
+ crate::TypeInner::Pointer { base, space }
+ }
+ ast::Type::Array { base, size } => {
+ let base = self.resolve_ast_type(base, ctx.reborrow())?;
+ self.layouter
+ .update(&ctx.module.types, &ctx.module.constants)
+ .unwrap();
+
+ crate::TypeInner::Array {
+ base,
+ size: match size {
+ ast::ArraySize::Constant(constant) => {
+ let constant = self.constant(constant, ctx.reborrow())?;
+ crate::ArraySize::Constant(constant)
+ }
+ ast::ArraySize::Dynamic => crate::ArraySize::Dynamic,
+ },
+ stride: self.layouter[base].to_stride(),
+ }
+ }
+ ast::Type::Image {
+ dim,
+ arrayed,
+ class,
+ } => crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ },
+ ast::Type::Sampler { comparison } => crate::TypeInner::Sampler { comparison },
+ ast::Type::AccelerationStructure => crate::TypeInner::AccelerationStructure,
+ ast::Type::RayQuery => crate::TypeInner::RayQuery,
+ ast::Type::BindingArray { base, size } => {
+ let base = self.resolve_ast_type(base, ctx.reborrow())?;
+
+ crate::TypeInner::BindingArray {
+ base,
+ size: match size {
+ ast::ArraySize::Constant(constant) => {
+ let constant = self.constant(constant, ctx.reborrow())?;
+ crate::ArraySize::Constant(constant)
+ }
+ ast::ArraySize::Dynamic => crate::ArraySize::Dynamic,
+ },
+ }
+ }
+ ast::Type::RayDesc => {
+ return Ok(ctx.module.generate_ray_desc_type());
+ }
+ ast::Type::RayIntersection => {
+ return Ok(ctx.module.generate_ray_intersection_type());
+ }
+ ast::Type::User(ref ident) => {
+ return match ctx.globals.get(ident.name) {
+ Some(&LoweredGlobalDecl::Type(handle)) => Ok(handle),
+ Some(_) => Err(Error::Unexpected(ident.span, ExpectedToken::Type)),
+ None => Err(Error::UnknownType(ident.span)),
+ }
+ }
+ };
+
+ Ok(ctx.ensure_type_exists(inner))
+ }
+
+ /// Find or construct a Naga [`Constant`] whose value is `expr`.
+ ///
+ /// The `ctx` indicates the Naga [`Module`] to which we should add
+ /// new `Constant`s or [`Type`]s as needed.
+ ///
+ /// [`Module`]: crate::Module
+ /// [`Constant`]: crate::Constant
+ /// [`Type`]: crate::Type
+ fn constant(
+ &mut self,
+ expr: Handle<ast::Expression<'source>>,
+ mut ctx: OutputContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Constant>, Error<'source>> {
+ let inner = match self.constant_inner(expr, ctx.reborrow())? {
+ ConstantOrInner::Constant(c) => return Ok(c),
+ ConstantOrInner::Inner(inner) => inner,
+ };
+
+ let c = ctx.module.constants.fetch_or_append(
+ crate::Constant {
+ name: None,
+ specialization: None,
+ inner,
+ },
+ Span::UNDEFINED,
+ );
+ Ok(c)
+ }
+
+ fn constant_inner(
+ &mut self,
+ expr: Handle<ast::Expression<'source>>,
+ mut ctx: OutputContext<'source, '_, '_>,
+ ) -> Result<ConstantOrInner, Error<'source>> {
+ let span = ctx.ast_expressions.get_span(expr);
+ let inner = match ctx.ast_expressions[expr] {
+ ast::Expression::Literal(literal) => match literal {
+ ast::Literal::Number(Number::F32(f)) => crate::ConstantInner::Scalar {
+ width: 4,
+ value: crate::ScalarValue::Float(f as _),
+ },
+ ast::Literal::Number(Number::I32(i)) => crate::ConstantInner::Scalar {
+ width: 4,
+ value: crate::ScalarValue::Sint(i as _),
+ },
+ ast::Literal::Number(Number::U32(u)) => crate::ConstantInner::Scalar {
+ width: 4,
+ value: crate::ScalarValue::Uint(u as _),
+ },
+ ast::Literal::Number(_) => {
+ unreachable!("got abstract numeric type when not expected");
+ }
+ ast::Literal::Bool(b) => crate::ConstantInner::Scalar {
+ width: 1,
+ value: crate::ScalarValue::Bool(b),
+ },
+ },
+ ast::Expression::Ident(ast::IdentExpr::Local(_)) => {
+ return Err(Error::Unexpected(span, ExpectedToken::Constant))
+ }
+ ast::Expression::Ident(ast::IdentExpr::Unresolved(name)) => {
+ return if let Some(global) = ctx.globals.get(name) {
+ match *global {
+ LoweredGlobalDecl::Const(handle) => Ok(ConstantOrInner::Constant(handle)),
+ _ => Err(Error::Unexpected(span, ExpectedToken::Constant)),
+ }
+ } else {
+ Err(Error::UnknownIdent(span, name))
+ }
+ }
+ ast::Expression::Construct {
+ ref ty,
+ ref components,
+ ..
+ } => self.const_construct(span, ty, components, ctx.reborrow())?,
+ ast::Expression::Call {
+ ref function,
+ ref arguments,
+ } => match ctx.globals.get(function.name) {
+ Some(&LoweredGlobalDecl::Type(ty)) => self.const_construct(
+ span,
+ &ast::ConstructorType::Type(ty),
+ arguments,
+ ctx.reborrow(),
+ )?,
+ Some(_) => return Err(Error::ConstExprUnsupported(span)),
+ None => return Err(Error::UnknownIdent(function.span, function.name)),
+ },
+ _ => return Err(Error::ConstExprUnsupported(span)),
+ };
+
+ Ok(ConstantOrInner::Inner(inner))
+ }
+
+ fn interpolate_default(
+ &mut self,
+ binding: &Option<crate::Binding>,
+ ty: Handle<crate::Type>,
+ ctx: OutputContext<'source, '_, '_>,
+ ) -> Option<crate::Binding> {
+ let mut binding = binding.clone();
+ if let Some(ref mut binding) = binding {
+ binding.apply_default_interpolation(&ctx.module.types[ty].inner);
+ }
+
+ binding
+ }
+
+ fn ray_query_pointer(
+ &mut self,
+ expr: Handle<ast::Expression<'source>>,
+ mut ctx: ExpressionContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ let span = ctx.ast_expressions.get_span(expr);
+ let pointer = self.expression(expr, ctx.reborrow())?;
+
+ ctx.grow_types(pointer)?;
+ match *ctx.resolved_inner(pointer) {
+ crate::TypeInner::Pointer { base, .. } => match ctx.module.types[base].inner {
+ crate::TypeInner::RayQuery => Ok(pointer),
+ ref other => {
+ log::error!("Pointer type to {:?} passed to ray query op", other);
+ Err(Error::InvalidRayQueryPointer(span))
+ }
+ },
+ ref other => {
+ log::error!("Type {:?} passed to ray query op", other);
+ Err(Error::InvalidRayQueryPointer(span))
+ }
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/mod.rs b/third_party/rust/naga/src/front/wgsl/mod.rs
new file mode 100644
index 0000000000..eb21fae6c9
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/mod.rs
@@ -0,0 +1,340 @@
+/*!
+Frontend for [WGSL][wgsl] (WebGPU Shading Language).
+
+[wgsl]: https://gpuweb.github.io/gpuweb/wgsl.html
+*/
+
+mod error;
+mod index;
+mod lower;
+mod parse;
+#[cfg(test)]
+mod tests;
+
+use crate::arena::{Arena, UniqueArena};
+
+use crate::front::wgsl::error::Error;
+use crate::front::wgsl::parse::Parser;
+use thiserror::Error;
+
+pub use crate::front::wgsl::error::ParseError;
+use crate::front::wgsl::lower::Lowerer;
+
+pub struct Frontend {
+ parser: Parser,
+}
+
+impl Frontend {
+ pub const fn new() -> Self {
+ Self {
+ parser: Parser::new(),
+ }
+ }
+
+ pub fn parse(&mut self, source: &str) -> Result<crate::Module, ParseError> {
+ self.inner(source).map_err(|x| x.as_parse_error(source))
+ }
+
+ fn inner<'a>(&mut self, source: &'a str) -> Result<crate::Module, Error<'a>> {
+ let tu = self.parser.parse(source)?;
+ let index = index::Index::generate(&tu)?;
+ let module = Lowerer::new(&index).lower(&tu)?;
+
+ Ok(module)
+ }
+}
+
+pub fn parse_str(source: &str) -> Result<crate::Module, ParseError> {
+ Frontend::new().parse(source)
+}
+
+impl crate::StorageFormat {
+ const fn to_wgsl(self) -> &'static str {
+ use crate::StorageFormat as Sf;
+ match self {
+ Sf::R8Unorm => "r8unorm",
+ Sf::R8Snorm => "r8snorm",
+ Sf::R8Uint => "r8uint",
+ Sf::R8Sint => "r8sint",
+ Sf::R16Uint => "r16uint",
+ Sf::R16Sint => "r16sint",
+ Sf::R16Float => "r16float",
+ Sf::Rg8Unorm => "rg8unorm",
+ Sf::Rg8Snorm => "rg8snorm",
+ Sf::Rg8Uint => "rg8uint",
+ Sf::Rg8Sint => "rg8sint",
+ Sf::R32Uint => "r32uint",
+ Sf::R32Sint => "r32sint",
+ Sf::R32Float => "r32float",
+ Sf::Rg16Uint => "rg16uint",
+ Sf::Rg16Sint => "rg16sint",
+ Sf::Rg16Float => "rg16float",
+ Sf::Rgba8Unorm => "rgba8unorm",
+ Sf::Rgba8Snorm => "rgba8snorm",
+ Sf::Rgba8Uint => "rgba8uint",
+ Sf::Rgba8Sint => "rgba8sint",
+ Sf::Rgb10a2Unorm => "rgb10a2unorm",
+ Sf::Rg11b10Float => "rg11b10float",
+ Sf::Rg32Uint => "rg32uint",
+ Sf::Rg32Sint => "rg32sint",
+ Sf::Rg32Float => "rg32float",
+ Sf::Rgba16Uint => "rgba16uint",
+ Sf::Rgba16Sint => "rgba16sint",
+ Sf::Rgba16Float => "rgba16float",
+ Sf::Rgba32Uint => "rgba32uint",
+ Sf::Rgba32Sint => "rgba32sint",
+ Sf::Rgba32Float => "rgba32float",
+ Sf::R16Unorm => "r16unorm",
+ Sf::R16Snorm => "r16snorm",
+ Sf::Rg16Unorm => "rg16unorm",
+ Sf::Rg16Snorm => "rg16snorm",
+ Sf::Rgba16Unorm => "rgba16unorm",
+ Sf::Rgba16Snorm => "rgba16snorm",
+ }
+ }
+}
+
+impl crate::TypeInner {
+ /// Formats the type as it is written in wgsl.
+ ///
+ /// For example `vec3<f32>`.
+ ///
+ /// Note: The names of a `TypeInner::Struct` is not known. Therefore this method will simply return "struct" for them.
+ fn to_wgsl(
+ &self,
+ types: &UniqueArena<crate::Type>,
+ constants: &Arena<crate::Constant>,
+ ) -> String {
+ use crate::TypeInner as Ti;
+
+ match *self {
+ Ti::Scalar { kind, width } => kind.to_wgsl(width),
+ Ti::Vector { size, kind, width } => {
+ format!("vec{}<{}>", size as u32, kind.to_wgsl(width))
+ }
+ Ti::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ format!(
+ "mat{}x{}<{}>",
+ columns as u32,
+ rows as u32,
+ crate::ScalarKind::Float.to_wgsl(width),
+ )
+ }
+ Ti::Atomic { kind, width } => {
+ format!("atomic<{}>", kind.to_wgsl(width))
+ }
+ Ti::Pointer { base, .. } => {
+ let base = &types[base];
+ let name = base.name.as_deref().unwrap_or("unknown");
+ format!("ptr<{name}>")
+ }
+ Ti::ValuePointer { kind, width, .. } => {
+ format!("ptr<{}>", kind.to_wgsl(width))
+ }
+ Ti::Array { base, size, .. } => {
+ let member_type = &types[base];
+ let base = member_type.name.as_deref().unwrap_or("unknown");
+ match size {
+ crate::ArraySize::Constant(size) => {
+ let constant = &constants[size];
+ let size = constant
+ .name
+ .clone()
+ .unwrap_or_else(|| match constant.inner {
+ crate::ConstantInner::Scalar {
+ value: crate::ScalarValue::Uint(size),
+ ..
+ } => size.to_string(),
+ crate::ConstantInner::Scalar {
+ value: crate::ScalarValue::Sint(size),
+ ..
+ } => size.to_string(),
+ _ => "?".to_string(),
+ });
+ format!("array<{base}, {size}>")
+ }
+ crate::ArraySize::Dynamic => format!("array<{base}>"),
+ }
+ }
+ Ti::Struct { .. } => {
+ // TODO: Actually output the struct?
+ "struct".to_string()
+ }
+ Ti::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ let dim_suffix = match dim {
+ crate::ImageDimension::D1 => "_1d",
+ crate::ImageDimension::D2 => "_2d",
+ crate::ImageDimension::D3 => "_3d",
+ crate::ImageDimension::Cube => "_cube",
+ };
+ let array_suffix = if arrayed { "_array" } else { "" };
+
+ let class_suffix = match class {
+ crate::ImageClass::Sampled { multi: true, .. } => "_multisampled",
+ crate::ImageClass::Depth { multi: false } => "_depth",
+ crate::ImageClass::Depth { multi: true } => "_depth_multisampled",
+ crate::ImageClass::Sampled { multi: false, .. }
+ | crate::ImageClass::Storage { .. } => "",
+ };
+
+ let type_in_brackets = match class {
+ crate::ImageClass::Sampled { kind, .. } => {
+ // Note: The only valid widths are 4 bytes wide.
+ // The lexer has already verified this, so we can safely assume it here.
+ // https://gpuweb.github.io/gpuweb/wgsl/#sampled-texture-type
+ let element_type = kind.to_wgsl(4);
+ format!("<{element_type}>")
+ }
+ crate::ImageClass::Depth { multi: _ } => String::new(),
+ crate::ImageClass::Storage { format, access } => {
+ if access.contains(crate::StorageAccess::STORE) {
+ format!("<{},write>", format.to_wgsl())
+ } else {
+ format!("<{}>", format.to_wgsl())
+ }
+ }
+ };
+
+ format!("texture{class_suffix}{dim_suffix}{array_suffix}{type_in_brackets}")
+ }
+ Ti::Sampler { .. } => "sampler".to_string(),
+ Ti::AccelerationStructure => "acceleration_structure".to_string(),
+ Ti::RayQuery => "ray_query".to_string(),
+ Ti::BindingArray { base, size, .. } => {
+ let member_type = &types[base];
+ let base = member_type.name.as_deref().unwrap_or("unknown");
+ match size {
+ crate::ArraySize::Constant(size) => {
+ let size = constants[size].name.as_deref().unwrap_or("unknown");
+ format!("binding_array<{base}, {size}>")
+ }
+ crate::ArraySize::Dynamic => format!("binding_array<{base}>"),
+ }
+ }
+ }
+ }
+}
+
+mod type_inner_tests {
+ #[test]
+ fn to_wgsl() {
+ let mut types = crate::UniqueArena::new();
+ let mut constants = crate::Arena::new();
+ let c = constants.append(
+ crate::Constant {
+ name: Some("C".to_string()),
+ specialization: None,
+ inner: crate::ConstantInner::Scalar {
+ width: 4,
+ value: crate::ScalarValue::Uint(32),
+ },
+ },
+ Default::default(),
+ );
+
+ let mytype1 = types.insert(
+ crate::Type {
+ name: Some("MyType1".to_string()),
+ inner: crate::TypeInner::Struct {
+ members: vec![],
+ span: 0,
+ },
+ },
+ Default::default(),
+ );
+ let mytype2 = types.insert(
+ crate::Type {
+ name: Some("MyType2".to_string()),
+ inner: crate::TypeInner::Struct {
+ members: vec![],
+ span: 0,
+ },
+ },
+ Default::default(),
+ );
+
+ let array = crate::TypeInner::Array {
+ base: mytype1,
+ stride: 4,
+ size: crate::ArraySize::Constant(c),
+ };
+ assert_eq!(array.to_wgsl(&types, &constants), "array<MyType1, C>");
+
+ let mat = crate::TypeInner::Matrix {
+ rows: crate::VectorSize::Quad,
+ columns: crate::VectorSize::Bi,
+ width: 8,
+ };
+ assert_eq!(mat.to_wgsl(&types, &constants), "mat2x4<f64>");
+
+ let ptr = crate::TypeInner::Pointer {
+ base: mytype2,
+ space: crate::AddressSpace::Storage {
+ access: crate::StorageAccess::default(),
+ },
+ };
+ assert_eq!(ptr.to_wgsl(&types, &constants), "ptr<MyType2>");
+
+ let img1 = crate::TypeInner::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: false,
+ class: crate::ImageClass::Sampled {
+ kind: crate::ScalarKind::Float,
+ multi: true,
+ },
+ };
+ assert_eq!(
+ img1.to_wgsl(&types, &constants),
+ "texture_multisampled_2d<f32>"
+ );
+
+ let img2 = crate::TypeInner::Image {
+ dim: crate::ImageDimension::Cube,
+ arrayed: true,
+ class: crate::ImageClass::Depth { multi: false },
+ };
+ assert_eq!(img2.to_wgsl(&types, &constants), "texture_depth_cube_array");
+
+ let img3 = crate::TypeInner::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: false,
+ class: crate::ImageClass::Depth { multi: true },
+ };
+ assert_eq!(
+ img3.to_wgsl(&types, &constants),
+ "texture_depth_multisampled_2d"
+ );
+
+ let array = crate::TypeInner::BindingArray {
+ base: mytype1,
+ size: crate::ArraySize::Constant(c),
+ };
+ assert_eq!(
+ array.to_wgsl(&types, &constants),
+ "binding_array<MyType1, C>"
+ );
+ }
+}
+
+impl crate::ScalarKind {
+ /// Format a scalar kind+width as a type is written in wgsl.
+ ///
+ /// Examples: `f32`, `u64`, `bool`.
+ fn to_wgsl(self, width: u8) -> String {
+ let prefix = match self {
+ crate::ScalarKind::Sint => "i",
+ crate::ScalarKind::Uint => "u",
+ crate::ScalarKind::Float => "f",
+ crate::ScalarKind::Bool => return "bool".to_string(),
+ };
+ format!("{}{}", prefix, width * 8)
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/parse/ast.rs b/third_party/rust/naga/src/front/wgsl/parse/ast.rs
new file mode 100644
index 0000000000..2a56ac6f80
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/parse/ast.rs
@@ -0,0 +1,486 @@
+use crate::front::wgsl::parse::number::Number;
+use crate::{Arena, FastHashSet, Handle, Span};
+use std::hash::Hash;
+
+#[derive(Debug, Default)]
+pub struct TranslationUnit<'a> {
+ pub decls: Arena<GlobalDecl<'a>>,
+ /// The common expressions arena for the entire translation unit.
+ ///
+ /// All functions, global initializers, array lengths, etc. store their
+ /// expressions here. We apportion these out to individual Naga
+ /// [`Function`]s' expression arenas at lowering time. Keeping them all in a
+ /// single arena simplifies handling of things like array lengths (which are
+ /// effectively global and thus don't clearly belong to any function) and
+ /// initializers (which can appear in both function-local and module-scope
+ /// contexts).
+ ///
+ /// [`Function`]: crate::Function
+ pub expressions: Arena<Expression<'a>>,
+
+ /// Non-user-defined types, like `vec4<f32>` or `array<i32, 10>`.
+ ///
+ /// These are referred to by `Handle<ast::Type<'a>>` values.
+ /// User-defined types are referred to by name until lowering.
+ pub types: Arena<Type<'a>>,
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct Ident<'a> {
+ pub name: &'a str,
+ pub span: Span,
+}
+
+#[derive(Debug)]
+pub enum IdentExpr<'a> {
+ Unresolved(&'a str),
+ Local(Handle<Local>),
+}
+
+/// A reference to a module-scope definition or predeclared object.
+///
+/// Each [`GlobalDecl`] holds a set of these values, to be resolved to
+/// specific definitions later. To support de-duplication, `Eq` and
+/// `Hash` on a `Dependency` value consider only the name, not the
+/// source location at which the reference occurs.
+#[derive(Debug)]
+pub struct Dependency<'a> {
+ /// The name referred to.
+ pub ident: &'a str,
+
+ /// The location at which the reference to that name occurs.
+ pub usage: Span,
+}
+
+impl Hash for Dependency<'_> {
+ fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+ self.ident.hash(state);
+ }
+}
+
+impl PartialEq for Dependency<'_> {
+ fn eq(&self, other: &Self) -> bool {
+ self.ident == other.ident
+ }
+}
+
+impl Eq for Dependency<'_> {}
+
+/// A module-scope declaration.
+#[derive(Debug)]
+pub struct GlobalDecl<'a> {
+ pub kind: GlobalDeclKind<'a>,
+
+ /// Names of all module-scope or predeclared objects this
+ /// declaration uses.
+ pub dependencies: FastHashSet<Dependency<'a>>,
+}
+
+#[derive(Debug)]
+pub enum GlobalDeclKind<'a> {
+ Fn(Function<'a>),
+ Var(GlobalVariable<'a>),
+ Const(Const<'a>),
+ Struct(Struct<'a>),
+ Type(TypeAlias<'a>),
+}
+
+#[derive(Debug)]
+pub struct FunctionArgument<'a> {
+ pub name: Ident<'a>,
+ pub ty: Handle<Type<'a>>,
+ pub binding: Option<crate::Binding>,
+ pub handle: Handle<Local>,
+}
+
+#[derive(Debug)]
+pub struct FunctionResult<'a> {
+ pub ty: Handle<Type<'a>>,
+ pub binding: Option<crate::Binding>,
+}
+
+#[derive(Debug)]
+pub struct EntryPoint {
+ pub stage: crate::ShaderStage,
+ pub early_depth_test: Option<crate::EarlyDepthTest>,
+ pub workgroup_size: [u32; 3],
+}
+
+#[cfg(doc)]
+use crate::front::wgsl::lower::{ExpressionContext, StatementContext};
+
+#[derive(Debug)]
+pub struct Function<'a> {
+ pub entry_point: Option<EntryPoint>,
+ pub name: Ident<'a>,
+ pub arguments: Vec<FunctionArgument<'a>>,
+ pub result: Option<FunctionResult<'a>>,
+
+ /// Local variable and function argument arena.
+ ///
+ /// Note that the `Local` here is actually a zero-sized type. The AST keeps
+ /// all the detailed information about locals - names, types, etc. - in
+ /// [`LocalDecl`] statements. For arguments, that information is kept in
+ /// [`arguments`]. This `Arena`'s only role is to assign a unique `Handle`
+ /// to each of them, and track their definitions' spans for use in
+ /// diagnostics.
+ ///
+ /// In the AST, when an [`Ident`] expression refers to a local variable or
+ /// argument, its [`IdentExpr`] holds the referent's `Handle<Local>` in this
+ /// arena.
+ ///
+ /// During lowering, [`LocalDecl`] statements add entries to a per-function
+ /// table that maps `Handle<Local>` values to their Naga representations,
+ /// accessed via [`StatementContext::local_table`] and
+ /// [`ExpressionContext::local_table`]. This table is then consulted when
+ /// lowering subsequent [`Ident`] expressions.
+ ///
+ /// [`LocalDecl`]: StatementKind::LocalDecl
+ /// [`arguments`]: Function::arguments
+ /// [`Ident`]: Expression::Ident
+ /// [`StatementContext::local_table`]: StatementContext::local_table
+ /// [`ExpressionContext::local_table`]: ExpressionContext::local_table
+ pub locals: Arena<Local>,
+
+ pub body: Block<'a>,
+}
+
+#[derive(Debug)]
+pub struct GlobalVariable<'a> {
+ pub name: Ident<'a>,
+ pub space: crate::AddressSpace,
+ pub binding: Option<crate::ResourceBinding>,
+ pub ty: Handle<Type<'a>>,
+ pub init: Option<Handle<Expression<'a>>>,
+}
+
+#[derive(Debug)]
+pub struct StructMember<'a> {
+ pub name: Ident<'a>,
+ pub ty: Handle<Type<'a>>,
+ pub binding: Option<crate::Binding>,
+ pub align: Option<(u32, Span)>,
+ pub size: Option<(u32, Span)>,
+}
+
+#[derive(Debug)]
+pub struct Struct<'a> {
+ pub name: Ident<'a>,
+ pub members: Vec<StructMember<'a>>,
+}
+
+#[derive(Debug)]
+pub struct TypeAlias<'a> {
+ pub name: Ident<'a>,
+ pub ty: Handle<Type<'a>>,
+}
+
+#[derive(Debug)]
+pub struct Const<'a> {
+ pub name: Ident<'a>,
+ pub ty: Option<Handle<Type<'a>>>,
+ pub init: Handle<Expression<'a>>,
+}
+
+/// The size of an [`Array`] or [`BindingArray`].
+///
+/// [`Array`]: Type::Array
+/// [`BindingArray`]: Type::BindingArray
+#[derive(Debug, Copy, Clone)]
+pub enum ArraySize<'a> {
+ /// The length as a constant expression.
+ Constant(Handle<Expression<'a>>),
+ Dynamic,
+}
+
+#[derive(Debug)]
+pub enum Type<'a> {
+ Scalar {
+ kind: crate::ScalarKind,
+ width: crate::Bytes,
+ },
+ Vector {
+ size: crate::VectorSize,
+ kind: crate::ScalarKind,
+ width: crate::Bytes,
+ },
+ Matrix {
+ columns: crate::VectorSize,
+ rows: crate::VectorSize,
+ width: crate::Bytes,
+ },
+ Atomic {
+ kind: crate::ScalarKind,
+ width: crate::Bytes,
+ },
+ Pointer {
+ base: Handle<Type<'a>>,
+ space: crate::AddressSpace,
+ },
+ Array {
+ base: Handle<Type<'a>>,
+ size: ArraySize<'a>,
+ },
+ Image {
+ dim: crate::ImageDimension,
+ arrayed: bool,
+ class: crate::ImageClass,
+ },
+ Sampler {
+ comparison: bool,
+ },
+ AccelerationStructure,
+ RayQuery,
+ RayDesc,
+ RayIntersection,
+ BindingArray {
+ base: Handle<Type<'a>>,
+ size: ArraySize<'a>,
+ },
+
+ /// A user-defined type, like a struct or a type alias.
+ User(Ident<'a>),
+}
+
+#[derive(Debug, Default)]
+pub struct Block<'a> {
+ pub stmts: Vec<Statement<'a>>,
+}
+
+#[derive(Debug)]
+pub struct Statement<'a> {
+ pub kind: StatementKind<'a>,
+ pub span: Span,
+}
+
+#[derive(Debug)]
+pub enum StatementKind<'a> {
+ LocalDecl(LocalDecl<'a>),
+ Block(Block<'a>),
+ If {
+ condition: Handle<Expression<'a>>,
+ accept: Block<'a>,
+ reject: Block<'a>,
+ },
+ Switch {
+ selector: Handle<Expression<'a>>,
+ cases: Vec<SwitchCase<'a>>,
+ },
+ Loop {
+ body: Block<'a>,
+ continuing: Block<'a>,
+ break_if: Option<Handle<Expression<'a>>>,
+ },
+ Break,
+ Continue,
+ Return {
+ value: Option<Handle<Expression<'a>>>,
+ },
+ Kill,
+ Call {
+ function: Ident<'a>,
+ arguments: Vec<Handle<Expression<'a>>>,
+ },
+ Assign {
+ target: Handle<Expression<'a>>,
+ op: Option<crate::BinaryOperator>,
+ value: Handle<Expression<'a>>,
+ },
+ Increment(Handle<Expression<'a>>),
+ Decrement(Handle<Expression<'a>>),
+ Ignore(Handle<Expression<'a>>),
+}
+
+#[derive(Debug)]
+pub enum SwitchValue {
+ I32(i32),
+ U32(u32),
+ Default,
+}
+
+#[derive(Debug)]
+pub struct SwitchCase<'a> {
+ pub value: SwitchValue,
+ pub value_span: Span,
+ pub body: Block<'a>,
+ pub fall_through: bool,
+}
+
+/// A type at the head of a [`Construct`] expression.
+///
+/// WGSL has two types of [`type constructor expressions`]:
+///
+/// - Those that fully specify the type being constructed, like
+/// `vec3<f32>(x,y,z)`, which obviously constructs a `vec3<f32>`.
+///
+/// - Those that leave the component type of the composite being constructed
+/// implicit, to be inferred from the argument types, like `vec3(x,y,z)`,
+/// which constructs a `vec3<T>` where `T` is the type of `x`, `y`, and `z`.
+///
+/// This enum represents the head type of both cases. The `PartialFoo` variants
+/// represent the second case, where the component type is implicit.
+///
+/// This does not cover structs or types referred to by type aliases. See the
+/// documentation for [`Construct`] and [`Call`] expressions for details.
+///
+/// [`Construct`]: Expression::Construct
+/// [`type constructor expressions`]: https://gpuweb.github.io/gpuweb/wgsl/#type-constructor-expr
+/// [`Call`]: Expression::Call
+#[derive(Debug)]
+pub enum ConstructorType<'a> {
+ /// A scalar type or conversion: `f32(1)`.
+ Scalar {
+ kind: crate::ScalarKind,
+ width: crate::Bytes,
+ },
+
+ /// A vector construction whose component type is inferred from the
+ /// argument: `vec3(1.0)`.
+ PartialVector { size: crate::VectorSize },
+
+ /// A vector construction whose component type is written out:
+ /// `vec3<f32>(1.0)`.
+ Vector {
+ size: crate::VectorSize,
+ kind: crate::ScalarKind,
+ width: crate::Bytes,
+ },
+
+ /// A matrix construction whose component type is inferred from the
+ /// argument: `mat2x2(1,2,3,4)`.
+ PartialMatrix {
+ columns: crate::VectorSize,
+ rows: crate::VectorSize,
+ },
+
+ /// A matrix construction whose component type is written out:
+ /// `mat2x2<f32>(1,2,3,4)`.
+ Matrix {
+ columns: crate::VectorSize,
+ rows: crate::VectorSize,
+ width: crate::Bytes,
+ },
+
+ /// An array whose component type and size are inferred from the arguments:
+ /// `array(3,4,5)`.
+ PartialArray,
+
+ /// An array whose component type and size are written out:
+ /// `array<u32, 4>(3,4,5)`.
+ Array {
+ base: Handle<Type<'a>>,
+ size: ArraySize<'a>,
+ },
+
+ /// Constructing a value of a known Naga IR type.
+ ///
+ /// This variant is produced only during lowering, when we have Naga types
+ /// available, never during parsing.
+ Type(Handle<crate::Type>),
+}
+
+#[derive(Debug, Copy, Clone)]
+pub enum Literal {
+ Bool(bool),
+ Number(Number),
+}
+
+#[cfg(doc)]
+use crate::front::wgsl::lower::Lowerer;
+
+#[derive(Debug)]
+pub enum Expression<'a> {
+ Literal(Literal),
+ Ident(IdentExpr<'a>),
+
+ /// A type constructor expression.
+ ///
+ /// This is only used for expressions like `KEYWORD(EXPR...)` and
+ /// `KEYWORD<PARAM>(EXPR...)`, where `KEYWORD` is a [type-defining keyword] like
+ /// `vec3`. These keywords cannot be shadowed by user definitions, so we can
+ /// tell that such an expression is a construction immediately.
+ ///
+ /// For ordinary identifiers, we can't tell whether an expression like
+ /// `IDENTIFIER(EXPR, ...)` is a construction expression or a function call
+ /// until we know `IDENTIFIER`'s definition, so we represent those as
+ /// [`Call`] expressions.
+ ///
+ /// [type-defining keyword]: https://gpuweb.github.io/gpuweb/wgsl/#type-defining-keywords
+ /// [`Call`]: Expression::Call
+ Construct {
+ ty: ConstructorType<'a>,
+ ty_span: Span,
+ components: Vec<Handle<Expression<'a>>>,
+ },
+ Unary {
+ op: crate::UnaryOperator,
+ expr: Handle<Expression<'a>>,
+ },
+ AddrOf(Handle<Expression<'a>>),
+ Deref(Handle<Expression<'a>>),
+ Binary {
+ op: crate::BinaryOperator,
+ left: Handle<Expression<'a>>,
+ right: Handle<Expression<'a>>,
+ },
+
+ /// A function call or type constructor expression.
+ ///
+ /// We can't tell whether an expression like `IDENTIFIER(EXPR, ...)` is a
+ /// construction expression or a function call until we know `IDENTIFIER`'s
+ /// definition, so we represent everything of that form as one of these
+ /// expressions until lowering. At that point, [`Lowerer::call`] has
+ /// everything's definition in hand, and can decide whether to emit a Naga
+ /// [`Constant`], [`As`], [`Splat`], or [`Compose`] expression.
+ ///
+ /// [`Lowerer::call`]: Lowerer::call
+ /// [`Constant`]: crate::Expression::Constant
+ /// [`As`]: crate::Expression::As
+ /// [`Splat`]: crate::Expression::Splat
+ /// [`Compose`]: crate::Expression::Compose
+ Call {
+ function: Ident<'a>,
+ arguments: Vec<Handle<Expression<'a>>>,
+ },
+ Index {
+ base: Handle<Expression<'a>>,
+ index: Handle<Expression<'a>>,
+ },
+ Member {
+ base: Handle<Expression<'a>>,
+ field: Ident<'a>,
+ },
+ Bitcast {
+ expr: Handle<Expression<'a>>,
+ to: Handle<Type<'a>>,
+ ty_span: Span,
+ },
+}
+
+#[derive(Debug)]
+pub struct LocalVariable<'a> {
+ pub name: Ident<'a>,
+ pub ty: Option<Handle<Type<'a>>>,
+ pub init: Option<Handle<Expression<'a>>>,
+ pub handle: Handle<Local>,
+}
+
+#[derive(Debug)]
+pub struct Let<'a> {
+ pub name: Ident<'a>,
+ pub ty: Option<Handle<Type<'a>>>,
+ pub init: Handle<Expression<'a>>,
+ pub handle: Handle<Local>,
+}
+
+#[derive(Debug)]
+pub enum LocalDecl<'a> {
+ Var(LocalVariable<'a>),
+ Let(Let<'a>),
+}
+
+#[derive(Debug)]
+/// A placeholder for a local variable declaration.
+///
+/// See [`Function::locals`] for more information.
+pub struct Local;
diff --git a/third_party/rust/naga/src/front/wgsl/parse/conv.rs b/third_party/rust/naga/src/front/wgsl/parse/conv.rs
new file mode 100644
index 0000000000..a69dbbe470
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/parse/conv.rs
@@ -0,0 +1,236 @@
+use super::Error;
+use crate::Span;
+
+pub fn map_address_space(word: &str, span: Span) -> Result<crate::AddressSpace, Error<'_>> {
+ match word {
+ "private" => Ok(crate::AddressSpace::Private),
+ "workgroup" => Ok(crate::AddressSpace::WorkGroup),
+ "uniform" => Ok(crate::AddressSpace::Uniform),
+ "storage" => Ok(crate::AddressSpace::Storage {
+ access: crate::StorageAccess::default(),
+ }),
+ "push_constant" => Ok(crate::AddressSpace::PushConstant),
+ "function" => Ok(crate::AddressSpace::Function),
+ _ => Err(Error::UnknownAddressSpace(span)),
+ }
+}
+
+pub fn map_built_in(word: &str, span: Span) -> Result<crate::BuiltIn, Error<'_>> {
+ Ok(match word {
+ "position" => crate::BuiltIn::Position { invariant: false },
+ // vertex
+ "vertex_index" => crate::BuiltIn::VertexIndex,
+ "instance_index" => crate::BuiltIn::InstanceIndex,
+ "view_index" => crate::BuiltIn::ViewIndex,
+ // fragment
+ "front_facing" => crate::BuiltIn::FrontFacing,
+ "frag_depth" => crate::BuiltIn::FragDepth,
+ "primitive_index" => crate::BuiltIn::PrimitiveIndex,
+ "sample_index" => crate::BuiltIn::SampleIndex,
+ "sample_mask" => crate::BuiltIn::SampleMask,
+ // compute
+ "global_invocation_id" => crate::BuiltIn::GlobalInvocationId,
+ "local_invocation_id" => crate::BuiltIn::LocalInvocationId,
+ "local_invocation_index" => crate::BuiltIn::LocalInvocationIndex,
+ "workgroup_id" => crate::BuiltIn::WorkGroupId,
+ "num_workgroups" => crate::BuiltIn::NumWorkGroups,
+ _ => return Err(Error::UnknownBuiltin(span)),
+ })
+}
+
+pub fn map_interpolation(word: &str, span: Span) -> Result<crate::Interpolation, Error<'_>> {
+ match word {
+ "linear" => Ok(crate::Interpolation::Linear),
+ "flat" => Ok(crate::Interpolation::Flat),
+ "perspective" => Ok(crate::Interpolation::Perspective),
+ _ => Err(Error::UnknownAttribute(span)),
+ }
+}
+
+pub fn map_sampling(word: &str, span: Span) -> Result<crate::Sampling, Error<'_>> {
+ match word {
+ "center" => Ok(crate::Sampling::Center),
+ "centroid" => Ok(crate::Sampling::Centroid),
+ "sample" => Ok(crate::Sampling::Sample),
+ _ => Err(Error::UnknownAttribute(span)),
+ }
+}
+
+pub fn map_storage_format(word: &str, span: Span) -> Result<crate::StorageFormat, Error<'_>> {
+ use crate::StorageFormat as Sf;
+ Ok(match word {
+ "r8unorm" => Sf::R8Unorm,
+ "r8snorm" => Sf::R8Snorm,
+ "r8uint" => Sf::R8Uint,
+ "r8sint" => Sf::R8Sint,
+ "r16unorm" => Sf::R16Unorm,
+ "r16snorm" => Sf::R16Snorm,
+ "r16uint" => Sf::R16Uint,
+ "r16sint" => Sf::R16Sint,
+ "r16float" => Sf::R16Float,
+ "rg8unorm" => Sf::Rg8Unorm,
+ "rg8snorm" => Sf::Rg8Snorm,
+ "rg8uint" => Sf::Rg8Uint,
+ "rg8sint" => Sf::Rg8Sint,
+ "r32uint" => Sf::R32Uint,
+ "r32sint" => Sf::R32Sint,
+ "r32float" => Sf::R32Float,
+ "rg16unorm" => Sf::Rg16Unorm,
+ "rg16snorm" => Sf::Rg16Snorm,
+ "rg16uint" => Sf::Rg16Uint,
+ "rg16sint" => Sf::Rg16Sint,
+ "rg16float" => Sf::Rg16Float,
+ "rgba8unorm" => Sf::Rgba8Unorm,
+ "rgba8snorm" => Sf::Rgba8Snorm,
+ "rgba8uint" => Sf::Rgba8Uint,
+ "rgba8sint" => Sf::Rgba8Sint,
+ "rgb10a2unorm" => Sf::Rgb10a2Unorm,
+ "rg11b10float" => Sf::Rg11b10Float,
+ "rg32uint" => Sf::Rg32Uint,
+ "rg32sint" => Sf::Rg32Sint,
+ "rg32float" => Sf::Rg32Float,
+ "rgba16unorm" => Sf::Rgba16Unorm,
+ "rgba16snorm" => Sf::Rgba16Snorm,
+ "rgba16uint" => Sf::Rgba16Uint,
+ "rgba16sint" => Sf::Rgba16Sint,
+ "rgba16float" => Sf::Rgba16Float,
+ "rgba32uint" => Sf::Rgba32Uint,
+ "rgba32sint" => Sf::Rgba32Sint,
+ "rgba32float" => Sf::Rgba32Float,
+ _ => return Err(Error::UnknownStorageFormat(span)),
+ })
+}
+
+pub fn get_scalar_type(word: &str) -> Option<(crate::ScalarKind, crate::Bytes)> {
+ match word {
+ // "f16" => Some((crate::ScalarKind::Float, 2)),
+ "f32" => Some((crate::ScalarKind::Float, 4)),
+ "f64" => Some((crate::ScalarKind::Float, 8)),
+ "i32" => Some((crate::ScalarKind::Sint, 4)),
+ "u32" => Some((crate::ScalarKind::Uint, 4)),
+ "bool" => Some((crate::ScalarKind::Bool, crate::BOOL_WIDTH)),
+ _ => None,
+ }
+}
+
+pub fn map_derivative(word: &str) -> Option<(crate::DerivativeAxis, crate::DerivativeControl)> {
+ use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
+ match word {
+ "dpdxCoarse" => Some((Axis::X, Ctrl::Coarse)),
+ "dpdyCoarse" => Some((Axis::Y, Ctrl::Coarse)),
+ "fwidthCoarse" => Some((Axis::Width, Ctrl::Coarse)),
+ "dpdxFine" => Some((Axis::X, Ctrl::Fine)),
+ "dpdyFine" => Some((Axis::Y, Ctrl::Fine)),
+ "fwidthFine" => Some((Axis::Width, Ctrl::Fine)),
+ "dpdx" => Some((Axis::X, Ctrl::None)),
+ "dpdy" => Some((Axis::Y, Ctrl::None)),
+ "fwidth" => Some((Axis::Width, Ctrl::None)),
+ _ => None,
+ }
+}
+
+pub fn map_relational_fun(word: &str) -> Option<crate::RelationalFunction> {
+ match word {
+ "any" => Some(crate::RelationalFunction::Any),
+ "all" => Some(crate::RelationalFunction::All),
+ _ => None,
+ }
+}
+
+pub fn map_standard_fun(word: &str) -> Option<crate::MathFunction> {
+ use crate::MathFunction as Mf;
+ Some(match word {
+ // comparison
+ "abs" => Mf::Abs,
+ "min" => Mf::Min,
+ "max" => Mf::Max,
+ "clamp" => Mf::Clamp,
+ "saturate" => Mf::Saturate,
+ // trigonometry
+ "cos" => Mf::Cos,
+ "cosh" => Mf::Cosh,
+ "sin" => Mf::Sin,
+ "sinh" => Mf::Sinh,
+ "tan" => Mf::Tan,
+ "tanh" => Mf::Tanh,
+ "acos" => Mf::Acos,
+ "acosh" => Mf::Acosh,
+ "asin" => Mf::Asin,
+ "asinh" => Mf::Asinh,
+ "atan" => Mf::Atan,
+ "atanh" => Mf::Atanh,
+ "atan2" => Mf::Atan2,
+ "radians" => Mf::Radians,
+ "degrees" => Mf::Degrees,
+ // decomposition
+ "ceil" => Mf::Ceil,
+ "floor" => Mf::Floor,
+ "round" => Mf::Round,
+ "fract" => Mf::Fract,
+ "trunc" => Mf::Trunc,
+ "modf" => Mf::Modf,
+ "frexp" => Mf::Frexp,
+ "ldexp" => Mf::Ldexp,
+ // exponent
+ "exp" => Mf::Exp,
+ "exp2" => Mf::Exp2,
+ "log" => Mf::Log,
+ "log2" => Mf::Log2,
+ "pow" => Mf::Pow,
+ // geometry
+ "dot" => Mf::Dot,
+ "outerProduct" => Mf::Outer,
+ "cross" => Mf::Cross,
+ "distance" => Mf::Distance,
+ "length" => Mf::Length,
+ "normalize" => Mf::Normalize,
+ "faceForward" => Mf::FaceForward,
+ "reflect" => Mf::Reflect,
+ "refract" => Mf::Refract,
+ // computational
+ "sign" => Mf::Sign,
+ "fma" => Mf::Fma,
+ "mix" => Mf::Mix,
+ "step" => Mf::Step,
+ "smoothstep" => Mf::SmoothStep,
+ "sqrt" => Mf::Sqrt,
+ "inverseSqrt" => Mf::InverseSqrt,
+ "transpose" => Mf::Transpose,
+ "determinant" => Mf::Determinant,
+ // bits
+ "countTrailingZeros" => Mf::CountTrailingZeros,
+ "countLeadingZeros" => Mf::CountLeadingZeros,
+ "countOneBits" => Mf::CountOneBits,
+ "reverseBits" => Mf::ReverseBits,
+ "extractBits" => Mf::ExtractBits,
+ "insertBits" => Mf::InsertBits,
+ "firstTrailingBit" => Mf::FindLsb,
+ "firstLeadingBit" => Mf::FindMsb,
+ // data packing
+ "pack4x8snorm" => Mf::Pack4x8snorm,
+ "pack4x8unorm" => Mf::Pack4x8unorm,
+ "pack2x16snorm" => Mf::Pack2x16snorm,
+ "pack2x16unorm" => Mf::Pack2x16unorm,
+ "pack2x16float" => Mf::Pack2x16float,
+ // data unpacking
+ "unpack4x8snorm" => Mf::Unpack4x8snorm,
+ "unpack4x8unorm" => Mf::Unpack4x8unorm,
+ "unpack2x16snorm" => Mf::Unpack2x16snorm,
+ "unpack2x16unorm" => Mf::Unpack2x16unorm,
+ "unpack2x16float" => Mf::Unpack2x16float,
+ _ => return None,
+ })
+}
+
+pub fn map_conservative_depth(
+ word: &str,
+ span: Span,
+) -> Result<crate::ConservativeDepth, Error<'_>> {
+ use crate::ConservativeDepth as Cd;
+ match word {
+ "greater_equal" => Ok(Cd::GreaterEqual),
+ "less_equal" => Ok(Cd::LessEqual),
+ "unchanged" => Ok(Cd::Unchanged),
+ _ => Err(Error::UnknownConservativeDepth(span)),
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/parse/lexer.rs b/third_party/rust/naga/src/front/wgsl/parse/lexer.rs
new file mode 100644
index 0000000000..4654d4087f
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/parse/lexer.rs
@@ -0,0 +1,723 @@
+use super::{number::consume_number, Error, ExpectedToken};
+use crate::front::wgsl::error::NumberError;
+use crate::front::wgsl::parse::{conv, Number};
+use crate::Span;
+
+type TokenSpan<'a> = (Token<'a>, Span);
+
+#[derive(Copy, Clone, Debug, PartialEq)]
+pub enum Token<'a> {
+ Separator(char),
+ Paren(char),
+ Attribute,
+ Number(Result<Number, NumberError>),
+ Word(&'a str),
+ Operation(char),
+ LogicalOperation(char),
+ ShiftOperation(char),
+ AssignmentOperation(char),
+ IncrementOperation,
+ DecrementOperation,
+ Arrow,
+ Unknown(char),
+ Trivia,
+ End,
+}
+
+fn consume_any(input: &str, what: impl Fn(char) -> bool) -> (&str, &str) {
+ let pos = input.find(|c| !what(c)).unwrap_or(input.len());
+ input.split_at(pos)
+}
+
+/// Return the token at the start of `input`.
+///
+/// If `generic` is `false`, then the bit shift operators `>>` or `<<`
+/// are valid lookahead tokens for the current parser state (see [§3.1
+/// Parsing] in the WGSL specification). In other words:
+///
+/// - If `generic` is `true`, then we are expecting an angle bracket
+/// around a generic type parameter, like the `<` and `>` in
+/// `vec3<f32>`, so interpret `<` and `>` as `Token::Paren` tokens,
+/// even if they're part of `<<` or `>>` sequences.
+///
+/// - Otherwise, interpret `<<` and `>>` as shift operators:
+/// `Token::LogicalOperation` tokens.
+///
+/// [§3.1 Parsing]: https://gpuweb.github.io/gpuweb/wgsl/#parsing
+fn consume_token(input: &str, generic: bool) -> (Token<'_>, &str) {
+ let mut chars = input.chars();
+ let cur = match chars.next() {
+ Some(c) => c,
+ None => return (Token::End, ""),
+ };
+ match cur {
+ ':' | ';' | ',' => (Token::Separator(cur), chars.as_str()),
+ '.' => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some('0'..='9') => consume_number(input),
+ _ => (Token::Separator(cur), og_chars),
+ }
+ }
+ '@' => (Token::Attribute, chars.as_str()),
+ '(' | ')' | '{' | '}' | '[' | ']' => (Token::Paren(cur), chars.as_str()),
+ '<' | '>' => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some('=') if !generic => (Token::LogicalOperation(cur), chars.as_str()),
+ Some(c) if c == cur && !generic => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some('=') => (Token::AssignmentOperation(cur), chars.as_str()),
+ _ => (Token::ShiftOperation(cur), og_chars),
+ }
+ }
+ _ => (Token::Paren(cur), og_chars),
+ }
+ }
+ '0'..='9' => consume_number(input),
+ '/' => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some('/') => {
+ let _ = chars.position(is_comment_end);
+ (Token::Trivia, chars.as_str())
+ }
+ Some('*') => {
+ let mut depth = 1;
+ let mut prev = None;
+
+ for c in &mut chars {
+ match (prev, c) {
+ (Some('*'), '/') => {
+ prev = None;
+ depth -= 1;
+ if depth == 0 {
+ return (Token::Trivia, chars.as_str());
+ }
+ }
+ (Some('/'), '*') => {
+ prev = None;
+ depth += 1;
+ }
+ _ => {
+ prev = Some(c);
+ }
+ }
+ }
+
+ (Token::End, "")
+ }
+ Some('=') => (Token::AssignmentOperation(cur), chars.as_str()),
+ _ => (Token::Operation(cur), og_chars),
+ }
+ }
+ '-' => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some('>') => (Token::Arrow, chars.as_str()),
+ Some('0'..='9' | '.') => consume_number(input),
+ Some('-') => (Token::DecrementOperation, chars.as_str()),
+ Some('=') => (Token::AssignmentOperation(cur), chars.as_str()),
+ _ => (Token::Operation(cur), og_chars),
+ }
+ }
+ '+' => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some('+') => (Token::IncrementOperation, chars.as_str()),
+ Some('=') => (Token::AssignmentOperation(cur), chars.as_str()),
+ _ => (Token::Operation(cur), og_chars),
+ }
+ }
+ '*' | '%' | '^' => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some('=') => (Token::AssignmentOperation(cur), chars.as_str()),
+ _ => (Token::Operation(cur), og_chars),
+ }
+ }
+ '~' => (Token::Operation(cur), chars.as_str()),
+ '=' | '!' => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some('=') => (Token::LogicalOperation(cur), chars.as_str()),
+ _ => (Token::Operation(cur), og_chars),
+ }
+ }
+ '&' | '|' => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some(c) if c == cur => (Token::LogicalOperation(cur), chars.as_str()),
+ Some('=') => (Token::AssignmentOperation(cur), chars.as_str()),
+ _ => (Token::Operation(cur), og_chars),
+ }
+ }
+ _ if is_blankspace(cur) => {
+ let (_, rest) = consume_any(input, is_blankspace);
+ (Token::Trivia, rest)
+ }
+ _ if is_word_start(cur) => {
+ let (word, rest) = consume_any(input, is_word_part);
+ (Token::Word(word), rest)
+ }
+ _ => (Token::Unknown(cur), chars.as_str()),
+ }
+}
+
+/// Returns whether or not a char is a comment end
+/// (Unicode Pattern_White_Space excluding U+0020, U+0009, U+200E and U+200F)
+const fn is_comment_end(c: char) -> bool {
+ match c {
+ '\u{000a}'..='\u{000d}' | '\u{0085}' | '\u{2028}' | '\u{2029}' => true,
+ _ => false,
+ }
+}
+
+/// Returns whether or not a char is a blankspace (Unicode Pattern_White_Space)
+const fn is_blankspace(c: char) -> bool {
+ match c {
+ '\u{0020}'
+ | '\u{0009}'..='\u{000d}'
+ | '\u{0085}'
+ | '\u{200e}'
+ | '\u{200f}'
+ | '\u{2028}'
+ | '\u{2029}' => true,
+ _ => false,
+ }
+}
+
+/// Returns whether or not a char is a word start (Unicode XID_Start + '_')
+fn is_word_start(c: char) -> bool {
+ c == '_' || unicode_xid::UnicodeXID::is_xid_start(c)
+}
+
+/// Returns whether or not a char is a word part (Unicode XID_Continue)
+fn is_word_part(c: char) -> bool {
+ unicode_xid::UnicodeXID::is_xid_continue(c)
+}
+
+#[derive(Clone)]
+pub(in crate::front::wgsl) struct Lexer<'a> {
+ input: &'a str,
+ pub(in crate::front::wgsl) source: &'a str,
+ // The byte offset of the end of the last non-trivia token.
+ last_end_offset: usize,
+}
+
+impl<'a> Lexer<'a> {
+ pub(in crate::front::wgsl) const fn new(input: &'a str) -> Self {
+ Lexer {
+ input,
+ source: input,
+ last_end_offset: 0,
+ }
+ }
+
+ /// Calls the function with a lexer and returns the result of the function as well as the span for everything the function parsed
+ ///
+ /// # Examples
+ /// ```ignore
+ /// let lexer = Lexer::new("5");
+ /// let (value, span) = lexer.capture_span(Lexer::next_uint_literal);
+ /// assert_eq!(value, 5);
+ /// ```
+ #[inline]
+ pub fn capture_span<T, E>(
+ &mut self,
+ inner: impl FnOnce(&mut Self) -> Result<T, E>,
+ ) -> Result<(T, Span), E> {
+ let start = self.current_byte_offset();
+ let res = inner(self)?;
+ let end = self.current_byte_offset();
+ Ok((res, Span::from(start..end)))
+ }
+
+ pub(in crate::front::wgsl) fn start_byte_offset(&mut self) -> usize {
+ loop {
+ // Eat all trivia because `next` doesn't eat trailing trivia.
+ let (token, rest) = consume_token(self.input, false);
+ if let Token::Trivia = token {
+ self.input = rest;
+ } else {
+ return self.current_byte_offset();
+ }
+ }
+ }
+
+ fn peek_token_and_rest(&mut self) -> (TokenSpan<'a>, &'a str) {
+ let mut cloned = self.clone();
+ let token = cloned.next();
+ let rest = cloned.input;
+ (token, rest)
+ }
+
+ const fn current_byte_offset(&self) -> usize {
+ self.source.len() - self.input.len()
+ }
+
+ pub(in crate::front::wgsl) fn span_from(&self, offset: usize) -> Span {
+ Span::from(offset..self.last_end_offset)
+ }
+
+ /// Return the next non-whitespace token from `self`.
+ ///
+ /// Assume we are a parse state where bit shift operators may
+ /// occur, but not angle brackets.
+ #[must_use]
+ pub(in crate::front::wgsl) fn next(&mut self) -> TokenSpan<'a> {
+ self.next_impl(false)
+ }
+
+ /// Return the next non-whitespace token from `self`.
+ ///
+ /// Assume we are in a parse state where angle brackets may occur,
+ /// but not bit shift operators.
+ #[must_use]
+ pub(in crate::front::wgsl) fn next_generic(&mut self) -> TokenSpan<'a> {
+ self.next_impl(true)
+ }
+
+ /// Return the next non-whitespace token from `self`, with a span.
+ ///
+ /// See [`consume_token`] for the meaning of `generic`.
+ fn next_impl(&mut self, generic: bool) -> TokenSpan<'a> {
+ let mut start_byte_offset = self.current_byte_offset();
+ loop {
+ let (token, rest) = consume_token(self.input, generic);
+ self.input = rest;
+ match token {
+ Token::Trivia => start_byte_offset = self.current_byte_offset(),
+ _ => {
+ self.last_end_offset = self.current_byte_offset();
+ return (token, self.span_from(start_byte_offset));
+ }
+ }
+ }
+ }
+
+ #[must_use]
+ pub(in crate::front::wgsl) fn peek(&mut self) -> TokenSpan<'a> {
+ let (token, _) = self.peek_token_and_rest();
+ token
+ }
+
+ pub(in crate::front::wgsl) fn expect_span(
+ &mut self,
+ expected: Token<'a>,
+ ) -> Result<Span, Error<'a>> {
+ let next = self.next();
+ if next.0 == expected {
+ Ok(next.1)
+ } else {
+ Err(Error::Unexpected(next.1, ExpectedToken::Token(expected)))
+ }
+ }
+
+ pub(in crate::front::wgsl) fn expect(&mut self, expected: Token<'a>) -> Result<(), Error<'a>> {
+ self.expect_span(expected)?;
+ Ok(())
+ }
+
+ pub(in crate::front::wgsl) fn expect_generic_paren(
+ &mut self,
+ expected: char,
+ ) -> Result<(), Error<'a>> {
+ let next = self.next_generic();
+ if next.0 == Token::Paren(expected) {
+ Ok(())
+ } else {
+ Err(Error::Unexpected(
+ next.1,
+ ExpectedToken::Token(Token::Paren(expected)),
+ ))
+ }
+ }
+
+ /// If the next token matches it is skipped and true is returned
+ pub(in crate::front::wgsl) fn skip(&mut self, what: Token<'_>) -> bool {
+ let (peeked_token, rest) = self.peek_token_and_rest();
+ if peeked_token.0 == what {
+ self.input = rest;
+ true
+ } else {
+ false
+ }
+ }
+
+ pub(in crate::front::wgsl) fn next_ident_with_span(
+ &mut self,
+ ) -> Result<(&'a str, Span), Error<'a>> {
+ match self.next() {
+ (Token::Word(word), span) if word == "_" => {
+ Err(Error::InvalidIdentifierUnderscore(span))
+ }
+ (Token::Word(word), span) if word.starts_with("__") => {
+ Err(Error::ReservedIdentifierPrefix(span))
+ }
+ (Token::Word(word), span) => Ok((word, span)),
+ other => Err(Error::Unexpected(other.1, ExpectedToken::Identifier)),
+ }
+ }
+
+ pub(in crate::front::wgsl) fn next_ident(
+ &mut self,
+ ) -> Result<super::ast::Ident<'a>, Error<'a>> {
+ let ident = self
+ .next_ident_with_span()
+ .map(|(name, span)| super::ast::Ident { name, span })?;
+
+ if crate::keywords::wgsl::RESERVED.contains(&ident.name) {
+ return Err(Error::ReservedKeyword(ident.span));
+ }
+
+ Ok(ident)
+ }
+
+ /// Parses a generic scalar type, for example `<f32>`.
+ pub(in crate::front::wgsl) fn next_scalar_generic(
+ &mut self,
+ ) -> Result<(crate::ScalarKind, crate::Bytes), Error<'a>> {
+ self.expect_generic_paren('<')?;
+ let pair = match self.next() {
+ (Token::Word(word), span) => {
+ conv::get_scalar_type(word).ok_or(Error::UnknownScalarType(span))
+ }
+ (_, span) => Err(Error::UnknownScalarType(span)),
+ }?;
+ self.expect_generic_paren('>')?;
+ Ok(pair)
+ }
+
+ /// Parses a generic scalar type, for example `<f32>`.
+ ///
+ /// Returns the span covering the inner type, excluding the brackets.
+ pub(in crate::front::wgsl) fn next_scalar_generic_with_span(
+ &mut self,
+ ) -> Result<(crate::ScalarKind, crate::Bytes, Span), Error<'a>> {
+ self.expect_generic_paren('<')?;
+ let pair = match self.next() {
+ (Token::Word(word), span) => conv::get_scalar_type(word)
+ .map(|(a, b)| (a, b, span))
+ .ok_or(Error::UnknownScalarType(span)),
+ (_, span) => Err(Error::UnknownScalarType(span)),
+ }?;
+ self.expect_generic_paren('>')?;
+ Ok(pair)
+ }
+
+ pub(in crate::front::wgsl) fn next_storage_access(
+ &mut self,
+ ) -> Result<crate::StorageAccess, Error<'a>> {
+ let (ident, span) = self.next_ident_with_span()?;
+ match ident {
+ "read" => Ok(crate::StorageAccess::LOAD),
+ "write" => Ok(crate::StorageAccess::STORE),
+ "read_write" => Ok(crate::StorageAccess::LOAD | crate::StorageAccess::STORE),
+ _ => Err(Error::UnknownAccess(span)),
+ }
+ }
+
+ pub(in crate::front::wgsl) fn next_format_generic(
+ &mut self,
+ ) -> Result<(crate::StorageFormat, crate::StorageAccess), Error<'a>> {
+ self.expect(Token::Paren('<'))?;
+ let (ident, ident_span) = self.next_ident_with_span()?;
+ let format = conv::map_storage_format(ident, ident_span)?;
+ self.expect(Token::Separator(','))?;
+ let access = self.next_storage_access()?;
+ self.expect(Token::Paren('>'))?;
+ Ok((format, access))
+ }
+
+ pub(in crate::front::wgsl) fn open_arguments(&mut self) -> Result<(), Error<'a>> {
+ self.expect(Token::Paren('('))
+ }
+
+ pub(in crate::front::wgsl) fn close_arguments(&mut self) -> Result<(), Error<'a>> {
+ let _ = self.skip(Token::Separator(','));
+ self.expect(Token::Paren(')'))
+ }
+
+ pub(in crate::front::wgsl) fn next_argument(&mut self) -> Result<bool, Error<'a>> {
+ let paren = Token::Paren(')');
+ if self.skip(Token::Separator(',')) {
+ Ok(!self.skip(paren))
+ } else {
+ self.expect(paren).map(|()| false)
+ }
+ }
+}
+
+#[cfg(test)]
+fn sub_test(source: &str, expected_tokens: &[Token]) {
+ let mut lex = Lexer::new(source);
+ for &token in expected_tokens {
+ assert_eq!(lex.next().0, token);
+ }
+ assert_eq!(lex.next().0, Token::End);
+}
+
+#[test]
+fn test_numbers() {
+ // WGSL spec examples //
+
+ // decimal integer
+ sub_test(
+ "0x123 0X123u 1u 123 0 0i 0x3f",
+ &[
+ Token::Number(Ok(Number::I32(291))),
+ Token::Number(Ok(Number::U32(291))),
+ Token::Number(Ok(Number::U32(1))),
+ Token::Number(Ok(Number::I32(123))),
+ Token::Number(Ok(Number::I32(0))),
+ Token::Number(Ok(Number::I32(0))),
+ Token::Number(Ok(Number::I32(63))),
+ ],
+ );
+ // decimal floating point
+ sub_test(
+ "0.e+4f 01. .01 12.34 .0f 0h 1e-3 0xa.fp+2 0x1P+4f 0X.3 0x3p+2h 0X1.fp-4 0x3.2p+2h",
+ &[
+ Token::Number(Ok(Number::F32(0.))),
+ Token::Number(Ok(Number::F32(1.))),
+ Token::Number(Ok(Number::F32(0.01))),
+ Token::Number(Ok(Number::F32(12.34))),
+ Token::Number(Ok(Number::F32(0.))),
+ Token::Number(Err(NumberError::UnimplementedF16)),
+ Token::Number(Ok(Number::F32(0.001))),
+ Token::Number(Ok(Number::F32(43.75))),
+ Token::Number(Ok(Number::F32(16.))),
+ Token::Number(Ok(Number::F32(0.1875))),
+ Token::Number(Err(NumberError::UnimplementedF16)),
+ Token::Number(Ok(Number::F32(0.12109375))),
+ Token::Number(Err(NumberError::UnimplementedF16)),
+ ],
+ );
+
+ // MIN / MAX //
+
+ // min / max decimal signed integer
+ sub_test(
+ "-2147483648i 2147483647i -2147483649i 2147483648i",
+ &[
+ Token::Number(Ok(Number::I32(i32::MIN))),
+ Token::Number(Ok(Number::I32(i32::MAX))),
+ Token::Number(Err(NumberError::NotRepresentable)),
+ Token::Number(Err(NumberError::NotRepresentable)),
+ ],
+ );
+ // min / max decimal unsigned integer
+ sub_test(
+ "0u 4294967295u -1u 4294967296u",
+ &[
+ Token::Number(Ok(Number::U32(u32::MIN))),
+ Token::Number(Ok(Number::U32(u32::MAX))),
+ Token::Number(Err(NumberError::NotRepresentable)),
+ Token::Number(Err(NumberError::NotRepresentable)),
+ ],
+ );
+
+ // min / max hexadecimal signed integer
+ sub_test(
+ "-0x80000000i 0x7FFFFFFFi -0x80000001i 0x80000000i",
+ &[
+ Token::Number(Ok(Number::I32(i32::MIN))),
+ Token::Number(Ok(Number::I32(i32::MAX))),
+ Token::Number(Err(NumberError::NotRepresentable)),
+ Token::Number(Err(NumberError::NotRepresentable)),
+ ],
+ );
+ // min / max hexadecimal unsigned integer
+ sub_test(
+ "0x0u 0xFFFFFFFFu -0x1u 0x100000000u",
+ &[
+ Token::Number(Ok(Number::U32(u32::MIN))),
+ Token::Number(Ok(Number::U32(u32::MAX))),
+ Token::Number(Err(NumberError::NotRepresentable)),
+ Token::Number(Err(NumberError::NotRepresentable)),
+ ],
+ );
+
+ /// ≈ 2^-126 * 2^−23 (= 2^−149)
+ const SMALLEST_POSITIVE_SUBNORMAL_F32: f32 = 1e-45;
+ /// ≈ 2^-126 * (1 − 2^−23)
+ const LARGEST_SUBNORMAL_F32: f32 = 1.1754942e-38;
+ /// ≈ 2^-126
+ const SMALLEST_POSITIVE_NORMAL_F32: f32 = f32::MIN_POSITIVE;
+ /// ≈ 1 − 2^−24
+ const LARGEST_F32_LESS_THAN_ONE: f32 = 0.99999994;
+ /// ≈ 1 + 2^−23
+ const SMALLEST_F32_LARGER_THAN_ONE: f32 = 1.0000001;
+ /// ≈ -(2^127 * (2 − 2^−23))
+ const SMALLEST_NORMAL_F32: f32 = f32::MIN;
+ /// ≈ 2^127 * (2 − 2^−23)
+ const LARGEST_NORMAL_F32: f32 = f32::MAX;
+
+ // decimal floating point
+ sub_test(
+ "1e-45f 1.1754942e-38f 1.17549435e-38f 0.99999994f 1.0000001f -3.40282347e+38f 3.40282347e+38f",
+ &[
+ Token::Number(Ok(Number::F32(
+ SMALLEST_POSITIVE_SUBNORMAL_F32,
+ ))),
+ Token::Number(Ok(Number::F32(
+ LARGEST_SUBNORMAL_F32,
+ ))),
+ Token::Number(Ok(Number::F32(
+ SMALLEST_POSITIVE_NORMAL_F32,
+ ))),
+ Token::Number(Ok(Number::F32(
+ LARGEST_F32_LESS_THAN_ONE,
+ ))),
+ Token::Number(Ok(Number::F32(
+ SMALLEST_F32_LARGER_THAN_ONE,
+ ))),
+ Token::Number(Ok(Number::F32(
+ SMALLEST_NORMAL_F32,
+ ))),
+ Token::Number(Ok(Number::F32(
+ LARGEST_NORMAL_F32,
+ ))),
+ ],
+ );
+ sub_test(
+ "-3.40282367e+38f 3.40282367e+38f",
+ &[
+ Token::Number(Err(NumberError::NotRepresentable)), // ≈ -2^128
+ Token::Number(Err(NumberError::NotRepresentable)), // ≈ 2^128
+ ],
+ );
+
+ // hexadecimal floating point
+ sub_test(
+ "0x1p-149f 0x7FFFFFp-149f 0x1p-126f 0xFFFFFFp-24f 0x800001p-23f -0xFFFFFFp+104f 0xFFFFFFp+104f",
+ &[
+ Token::Number(Ok(Number::F32(
+ SMALLEST_POSITIVE_SUBNORMAL_F32,
+ ))),
+ Token::Number(Ok(Number::F32(
+ LARGEST_SUBNORMAL_F32,
+ ))),
+ Token::Number(Ok(Number::F32(
+ SMALLEST_POSITIVE_NORMAL_F32,
+ ))),
+ Token::Number(Ok(Number::F32(
+ LARGEST_F32_LESS_THAN_ONE,
+ ))),
+ Token::Number(Ok(Number::F32(
+ SMALLEST_F32_LARGER_THAN_ONE,
+ ))),
+ Token::Number(Ok(Number::F32(
+ SMALLEST_NORMAL_F32,
+ ))),
+ Token::Number(Ok(Number::F32(
+ LARGEST_NORMAL_F32,
+ ))),
+ ],
+ );
+ sub_test(
+ "-0x1p128f 0x1p128f 0x1.000001p0f",
+ &[
+ Token::Number(Err(NumberError::NotRepresentable)), // = -2^128
+ Token::Number(Err(NumberError::NotRepresentable)), // = 2^128
+ Token::Number(Err(NumberError::NotRepresentable)),
+ ],
+ );
+}
+
+#[test]
+fn test_tokens() {
+ sub_test("id123_OK", &[Token::Word("id123_OK")]);
+ sub_test(
+ "92No",
+ &[Token::Number(Ok(Number::I32(92))), Token::Word("No")],
+ );
+ sub_test(
+ "2u3o",
+ &[
+ Token::Number(Ok(Number::U32(2))),
+ Token::Number(Ok(Number::I32(3))),
+ Token::Word("o"),
+ ],
+ );
+ sub_test(
+ "2.4f44po",
+ &[
+ Token::Number(Ok(Number::F32(2.4))),
+ Token::Number(Ok(Number::I32(44))),
+ Token::Word("po"),
+ ],
+ );
+ sub_test(
+ "Δέλτα réflexion Кызыл 𐰓𐰏𐰇 朝焼け سلام 검정 שָׁלוֹם गुलाबी փիրուզ",
+ &[
+ Token::Word("Δέλτα"),
+ Token::Word("réflexion"),
+ Token::Word("Кызыл"),
+ Token::Word("𐰓𐰏𐰇"),
+ Token::Word("朝焼け"),
+ Token::Word("سلام"),
+ Token::Word("검정"),
+ Token::Word("שָׁלוֹם"),
+ Token::Word("गुलाबी"),
+ Token::Word("փիրուզ"),
+ ],
+ );
+ sub_test("æNoø", &[Token::Word("æNoø")]);
+ sub_test("No¾", &[Token::Word("No"), Token::Unknown('¾')]);
+ sub_test("No好", &[Token::Word("No好")]);
+ sub_test("_No", &[Token::Word("_No")]);
+ sub_test(
+ "*/*/***/*//=/*****//",
+ &[
+ Token::Operation('*'),
+ Token::AssignmentOperation('/'),
+ Token::Operation('/'),
+ ],
+ );
+}
+
+#[test]
+fn test_variable_decl() {
+ sub_test(
+ "@group(0 ) var< uniform> texture: texture_multisampled_2d <f32 >;",
+ &[
+ Token::Attribute,
+ Token::Word("group"),
+ Token::Paren('('),
+ Token::Number(Ok(Number::I32(0))),
+ Token::Paren(')'),
+ Token::Word("var"),
+ Token::Paren('<'),
+ Token::Word("uniform"),
+ Token::Paren('>'),
+ Token::Word("texture"),
+ Token::Separator(':'),
+ Token::Word("texture_multisampled_2d"),
+ Token::Paren('<'),
+ Token::Word("f32"),
+ Token::Paren('>'),
+ Token::Separator(';'),
+ ],
+ );
+ sub_test(
+ "var<storage,read_write> buffer: array<u32>;",
+ &[
+ Token::Word("var"),
+ Token::Paren('<'),
+ Token::Word("storage"),
+ Token::Separator(','),
+ Token::Word("read_write"),
+ Token::Paren('>'),
+ Token::Word("buffer"),
+ Token::Separator(':'),
+ Token::Word("array"),
+ Token::Paren('<'),
+ Token::Word("u32"),
+ Token::Paren('>'),
+ Token::Separator(';'),
+ ],
+ );
+}
diff --git a/third_party/rust/naga/src/front/wgsl/parse/mod.rs b/third_party/rust/naga/src/front/wgsl/parse/mod.rs
new file mode 100644
index 0000000000..cf9892a1f3
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/parse/mod.rs
@@ -0,0 +1,2298 @@
+use crate::front::wgsl::error::{Error, ExpectedToken};
+use crate::front::wgsl::parse::lexer::{Lexer, Token};
+use crate::front::wgsl::parse::number::Number;
+use crate::front::SymbolTable;
+use crate::{Arena, FastHashSet, Handle, Span};
+
+pub mod ast;
+pub mod conv;
+pub mod lexer;
+pub mod number;
+
+/// State for constructing an AST expression.
+///
+/// Not to be confused with `lower::ExpressionContext`.
+struct ExpressionContext<'input, 'temp, 'out> {
+ /// The [`TranslationUnit::expressions`] arena to which we should contribute
+ /// expressions.
+ ///
+ /// [`TranslationUnit::expressions`]: ast::TranslationUnit::expressions
+ expressions: &'out mut Arena<ast::Expression<'input>>,
+
+ /// The [`TranslationUnit::types`] arena to which we should contribute new
+ /// types.
+ ///
+ /// [`TranslationUnit::types`]: ast::TranslationUnit::types
+ types: &'out mut Arena<ast::Type<'input>>,
+
+ /// A map from identifiers in scope to the locals/arguments they represent.
+ ///
+ /// The handles refer to the [`Function::locals`] area; see that field's
+ /// documentation for details.
+ ///
+ /// [`Function::locals`]: ast::Function::locals
+ local_table: &'temp mut SymbolTable<&'input str, Handle<ast::Local>>,
+
+ /// The [`Function::locals`] arena for the function we're building.
+ ///
+ /// [`Function::locals`]: ast::Function::locals
+ locals: &'out mut Arena<ast::Local>,
+
+ /// Identifiers used by the current global declaration that have no local definition.
+ ///
+ /// This becomes the [`GlobalDecl`]'s [`dependencies`] set.
+ ///
+ /// Note that we don't know at parse time what kind of [`GlobalDecl`] the
+ /// name refers to. We can't look up names until we've seen the entire
+ /// translation unit.
+ ///
+ /// [`GlobalDecl`]: ast::GlobalDecl
+ /// [`dependencies`]: ast::GlobalDecl::dependencies
+ unresolved: &'out mut FastHashSet<ast::Dependency<'input>>,
+}
+
+impl<'a> ExpressionContext<'a, '_, '_> {
+ fn reborrow(&mut self) -> ExpressionContext<'a, '_, '_> {
+ ExpressionContext {
+ expressions: self.expressions,
+ types: self.types,
+ local_table: self.local_table,
+ locals: self.locals,
+ unresolved: self.unresolved,
+ }
+ }
+
+ fn parse_binary_op(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ classifier: impl Fn(Token<'a>) -> Option<crate::BinaryOperator>,
+ mut parser: impl FnMut(
+ &mut Lexer<'a>,
+ ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>>,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> {
+ let start = lexer.start_byte_offset();
+ let mut accumulator = parser(lexer, self.reborrow())?;
+ while let Some(op) = classifier(lexer.peek().0) {
+ let _ = lexer.next();
+ let left = accumulator;
+ let right = parser(lexer, self.reborrow())?;
+ accumulator = self.expressions.append(
+ ast::Expression::Binary { op, left, right },
+ lexer.span_from(start),
+ );
+ }
+ Ok(accumulator)
+ }
+}
+
+/// Which grammar rule we are in the midst of parsing.
+///
+/// This is used for error checking. `Parser` maintains a stack of
+/// these and (occasionally) checks that it is being pushed and popped
+/// as expected.
+#[derive(Clone, Debug, PartialEq)]
+enum Rule {
+ Attribute,
+ VariableDecl,
+ TypeDecl,
+ FunctionDecl,
+ Block,
+ Statement,
+ PrimaryExpr,
+ SingularExpr,
+ UnaryExpr,
+ GeneralExpr,
+}
+
+#[derive(Default)]
+struct BindingParser {
+ location: Option<u32>,
+ built_in: Option<crate::BuiltIn>,
+ interpolation: Option<crate::Interpolation>,
+ sampling: Option<crate::Sampling>,
+ invariant: bool,
+}
+
+impl BindingParser {
+ fn parse<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ name: &'a str,
+ name_span: Span,
+ ) -> Result<(), Error<'a>> {
+ match name {
+ "location" => {
+ lexer.expect(Token::Paren('('))?;
+ self.location = Some(Parser::non_negative_i32_literal(lexer)?);
+ lexer.expect(Token::Paren(')'))?;
+ }
+ "builtin" => {
+ lexer.expect(Token::Paren('('))?;
+ let (raw, span) = lexer.next_ident_with_span()?;
+ self.built_in = Some(conv::map_built_in(raw, span)?);
+ lexer.expect(Token::Paren(')'))?;
+ }
+ "interpolate" => {
+ lexer.expect(Token::Paren('('))?;
+ let (raw, span) = lexer.next_ident_with_span()?;
+ self.interpolation = Some(conv::map_interpolation(raw, span)?);
+ if lexer.skip(Token::Separator(',')) {
+ let (raw, span) = lexer.next_ident_with_span()?;
+ self.sampling = Some(conv::map_sampling(raw, span)?);
+ }
+ lexer.expect(Token::Paren(')'))?;
+ }
+ "invariant" => self.invariant = true,
+ _ => return Err(Error::UnknownAttribute(name_span)),
+ }
+ Ok(())
+ }
+
+ const fn finish<'a>(self, span: Span) -> Result<Option<crate::Binding>, Error<'a>> {
+ match (
+ self.location,
+ self.built_in,
+ self.interpolation,
+ self.sampling,
+ self.invariant,
+ ) {
+ (None, None, None, None, false) => Ok(None),
+ (Some(location), None, interpolation, sampling, false) => {
+ // Before handing over the completed `Module`, we call
+ // `apply_default_interpolation` to ensure that the interpolation and
+ // sampling have been explicitly specified on all vertex shader output and fragment
+ // shader input user bindings, so leaving them potentially `None` here is fine.
+ Ok(Some(crate::Binding::Location {
+ location,
+ interpolation,
+ sampling,
+ }))
+ }
+ (None, Some(crate::BuiltIn::Position { .. }), None, None, invariant) => {
+ Ok(Some(crate::Binding::BuiltIn(crate::BuiltIn::Position {
+ invariant,
+ })))
+ }
+ (None, Some(built_in), None, None, false) => {
+ Ok(Some(crate::Binding::BuiltIn(built_in)))
+ }
+ (_, _, _, _, _) => Err(Error::InconsistentBinding(span)),
+ }
+ }
+}
+
+pub struct Parser {
+ rules: Vec<(Rule, usize)>,
+}
+
+impl Parser {
+ pub const fn new() -> Self {
+ Parser { rules: Vec::new() }
+ }
+
+ fn reset(&mut self) {
+ self.rules.clear();
+ }
+
+ fn push_rule_span(&mut self, rule: Rule, lexer: &mut Lexer<'_>) {
+ self.rules.push((rule, lexer.start_byte_offset()));
+ }
+
+ fn pop_rule_span(&mut self, lexer: &Lexer<'_>) -> Span {
+ let (_, initial) = self.rules.pop().unwrap();
+ lexer.span_from(initial)
+ }
+
+ fn peek_rule_span(&mut self, lexer: &Lexer<'_>) -> Span {
+ let &(_, initial) = self.rules.last().unwrap();
+ lexer.span_from(initial)
+ }
+
+ fn switch_value<'a>(lexer: &mut Lexer<'a>) -> Result<(ast::SwitchValue, Span), Error<'a>> {
+ let token_span = lexer.next();
+ match token_span.0 {
+ Token::Word("default") => Ok((ast::SwitchValue::Default, token_span.1)),
+ Token::Number(Ok(Number::U32(num))) => Ok((ast::SwitchValue::U32(num), token_span.1)),
+ Token::Number(Ok(Number::I32(num))) => Ok((ast::SwitchValue::I32(num), token_span.1)),
+ Token::Number(Err(e)) => Err(Error::BadNumber(token_span.1, e)),
+ _ => Err(Error::Unexpected(token_span.1, ExpectedToken::Integer)),
+ }
+ }
+
+ /// Parse a non-negative signed integer literal.
+ /// This is for attributes like `size`, `location` and others.
+ fn non_negative_i32_literal<'a>(lexer: &mut Lexer<'a>) -> Result<u32, Error<'a>> {
+ match lexer.next() {
+ (Token::Number(Ok(Number::I32(num))), span) => {
+ u32::try_from(num).map_err(|_| Error::NegativeInt(span))
+ }
+ (Token::Number(Err(e)), span) => Err(Error::BadNumber(span, e)),
+ other => Err(Error::Unexpected(other.1, ExpectedToken::Number)),
+ }
+ }
+
+ /// Parse a non-negative integer literal that may be either signed or unsigned.
+ /// This is for the `workgroup_size` attribute and array lengths.
+ /// Note: these values should be no larger than [`i32::MAX`], but this is not checked here.
+ fn generic_non_negative_int_literal<'a>(lexer: &mut Lexer<'a>) -> Result<u32, Error<'a>> {
+ match lexer.next() {
+ (Token::Number(Ok(Number::I32(num))), span) => {
+ u32::try_from(num).map_err(|_| Error::NegativeInt(span))
+ }
+ (Token::Number(Ok(Number::U32(num))), _) => Ok(num),
+ (Token::Number(Err(e)), span) => Err(Error::BadNumber(span, e)),
+ other => Err(Error::Unexpected(other.1, ExpectedToken::Number)),
+ }
+ }
+
+ /// Decide if we're looking at a construction expression, and return its
+ /// type if so.
+ ///
+ /// If the identifier `word` is a [type-defining keyword], then return a
+ /// [`ConstructorType`] value describing the type to build. Return an error
+ /// if the type is not constructible (like `sampler`).
+ ///
+ /// If `word` isn't a type name, then return `None`.
+ ///
+ /// [type-defining keyword]: https://gpuweb.github.io/gpuweb/wgsl/#type-defining-keywords
+ /// [`ConstructorType`]: ast::ConstructorType
+ fn constructor_type<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ word: &'a str,
+ span: Span,
+ mut ctx: ExpressionContext<'a, '_, '_>,
+ ) -> Result<Option<ast::ConstructorType<'a>>, Error<'a>> {
+ if let Some((kind, width)) = conv::get_scalar_type(word) {
+ return Ok(Some(ast::ConstructorType::Scalar { kind, width }));
+ }
+
+ let partial = match word {
+ "vec2" => ast::ConstructorType::PartialVector {
+ size: crate::VectorSize::Bi,
+ },
+ "vec2i" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Bi,
+ kind: crate::ScalarKind::Sint,
+ width: 4,
+ }))
+ }
+ "vec2u" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Bi,
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ }))
+ }
+ "vec2f" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Bi,
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ }))
+ }
+ "vec3" => ast::ConstructorType::PartialVector {
+ size: crate::VectorSize::Tri,
+ },
+ "vec3i" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Tri,
+ kind: crate::ScalarKind::Sint,
+ width: 4,
+ }))
+ }
+ "vec3u" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Tri,
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ }))
+ }
+ "vec3f" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Tri,
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ }))
+ }
+ "vec4" => ast::ConstructorType::PartialVector {
+ size: crate::VectorSize::Quad,
+ },
+ "vec4i" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Quad,
+ kind: crate::ScalarKind::Sint,
+ width: 4,
+ }))
+ }
+ "vec4u" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Quad,
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ }))
+ }
+ "vec4f" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Quad,
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ }))
+ }
+ "mat2x2" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Bi,
+ },
+ "mat2x2f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }))
+ }
+ "mat2x3" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Tri,
+ },
+ "mat2x3f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Tri,
+ width: 4,
+ }))
+ }
+ "mat2x4" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Quad,
+ },
+ "mat2x4f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Quad,
+ width: 4,
+ }))
+ }
+ "mat3x2" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Bi,
+ },
+ "mat3x2f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }))
+ }
+ "mat3x3" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Tri,
+ },
+ "mat3x3f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Tri,
+ width: 4,
+ }))
+ }
+ "mat3x4" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Quad,
+ },
+ "mat3x4f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Quad,
+ width: 4,
+ }))
+ }
+ "mat4x2" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Bi,
+ },
+ "mat4x2f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }))
+ }
+ "mat4x3" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Tri,
+ },
+ "mat4x3f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Tri,
+ width: 4,
+ }))
+ }
+ "mat4x4" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Quad,
+ },
+ "mat4x4f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Quad,
+ width: 4,
+ }))
+ }
+ "array" => ast::ConstructorType::PartialArray,
+ "atomic"
+ | "binding_array"
+ | "sampler"
+ | "sampler_comparison"
+ | "texture_1d"
+ | "texture_1d_array"
+ | "texture_2d"
+ | "texture_2d_array"
+ | "texture_3d"
+ | "texture_cube"
+ | "texture_cube_array"
+ | "texture_multisampled_2d"
+ | "texture_multisampled_2d_array"
+ | "texture_depth_2d"
+ | "texture_depth_2d_array"
+ | "texture_depth_cube"
+ | "texture_depth_cube_array"
+ | "texture_depth_multisampled_2d"
+ | "texture_storage_1d"
+ | "texture_storage_1d_array"
+ | "texture_storage_2d"
+ | "texture_storage_2d_array"
+ | "texture_storage_3d" => return Err(Error::TypeNotConstructible(span)),
+ _ => return Ok(None),
+ };
+
+ // parse component type if present
+ match (lexer.peek().0, partial) {
+ (Token::Paren('<'), ast::ConstructorType::PartialVector { size }) => {
+ let (kind, width) = lexer.next_scalar_generic()?;
+ Ok(Some(ast::ConstructorType::Vector { size, kind, width }))
+ }
+ (Token::Paren('<'), ast::ConstructorType::PartialMatrix { columns, rows }) => {
+ let (kind, width, span) = lexer.next_scalar_generic_with_span()?;
+ match kind {
+ crate::ScalarKind::Float => Ok(Some(ast::ConstructorType::Matrix {
+ columns,
+ rows,
+ width,
+ })),
+ _ => Err(Error::BadMatrixScalarKind(span, kind, width)),
+ }
+ }
+ (Token::Paren('<'), ast::ConstructorType::PartialArray) => {
+ lexer.expect_generic_paren('<')?;
+ let base = self.type_decl(lexer, ctx.reborrow())?;
+ let size = if lexer.skip(Token::Separator(',')) {
+ let expr = self.unary_expression(lexer, ctx.reborrow())?;
+ ast::ArraySize::Constant(expr)
+ } else {
+ ast::ArraySize::Dynamic
+ };
+ lexer.expect_generic_paren('>')?;
+
+ Ok(Some(ast::ConstructorType::Array { base, size }))
+ }
+ (_, partial) => Ok(Some(partial)),
+ }
+ }
+
+ /// Expects `name` to be consumed (not in lexer).
+ fn arguments<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut ctx: ExpressionContext<'a, '_, '_>,
+ ) -> Result<Vec<Handle<ast::Expression<'a>>>, Error<'a>> {
+ lexer.open_arguments()?;
+ let mut arguments = Vec::new();
+ loop {
+ if !arguments.is_empty() {
+ if !lexer.next_argument()? {
+ break;
+ }
+ } else if lexer.skip(Token::Paren(')')) {
+ break;
+ }
+ let arg = self.general_expression(lexer, ctx.reborrow())?;
+ arguments.push(arg);
+ }
+
+ Ok(arguments)
+ }
+
+ /// Expects [`Rule::PrimaryExpr`] or [`Rule::SingularExpr`] on top; does not pop it.
+ /// Expects `name` to be consumed (not in lexer).
+ fn function_call<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ name: &'a str,
+ name_span: Span,
+ mut ctx: ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> {
+ assert!(self.rules.last().is_some());
+
+ let expr = match name {
+ // bitcast looks like a function call, but it's an operator and must be handled differently.
+ "bitcast" => {
+ lexer.expect_generic_paren('<')?;
+ let start = lexer.start_byte_offset();
+ let to = self.type_decl(lexer, ctx.reborrow())?;
+ let span = lexer.span_from(start);
+ lexer.expect_generic_paren('>')?;
+
+ lexer.open_arguments()?;
+ let expr = self.general_expression(lexer, ctx.reborrow())?;
+ lexer.close_arguments()?;
+
+ ast::Expression::Bitcast {
+ expr,
+ to,
+ ty_span: span,
+ }
+ }
+ // everything else must be handled later, since they can be hidden by user-defined functions.
+ _ => {
+ let arguments = self.arguments(lexer, ctx.reborrow())?;
+ ctx.unresolved.insert(ast::Dependency {
+ ident: name,
+ usage: name_span,
+ });
+ ast::Expression::Call {
+ function: ast::Ident {
+ name,
+ span: name_span,
+ },
+ arguments,
+ }
+ }
+ };
+
+ let span = self.peek_rule_span(lexer);
+ let expr = ctx.expressions.append(expr, span);
+ Ok(expr)
+ }
+
+ fn ident_expr<'a>(
+ &mut self,
+ name: &'a str,
+ name_span: Span,
+ ctx: ExpressionContext<'a, '_, '_>,
+ ) -> ast::IdentExpr<'a> {
+ match ctx.local_table.lookup(name) {
+ Some(&local) => ast::IdentExpr::Local(local),
+ None => {
+ ctx.unresolved.insert(ast::Dependency {
+ ident: name,
+ usage: name_span,
+ });
+ ast::IdentExpr::Unresolved(name)
+ }
+ }
+ }
+
+ fn primary_expression<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut ctx: ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> {
+ self.push_rule_span(Rule::PrimaryExpr, lexer);
+
+ let expr = match lexer.peek() {
+ (Token::Paren('('), _) => {
+ let _ = lexer.next();
+ let expr = self.general_expression(lexer, ctx.reborrow())?;
+ lexer.expect(Token::Paren(')'))?;
+ self.pop_rule_span(lexer);
+ return Ok(expr);
+ }
+ (Token::Word("true"), _) => {
+ let _ = lexer.next();
+ ast::Expression::Literal(ast::Literal::Bool(true))
+ }
+ (Token::Word("false"), _) => {
+ let _ = lexer.next();
+ ast::Expression::Literal(ast::Literal::Bool(false))
+ }
+ (Token::Number(res), span) => {
+ let _ = lexer.next();
+ let num = res.map_err(|err| Error::BadNumber(span, err))?;
+ ast::Expression::Literal(ast::Literal::Number(num))
+ }
+ (Token::Word("RAY_FLAG_NONE"), _) => {
+ let _ = lexer.next();
+ ast::Expression::Literal(ast::Literal::Number(Number::U32(0)))
+ }
+ (Token::Word("RAY_FLAG_TERMINATE_ON_FIRST_HIT"), _) => {
+ let _ = lexer.next();
+ ast::Expression::Literal(ast::Literal::Number(Number::U32(4)))
+ }
+ (Token::Word("RAY_QUERY_INTERSECTION_NONE"), _) => {
+ let _ = lexer.next();
+ ast::Expression::Literal(ast::Literal::Number(Number::U32(0)))
+ }
+ (Token::Word(word), span) => {
+ let start = lexer.start_byte_offset();
+ let _ = lexer.next();
+
+ if let Some(ty) = self.constructor_type(lexer, word, span, ctx.reborrow())? {
+ let ty_span = lexer.span_from(start);
+ let components = self.arguments(lexer, ctx.reborrow())?;
+ ast::Expression::Construct {
+ ty,
+ ty_span,
+ components,
+ }
+ } else if let Token::Paren('(') = lexer.peek().0 {
+ self.pop_rule_span(lexer);
+ return self.function_call(lexer, word, span, ctx);
+ } else if word == "bitcast" {
+ self.pop_rule_span(lexer);
+ return self.function_call(lexer, word, span, ctx);
+ } else {
+ let ident = self.ident_expr(word, span, ctx.reborrow());
+ ast::Expression::Ident(ident)
+ }
+ }
+ other => return Err(Error::Unexpected(other.1, ExpectedToken::PrimaryExpression)),
+ };
+
+ let span = self.pop_rule_span(lexer);
+ let expr = ctx.expressions.append(expr, span);
+ Ok(expr)
+ }
+
+ fn postfix<'a>(
+ &mut self,
+ span_start: usize,
+ lexer: &mut Lexer<'a>,
+ mut ctx: ExpressionContext<'a, '_, '_>,
+ expr: Handle<ast::Expression<'a>>,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> {
+ let mut expr = expr;
+
+ loop {
+ let expression = match lexer.peek().0 {
+ Token::Separator('.') => {
+ let _ = lexer.next();
+ let field = lexer.next_ident()?;
+
+ ast::Expression::Member { base: expr, field }
+ }
+ Token::Paren('[') => {
+ let _ = lexer.next();
+ let index = self.general_expression(lexer, ctx.reborrow())?;
+ lexer.expect(Token::Paren(']'))?;
+
+ ast::Expression::Index { base: expr, index }
+ }
+ _ => break,
+ };
+
+ let span = lexer.span_from(span_start);
+ expr = ctx.expressions.append(expression, span);
+ }
+
+ Ok(expr)
+ }
+
+ /// Parse a `unary_expression`.
+ fn unary_expression<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut ctx: ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> {
+ self.push_rule_span(Rule::UnaryExpr, lexer);
+ //TODO: refactor this to avoid backing up
+ let expr = match lexer.peek().0 {
+ Token::Operation('-') => {
+ let _ = lexer.next();
+ let expr = self.unary_expression(lexer, ctx.reborrow())?;
+ let expr = ast::Expression::Unary {
+ op: crate::UnaryOperator::Negate,
+ expr,
+ };
+ let span = self.peek_rule_span(lexer);
+ ctx.expressions.append(expr, span)
+ }
+ Token::Operation('!' | '~') => {
+ let _ = lexer.next();
+ let expr = self.unary_expression(lexer, ctx.reborrow())?;
+ let expr = ast::Expression::Unary {
+ op: crate::UnaryOperator::Not,
+ expr,
+ };
+ let span = self.peek_rule_span(lexer);
+ ctx.expressions.append(expr, span)
+ }
+ Token::Operation('*') => {
+ let _ = lexer.next();
+ let expr = self.unary_expression(lexer, ctx.reborrow())?;
+ let expr = ast::Expression::Deref(expr);
+ let span = self.peek_rule_span(lexer);
+ ctx.expressions.append(expr, span)
+ }
+ Token::Operation('&') => {
+ let _ = lexer.next();
+ let expr = self.unary_expression(lexer, ctx.reborrow())?;
+ let expr = ast::Expression::AddrOf(expr);
+ let span = self.peek_rule_span(lexer);
+ ctx.expressions.append(expr, span)
+ }
+ _ => self.singular_expression(lexer, ctx.reborrow())?,
+ };
+
+ self.pop_rule_span(lexer);
+ Ok(expr)
+ }
+
+ /// Parse a `singular_expression`.
+ fn singular_expression<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut ctx: ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> {
+ let start = lexer.start_byte_offset();
+ self.push_rule_span(Rule::SingularExpr, lexer);
+ let primary_expr = self.primary_expression(lexer, ctx.reborrow())?;
+ let singular_expr = self.postfix(start, lexer, ctx.reborrow(), primary_expr)?;
+ self.pop_rule_span(lexer);
+
+ Ok(singular_expr)
+ }
+
+ fn equality_expression<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut context: ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> {
+ // equality_expression
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::LogicalOperation('=') => Some(crate::BinaryOperator::Equal),
+ Token::LogicalOperation('!') => Some(crate::BinaryOperator::NotEqual),
+ _ => None,
+ },
+ // relational_expression
+ |lexer, mut context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::Paren('<') => Some(crate::BinaryOperator::Less),
+ Token::Paren('>') => Some(crate::BinaryOperator::Greater),
+ Token::LogicalOperation('<') => Some(crate::BinaryOperator::LessEqual),
+ Token::LogicalOperation('>') => Some(crate::BinaryOperator::GreaterEqual),
+ _ => None,
+ },
+ // shift_expression
+ |lexer, mut context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::ShiftOperation('<') => {
+ Some(crate::BinaryOperator::ShiftLeft)
+ }
+ Token::ShiftOperation('>') => {
+ Some(crate::BinaryOperator::ShiftRight)
+ }
+ _ => None,
+ },
+ // additive_expression
+ |lexer, mut context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::Operation('+') => Some(crate::BinaryOperator::Add),
+ Token::Operation('-') => {
+ Some(crate::BinaryOperator::Subtract)
+ }
+ _ => None,
+ },
+ // multiplicative_expression
+ |lexer, mut context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::Operation('*') => {
+ Some(crate::BinaryOperator::Multiply)
+ }
+ Token::Operation('/') => {
+ Some(crate::BinaryOperator::Divide)
+ }
+ Token::Operation('%') => {
+ Some(crate::BinaryOperator::Modulo)
+ }
+ _ => None,
+ },
+ |lexer, context| self.unary_expression(lexer, context),
+ )
+ },
+ )
+ },
+ )
+ },
+ )
+ },
+ )
+ }
+
+ fn general_expression<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut ctx: ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> {
+ self.general_expression_with_span(lexer, ctx.reborrow())
+ .map(|(expr, _)| expr)
+ }
+
+ fn general_expression_with_span<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut context: ExpressionContext<'a, '_, '_>,
+ ) -> Result<(Handle<ast::Expression<'a>>, Span), Error<'a>> {
+ self.push_rule_span(Rule::GeneralExpr, lexer);
+ // logical_or_expression
+ let handle = context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::LogicalOperation('|') => Some(crate::BinaryOperator::LogicalOr),
+ _ => None,
+ },
+ // logical_and_expression
+ |lexer, mut context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::LogicalOperation('&') => Some(crate::BinaryOperator::LogicalAnd),
+ _ => None,
+ },
+ // inclusive_or_expression
+ |lexer, mut context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::Operation('|') => Some(crate::BinaryOperator::InclusiveOr),
+ _ => None,
+ },
+ // exclusive_or_expression
+ |lexer, mut context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::Operation('^') => {
+ Some(crate::BinaryOperator::ExclusiveOr)
+ }
+ _ => None,
+ },
+ // and_expression
+ |lexer, mut context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::Operation('&') => {
+ Some(crate::BinaryOperator::And)
+ }
+ _ => None,
+ },
+ |lexer, context| {
+ self.equality_expression(lexer, context)
+ },
+ )
+ },
+ )
+ },
+ )
+ },
+ )
+ },
+ )?;
+ Ok((handle, self.pop_rule_span(lexer)))
+ }
+
+ fn variable_decl<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut ctx: ExpressionContext<'a, '_, '_>,
+ ) -> Result<ast::GlobalVariable<'a>, Error<'a>> {
+ self.push_rule_span(Rule::VariableDecl, lexer);
+ let mut space = crate::AddressSpace::Handle;
+
+ if lexer.skip(Token::Paren('<')) {
+ let (class_str, span) = lexer.next_ident_with_span()?;
+ space = match class_str {
+ "storage" => {
+ let access = if lexer.skip(Token::Separator(',')) {
+ lexer.next_storage_access()?
+ } else {
+ // defaulting to `read`
+ crate::StorageAccess::LOAD
+ };
+ crate::AddressSpace::Storage { access }
+ }
+ _ => conv::map_address_space(class_str, span)?,
+ };
+ lexer.expect(Token::Paren('>'))?;
+ }
+ let name = lexer.next_ident()?;
+ lexer.expect(Token::Separator(':'))?;
+ let ty = self.type_decl(lexer, ctx.reborrow())?;
+
+ let init = if lexer.skip(Token::Operation('=')) {
+ let handle = self.general_expression(lexer, ctx.reborrow())?;
+ Some(handle)
+ } else {
+ None
+ };
+ lexer.expect(Token::Separator(';'))?;
+ self.pop_rule_span(lexer);
+
+ Ok(ast::GlobalVariable {
+ name,
+ space,
+ binding: None,
+ ty,
+ init,
+ })
+ }
+
+ fn struct_body<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut ctx: ExpressionContext<'a, '_, '_>,
+ ) -> Result<Vec<ast::StructMember<'a>>, Error<'a>> {
+ let mut members = Vec::new();
+
+ lexer.expect(Token::Paren('{'))?;
+ let mut ready = true;
+ while !lexer.skip(Token::Paren('}')) {
+ if !ready {
+ return Err(Error::Unexpected(
+ lexer.next().1,
+ ExpectedToken::Token(Token::Separator(',')),
+ ));
+ }
+ let (mut size, mut align) = (None, None);
+ self.push_rule_span(Rule::Attribute, lexer);
+ let mut bind_parser = BindingParser::default();
+ while lexer.skip(Token::Attribute) {
+ match lexer.next_ident_with_span()? {
+ ("size", _) => {
+ lexer.expect(Token::Paren('('))?;
+ let (value, span) = lexer.capture_span(Self::non_negative_i32_literal)?;
+ lexer.expect(Token::Paren(')'))?;
+ size = Some((value, span));
+ }
+ ("align", _) => {
+ lexer.expect(Token::Paren('('))?;
+ let (value, span) = lexer.capture_span(Self::non_negative_i32_literal)?;
+ lexer.expect(Token::Paren(')'))?;
+ align = Some((value, span));
+ }
+ (word, word_span) => bind_parser.parse(lexer, word, word_span)?,
+ }
+ }
+
+ let bind_span = self.pop_rule_span(lexer);
+ let binding = bind_parser.finish(bind_span)?;
+
+ let name = lexer.next_ident()?;
+ lexer.expect(Token::Separator(':'))?;
+ let ty = self.type_decl(lexer, ctx.reborrow())?;
+ ready = lexer.skip(Token::Separator(','));
+
+ members.push(ast::StructMember {
+ name,
+ ty,
+ binding,
+ size,
+ align,
+ });
+ }
+
+ Ok(members)
+ }
+
+ fn matrix_scalar_type<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ columns: crate::VectorSize,
+ rows: crate::VectorSize,
+ ) -> Result<ast::Type<'a>, Error<'a>> {
+ let (kind, width, span) = lexer.next_scalar_generic_with_span()?;
+ match kind {
+ crate::ScalarKind::Float => Ok(ast::Type::Matrix {
+ columns,
+ rows,
+ width,
+ }),
+ _ => Err(Error::BadMatrixScalarKind(span, kind, width)),
+ }
+ }
+
+ fn type_decl_impl<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ word: &'a str,
+ mut ctx: ExpressionContext<'a, '_, '_>,
+ ) -> Result<Option<ast::Type<'a>>, Error<'a>> {
+ if let Some((kind, width)) = conv::get_scalar_type(word) {
+ return Ok(Some(ast::Type::Scalar { kind, width }));
+ }
+
+ Ok(Some(match word {
+ "vec2" => {
+ let (kind, width) = lexer.next_scalar_generic()?;
+ ast::Type::Vector {
+ size: crate::VectorSize::Bi,
+ kind,
+ width,
+ }
+ }
+ "vec2i" => ast::Type::Vector {
+ size: crate::VectorSize::Bi,
+ kind: crate::ScalarKind::Sint,
+ width: 4,
+ },
+ "vec2u" => ast::Type::Vector {
+ size: crate::VectorSize::Bi,
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ },
+ "vec2f" => ast::Type::Vector {
+ size: crate::VectorSize::Bi,
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ },
+ "vec3" => {
+ let (kind, width) = lexer.next_scalar_generic()?;
+ ast::Type::Vector {
+ size: crate::VectorSize::Tri,
+ kind,
+ width,
+ }
+ }
+ "vec3i" => ast::Type::Vector {
+ size: crate::VectorSize::Tri,
+ kind: crate::ScalarKind::Sint,
+ width: 4,
+ },
+ "vec3u" => ast::Type::Vector {
+ size: crate::VectorSize::Tri,
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ },
+ "vec3f" => ast::Type::Vector {
+ size: crate::VectorSize::Tri,
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ },
+ "vec4" => {
+ let (kind, width) = lexer.next_scalar_generic()?;
+ ast::Type::Vector {
+ size: crate::VectorSize::Quad,
+ kind,
+ width,
+ }
+ }
+ "vec4i" => ast::Type::Vector {
+ size: crate::VectorSize::Quad,
+ kind: crate::ScalarKind::Sint,
+ width: 4,
+ },
+ "vec4u" => ast::Type::Vector {
+ size: crate::VectorSize::Quad,
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ },
+ "vec4f" => ast::Type::Vector {
+ size: crate::VectorSize::Quad,
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ },
+ "mat2x2" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Bi, crate::VectorSize::Bi)?
+ }
+ "mat2x2f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ },
+ "mat2x3" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Bi, crate::VectorSize::Tri)?
+ }
+ "mat2x3f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Tri,
+ width: 4,
+ },
+ "mat2x4" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Bi, crate::VectorSize::Quad)?
+ }
+ "mat2x4f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Quad,
+ width: 4,
+ },
+ "mat3x2" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Tri, crate::VectorSize::Bi)?
+ }
+ "mat3x2f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ },
+ "mat3x3" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Tri, crate::VectorSize::Tri)?
+ }
+ "mat3x3f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Tri,
+ width: 4,
+ },
+ "mat3x4" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Tri, crate::VectorSize::Quad)?
+ }
+ "mat3x4f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Quad,
+ width: 4,
+ },
+ "mat4x2" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Quad, crate::VectorSize::Bi)?
+ }
+ "mat4x2f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ },
+ "mat4x3" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Quad, crate::VectorSize::Tri)?
+ }
+ "mat4x3f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Tri,
+ width: 4,
+ },
+ "mat4x4" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Quad, crate::VectorSize::Quad)?
+ }
+ "mat4x4f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Quad,
+ width: 4,
+ },
+ "atomic" => {
+ let (kind, width) = lexer.next_scalar_generic()?;
+ ast::Type::Atomic { kind, width }
+ }
+ "ptr" => {
+ lexer.expect_generic_paren('<')?;
+ let (ident, span) = lexer.next_ident_with_span()?;
+ let mut space = conv::map_address_space(ident, span)?;
+ lexer.expect(Token::Separator(','))?;
+ let base = self.type_decl(lexer, ctx)?;
+ if let crate::AddressSpace::Storage { ref mut access } = space {
+ *access = if lexer.skip(Token::Separator(',')) {
+ lexer.next_storage_access()?
+ } else {
+ crate::StorageAccess::LOAD
+ };
+ }
+ lexer.expect_generic_paren('>')?;
+ ast::Type::Pointer { base, space }
+ }
+ "array" => {
+ lexer.expect_generic_paren('<')?;
+ let base = self.type_decl(lexer, ctx.reborrow())?;
+ let size = if lexer.skip(Token::Separator(',')) {
+ let size = self.unary_expression(lexer, ctx.reborrow())?;
+ ast::ArraySize::Constant(size)
+ } else {
+ ast::ArraySize::Dynamic
+ };
+ lexer.expect_generic_paren('>')?;
+
+ ast::Type::Array { base, size }
+ }
+ "binding_array" => {
+ lexer.expect_generic_paren('<')?;
+ let base = self.type_decl(lexer, ctx.reborrow())?;
+ let size = if lexer.skip(Token::Separator(',')) {
+ let size = self.unary_expression(lexer, ctx.reborrow())?;
+ ast::ArraySize::Constant(size)
+ } else {
+ ast::ArraySize::Dynamic
+ };
+ lexer.expect_generic_paren('>')?;
+
+ ast::Type::BindingArray { base, size }
+ }
+ "sampler" => ast::Type::Sampler { comparison: false },
+ "sampler_comparison" => ast::Type::Sampler { comparison: true },
+ "texture_1d" => {
+ let (kind, width, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(kind, width, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D1,
+ arrayed: false,
+ class: crate::ImageClass::Sampled { kind, multi: false },
+ }
+ }
+ "texture_1d_array" => {
+ let (kind, width, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(kind, width, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D1,
+ arrayed: true,
+ class: crate::ImageClass::Sampled { kind, multi: false },
+ }
+ }
+ "texture_2d" => {
+ let (kind, width, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(kind, width, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: false,
+ class: crate::ImageClass::Sampled { kind, multi: false },
+ }
+ }
+ "texture_2d_array" => {
+ let (kind, width, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(kind, width, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: true,
+ class: crate::ImageClass::Sampled { kind, multi: false },
+ }
+ }
+ "texture_3d" => {
+ let (kind, width, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(kind, width, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D3,
+ arrayed: false,
+ class: crate::ImageClass::Sampled { kind, multi: false },
+ }
+ }
+ "texture_cube" => {
+ let (kind, width, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(kind, width, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::Cube,
+ arrayed: false,
+ class: crate::ImageClass::Sampled { kind, multi: false },
+ }
+ }
+ "texture_cube_array" => {
+ let (kind, width, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(kind, width, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::Cube,
+ arrayed: true,
+ class: crate::ImageClass::Sampled { kind, multi: false },
+ }
+ }
+ "texture_multisampled_2d" => {
+ let (kind, width, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(kind, width, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: false,
+ class: crate::ImageClass::Sampled { kind, multi: true },
+ }
+ }
+ "texture_multisampled_2d_array" => {
+ let (kind, width, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(kind, width, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: true,
+ class: crate::ImageClass::Sampled { kind, multi: true },
+ }
+ }
+ "texture_depth_2d" => ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: false,
+ class: crate::ImageClass::Depth { multi: false },
+ },
+ "texture_depth_2d_array" => ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: true,
+ class: crate::ImageClass::Depth { multi: false },
+ },
+ "texture_depth_cube" => ast::Type::Image {
+ dim: crate::ImageDimension::Cube,
+ arrayed: false,
+ class: crate::ImageClass::Depth { multi: false },
+ },
+ "texture_depth_cube_array" => ast::Type::Image {
+ dim: crate::ImageDimension::Cube,
+ arrayed: true,
+ class: crate::ImageClass::Depth { multi: false },
+ },
+ "texture_depth_multisampled_2d" => ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: false,
+ class: crate::ImageClass::Depth { multi: true },
+ },
+ "texture_storage_1d" => {
+ let (format, access) = lexer.next_format_generic()?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D1,
+ arrayed: false,
+ class: crate::ImageClass::Storage { format, access },
+ }
+ }
+ "texture_storage_1d_array" => {
+ let (format, access) = lexer.next_format_generic()?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D1,
+ arrayed: true,
+ class: crate::ImageClass::Storage { format, access },
+ }
+ }
+ "texture_storage_2d" => {
+ let (format, access) = lexer.next_format_generic()?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: false,
+ class: crate::ImageClass::Storage { format, access },
+ }
+ }
+ "texture_storage_2d_array" => {
+ let (format, access) = lexer.next_format_generic()?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: true,
+ class: crate::ImageClass::Storage { format, access },
+ }
+ }
+ "texture_storage_3d" => {
+ let (format, access) = lexer.next_format_generic()?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D3,
+ arrayed: false,
+ class: crate::ImageClass::Storage { format, access },
+ }
+ }
+ "acceleration_structure" => ast::Type::AccelerationStructure,
+ "ray_query" => ast::Type::RayQuery,
+ "RayDesc" => ast::Type::RayDesc,
+ "RayIntersection" => ast::Type::RayIntersection,
+ _ => return Ok(None),
+ }))
+ }
+
+ const fn check_texture_sample_type(
+ kind: crate::ScalarKind,
+ width: u8,
+ span: Span,
+ ) -> Result<(), Error<'static>> {
+ use crate::ScalarKind::*;
+ // Validate according to https://gpuweb.github.io/gpuweb/wgsl/#sampled-texture-type
+ match (kind, width) {
+ (Float | Sint | Uint, 4) => Ok(()),
+ _ => Err(Error::BadTextureSampleType { span, kind, width }),
+ }
+ }
+
+ /// Parse type declaration of a given name.
+ fn type_decl<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut ctx: ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<ast::Type<'a>>, Error<'a>> {
+ self.push_rule_span(Rule::TypeDecl, lexer);
+
+ let (name, span) = lexer.next_ident_with_span()?;
+
+ let ty = match self.type_decl_impl(lexer, name, ctx.reborrow())? {
+ Some(ty) => ty,
+ None => {
+ ctx.unresolved.insert(ast::Dependency {
+ ident: name,
+ usage: span,
+ });
+ ast::Type::User(ast::Ident { name, span })
+ }
+ };
+
+ self.pop_rule_span(lexer);
+
+ let handle = ctx.types.append(ty, Span::UNDEFINED);
+ Ok(handle)
+ }
+
+ fn assignment_op_and_rhs<'a, 'out>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut ctx: ExpressionContext<'a, '_, 'out>,
+ block: &'out mut ast::Block<'a>,
+ target: Handle<ast::Expression<'a>>,
+ span_start: usize,
+ ) -> Result<(), Error<'a>> {
+ use crate::BinaryOperator as Bo;
+
+ let op = lexer.next();
+ let (op, value) = match op {
+ (Token::Operation('='), _) => {
+ let value = self.general_expression(lexer, ctx.reborrow())?;
+ (None, value)
+ }
+ (Token::AssignmentOperation(c), _) => {
+ let op = match c {
+ '<' => Bo::ShiftLeft,
+ '>' => Bo::ShiftRight,
+ '+' => Bo::Add,
+ '-' => Bo::Subtract,
+ '*' => Bo::Multiply,
+ '/' => Bo::Divide,
+ '%' => Bo::Modulo,
+ '&' => Bo::And,
+ '|' => Bo::InclusiveOr,
+ '^' => Bo::ExclusiveOr,
+ // Note: `consume_token` shouldn't produce any other assignment ops
+ _ => unreachable!(),
+ };
+
+ let value = self.general_expression(lexer, ctx.reborrow())?;
+ (Some(op), value)
+ }
+ token @ (Token::IncrementOperation | Token::DecrementOperation, _) => {
+ let op = match token.0 {
+ Token::IncrementOperation => ast::StatementKind::Increment,
+ Token::DecrementOperation => ast::StatementKind::Decrement,
+ _ => unreachable!(),
+ };
+
+ let span = lexer.span_from(span_start);
+ block.stmts.push(ast::Statement {
+ kind: op(target),
+ span,
+ });
+ return Ok(());
+ }
+ _ => return Err(Error::Unexpected(op.1, ExpectedToken::Assignment)),
+ };
+
+ let span = lexer.span_from(span_start);
+ block.stmts.push(ast::Statement {
+ kind: ast::StatementKind::Assign { target, op, value },
+ span,
+ });
+ Ok(())
+ }
+
+ /// Parse an assignment statement (will also parse increment and decrement statements)
+ fn assignment_statement<'a, 'out>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut ctx: ExpressionContext<'a, '_, 'out>,
+ block: &'out mut ast::Block<'a>,
+ ) -> Result<(), Error<'a>> {
+ let span_start = lexer.start_byte_offset();
+ let target = self.general_expression(lexer, ctx.reborrow())?;
+ self.assignment_op_and_rhs(lexer, ctx, block, target, span_start)
+ }
+
+ /// Parse a function call statement.
+ /// Expects `ident` to be consumed (not in the lexer).
+ fn function_statement<'a, 'out>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ident: &'a str,
+ ident_span: Span,
+ span_start: usize,
+ mut context: ExpressionContext<'a, '_, 'out>,
+ block: &'out mut ast::Block<'a>,
+ ) -> Result<(), Error<'a>> {
+ self.push_rule_span(Rule::SingularExpr, lexer);
+
+ context.unresolved.insert(ast::Dependency {
+ ident,
+ usage: ident_span,
+ });
+ let arguments = self.arguments(lexer, context.reborrow())?;
+ let span = lexer.span_from(span_start);
+
+ block.stmts.push(ast::Statement {
+ kind: ast::StatementKind::Call {
+ function: ast::Ident {
+ name: ident,
+ span: ident_span,
+ },
+ arguments,
+ },
+ span,
+ });
+
+ self.pop_rule_span(lexer);
+
+ Ok(())
+ }
+
+ fn function_call_or_assignment_statement<'a, 'out>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut context: ExpressionContext<'a, '_, 'out>,
+ block: &'out mut ast::Block<'a>,
+ ) -> Result<(), Error<'a>> {
+ let span_start = lexer.start_byte_offset();
+ match lexer.peek() {
+ (Token::Word(name), span) => {
+ // A little hack for 2 token lookahead.
+ let cloned = lexer.clone();
+ let _ = lexer.next();
+ match lexer.peek() {
+ (Token::Paren('('), _) => self.function_statement(
+ lexer,
+ name,
+ span,
+ span_start,
+ context.reborrow(),
+ block,
+ ),
+ _ => {
+ *lexer = cloned;
+ self.assignment_statement(lexer, context.reborrow(), block)
+ }
+ }
+ }
+ _ => self.assignment_statement(lexer, context.reborrow(), block),
+ }
+ }
+
+ fn statement<'a, 'out>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut ctx: ExpressionContext<'a, '_, 'out>,
+ block: &'out mut ast::Block<'a>,
+ ) -> Result<(), Error<'a>> {
+ self.push_rule_span(Rule::Statement, lexer);
+ match lexer.peek() {
+ (Token::Separator(';'), _) => {
+ let _ = lexer.next();
+ self.pop_rule_span(lexer);
+ return Ok(());
+ }
+ (Token::Paren('{'), _) => {
+ let (inner, span) = self.block(lexer, ctx.reborrow())?;
+ block.stmts.push(ast::Statement {
+ kind: ast::StatementKind::Block(inner),
+ span,
+ });
+ self.pop_rule_span(lexer);
+ return Ok(());
+ }
+ (Token::Word(word), _) => {
+ let kind = match word {
+ "_" => {
+ let _ = lexer.next();
+ lexer.expect(Token::Operation('='))?;
+ let expr = self.general_expression(lexer, ctx.reborrow())?;
+ lexer.expect(Token::Separator(';'))?;
+
+ ast::StatementKind::Ignore(expr)
+ }
+ "let" => {
+ let _ = lexer.next();
+ let name = lexer.next_ident()?;
+
+ let given_ty = if lexer.skip(Token::Separator(':')) {
+ let ty = self.type_decl(lexer, ctx.reborrow())?;
+ Some(ty)
+ } else {
+ None
+ };
+ lexer.expect(Token::Operation('='))?;
+ let expr_id = self.general_expression(lexer, ctx.reborrow())?;
+ lexer.expect(Token::Separator(';'))?;
+
+ let handle = ctx.locals.append(ast::Local, name.span);
+ if let Some(old) = ctx.local_table.add(name.name, handle) {
+ return Err(Error::Redefinition {
+ previous: ctx.locals.get_span(old),
+ current: name.span,
+ });
+ }
+
+ ast::StatementKind::LocalDecl(ast::LocalDecl::Let(ast::Let {
+ name,
+ ty: given_ty,
+ init: expr_id,
+ handle,
+ }))
+ }
+ "var" => {
+ let _ = lexer.next();
+
+ let name = lexer.next_ident()?;
+ let ty = if lexer.skip(Token::Separator(':')) {
+ let ty = self.type_decl(lexer, ctx.reborrow())?;
+ Some(ty)
+ } else {
+ None
+ };
+
+ let init = if lexer.skip(Token::Operation('=')) {
+ let init = self.general_expression(lexer, ctx.reborrow())?;
+ Some(init)
+ } else {
+ None
+ };
+
+ lexer.expect(Token::Separator(';'))?;
+
+ let handle = ctx.locals.append(ast::Local, name.span);
+ if let Some(old) = ctx.local_table.add(name.name, handle) {
+ return Err(Error::Redefinition {
+ previous: ctx.locals.get_span(old),
+ current: name.span,
+ });
+ }
+
+ ast::StatementKind::LocalDecl(ast::LocalDecl::Var(ast::LocalVariable {
+ name,
+ ty,
+ init,
+ handle,
+ }))
+ }
+ "return" => {
+ let _ = lexer.next();
+ let value = if lexer.peek().0 != Token::Separator(';') {
+ let handle = self.general_expression(lexer, ctx.reborrow())?;
+ Some(handle)
+ } else {
+ None
+ };
+ lexer.expect(Token::Separator(';'))?;
+ ast::StatementKind::Return { value }
+ }
+ "if" => {
+ let _ = lexer.next();
+ let condition = self.general_expression(lexer, ctx.reborrow())?;
+
+ let accept = self.block(lexer, ctx.reborrow())?.0;
+
+ let mut elsif_stack = Vec::new();
+ let mut elseif_span_start = lexer.start_byte_offset();
+ let mut reject = loop {
+ if !lexer.skip(Token::Word("else")) {
+ break ast::Block::default();
+ }
+
+ if !lexer.skip(Token::Word("if")) {
+ // ... else { ... }
+ break self.block(lexer, ctx.reborrow())?.0;
+ }
+
+ // ... else if (...) { ... }
+ let other_condition = self.general_expression(lexer, ctx.reborrow())?;
+ let other_block = self.block(lexer, ctx.reborrow())?;
+ elsif_stack.push((elseif_span_start, other_condition, other_block));
+ elseif_span_start = lexer.start_byte_offset();
+ };
+
+ // reverse-fold the else-if blocks
+ //Note: we may consider uplifting this to the IR
+ for (other_span_start, other_cond, other_block) in
+ elsif_stack.into_iter().rev()
+ {
+ let sub_stmt = ast::StatementKind::If {
+ condition: other_cond,
+ accept: other_block.0,
+ reject,
+ };
+ reject = ast::Block::default();
+ let span = lexer.span_from(other_span_start);
+ reject.stmts.push(ast::Statement {
+ kind: sub_stmt,
+ span,
+ })
+ }
+
+ ast::StatementKind::If {
+ condition,
+ accept,
+ reject,
+ }
+ }
+ "switch" => {
+ let _ = lexer.next();
+ let selector = self.general_expression(lexer, ctx.reborrow())?;
+ lexer.expect(Token::Paren('{'))?;
+ let mut cases = Vec::new();
+
+ loop {
+ // cases + default
+ match lexer.next() {
+ (Token::Word("case"), _) => {
+ // parse a list of values
+ let (value, value_span) = loop {
+ let (value, value_span) = Self::switch_value(lexer)?;
+ if lexer.skip(Token::Separator(',')) {
+ if lexer.skip(Token::Separator(':')) {
+ break (value, value_span);
+ }
+ } else {
+ lexer.skip(Token::Separator(':'));
+ break (value, value_span);
+ }
+ cases.push(ast::SwitchCase {
+ value,
+ value_span,
+ body: ast::Block::default(),
+ fall_through: true,
+ });
+ };
+
+ let body = self.block(lexer, ctx.reborrow())?.0;
+
+ cases.push(ast::SwitchCase {
+ value,
+ value_span,
+ body,
+ fall_through: false,
+ });
+ }
+ (Token::Word("default"), value_span) => {
+ lexer.skip(Token::Separator(':'));
+ let body = self.block(lexer, ctx.reborrow())?.0;
+ cases.push(ast::SwitchCase {
+ value: ast::SwitchValue::Default,
+ value_span,
+ body,
+ fall_through: false,
+ });
+ }
+ (Token::Paren('}'), _) => break,
+ (_, span) => {
+ return Err(Error::Unexpected(span, ExpectedToken::SwitchItem))
+ }
+ }
+ }
+
+ ast::StatementKind::Switch { selector, cases }
+ }
+ "loop" => self.r#loop(lexer, ctx.reborrow())?,
+ "while" => {
+ let _ = lexer.next();
+ let mut body = ast::Block::default();
+
+ let (condition, span) = lexer.capture_span(|lexer| {
+ let condition = self.general_expression(lexer, ctx.reborrow())?;
+ Ok(condition)
+ })?;
+ let mut reject = ast::Block::default();
+ reject.stmts.push(ast::Statement {
+ kind: ast::StatementKind::Break,
+ span,
+ });
+
+ body.stmts.push(ast::Statement {
+ kind: ast::StatementKind::If {
+ condition,
+ accept: ast::Block::default(),
+ reject,
+ },
+ span,
+ });
+
+ let (block, span) = self.block(lexer, ctx.reborrow())?;
+ body.stmts.push(ast::Statement {
+ kind: ast::StatementKind::Block(block),
+ span,
+ });
+
+ ast::StatementKind::Loop {
+ body,
+ continuing: ast::Block::default(),
+ break_if: None,
+ }
+ }
+ "for" => {
+ let _ = lexer.next();
+ lexer.expect(Token::Paren('('))?;
+
+ ctx.local_table.push_scope();
+
+ if !lexer.skip(Token::Separator(';')) {
+ let num_statements = block.stmts.len();
+ let (_, span) = lexer.capture_span(|lexer| {
+ self.statement(lexer, ctx.reborrow(), block)
+ })?;
+
+ if block.stmts.len() != num_statements {
+ match block.stmts.last().unwrap().kind {
+ ast::StatementKind::Call { .. }
+ | ast::StatementKind::Assign { .. }
+ | ast::StatementKind::LocalDecl(_) => {}
+ _ => return Err(Error::InvalidForInitializer(span)),
+ }
+ }
+ };
+
+ let mut body = ast::Block::default();
+ if !lexer.skip(Token::Separator(';')) {
+ let (condition, span) = lexer.capture_span(|lexer| {
+ let condition = self.general_expression(lexer, ctx.reborrow())?;
+ lexer.expect(Token::Separator(';'))?;
+ Ok(condition)
+ })?;
+ let mut reject = ast::Block::default();
+ reject.stmts.push(ast::Statement {
+ kind: ast::StatementKind::Break,
+ span,
+ });
+ body.stmts.push(ast::Statement {
+ kind: ast::StatementKind::If {
+ condition,
+ accept: ast::Block::default(),
+ reject,
+ },
+ span,
+ });
+ };
+
+ let mut continuing = ast::Block::default();
+ if !lexer.skip(Token::Paren(')')) {
+ self.function_call_or_assignment_statement(
+ lexer,
+ ctx.reborrow(),
+ &mut continuing,
+ )?;
+ lexer.expect(Token::Paren(')'))?;
+ }
+
+ let (block, span) = self.block(lexer, ctx.reborrow())?;
+ body.stmts.push(ast::Statement {
+ kind: ast::StatementKind::Block(block),
+ span,
+ });
+
+ ctx.local_table.pop_scope();
+
+ ast::StatementKind::Loop {
+ body,
+ continuing,
+ break_if: None,
+ }
+ }
+ "break" => {
+ let (_, span) = lexer.next();
+ // Check if the next token is an `if`, this indicates
+ // that the user tried to type out a `break if` which
+ // is illegal in this position.
+ let (peeked_token, peeked_span) = lexer.peek();
+ if let Token::Word("if") = peeked_token {
+ let span = span.until(&peeked_span);
+ return Err(Error::InvalidBreakIf(span));
+ }
+ lexer.expect(Token::Separator(';'))?;
+ ast::StatementKind::Break
+ }
+ "continue" => {
+ let _ = lexer.next();
+ lexer.expect(Token::Separator(';'))?;
+ ast::StatementKind::Continue
+ }
+ "discard" => {
+ let _ = lexer.next();
+ lexer.expect(Token::Separator(';'))?;
+ ast::StatementKind::Kill
+ }
+ // assignment or a function call
+ _ => {
+ self.function_call_or_assignment_statement(lexer, ctx.reborrow(), block)?;
+ lexer.expect(Token::Separator(';'))?;
+ self.pop_rule_span(lexer);
+ return Ok(());
+ }
+ };
+
+ let span = self.pop_rule_span(lexer);
+ block.stmts.push(ast::Statement { kind, span });
+ }
+ _ => {
+ self.assignment_statement(lexer, ctx.reborrow(), block)?;
+ lexer.expect(Token::Separator(';'))?;
+ self.pop_rule_span(lexer);
+ }
+ }
+ Ok(())
+ }
+
+ fn r#loop<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut ctx: ExpressionContext<'a, '_, '_>,
+ ) -> Result<ast::StatementKind<'a>, Error<'a>> {
+ let _ = lexer.next();
+ let mut body = ast::Block::default();
+ let mut continuing = ast::Block::default();
+ let mut break_if = None;
+
+ lexer.expect(Token::Paren('{'))?;
+
+ ctx.local_table.push_scope();
+
+ loop {
+ if lexer.skip(Token::Word("continuing")) {
+ // Branch for the `continuing` block, this must be
+ // the last thing in the loop body
+
+ // Expect a opening brace to start the continuing block
+ lexer.expect(Token::Paren('{'))?;
+ loop {
+ if lexer.skip(Token::Word("break")) {
+ // Branch for the `break if` statement, this statement
+ // has the form `break if <expr>;` and must be the last
+ // statement in a continuing block
+
+ // The break must be followed by an `if` to form
+ // the break if
+ lexer.expect(Token::Word("if"))?;
+
+ let condition = self.general_expression(lexer, ctx.reborrow())?;
+ // Set the condition of the break if to the newly parsed
+ // expression
+ break_if = Some(condition);
+
+ // Expect a semicolon to close the statement
+ lexer.expect(Token::Separator(';'))?;
+ // Expect a closing brace to close the continuing block,
+ // since the break if must be the last statement
+ lexer.expect(Token::Paren('}'))?;
+ // Stop parsing the continuing block
+ break;
+ } else if lexer.skip(Token::Paren('}')) {
+ // If we encounter a closing brace it means we have reached
+ // the end of the continuing block and should stop processing
+ break;
+ } else {
+ // Otherwise try to parse a statement
+ self.statement(lexer, ctx.reborrow(), &mut continuing)?;
+ }
+ }
+ // Since the continuing block must be the last part of the loop body,
+ // we expect to see a closing brace to end the loop body
+ lexer.expect(Token::Paren('}'))?;
+ break;
+ }
+ if lexer.skip(Token::Paren('}')) {
+ // If we encounter a closing brace it means we have reached
+ // the end of the loop body and should stop processing
+ break;
+ }
+ // Otherwise try to parse a statement
+ self.statement(lexer, ctx.reborrow(), &mut body)?;
+ }
+
+ ctx.local_table.pop_scope();
+
+ Ok(ast::StatementKind::Loop {
+ body,
+ continuing,
+ break_if,
+ })
+ }
+
+ /// compound_statement
+ fn block<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut ctx: ExpressionContext<'a, '_, '_>,
+ ) -> Result<(ast::Block<'a>, Span), Error<'a>> {
+ self.push_rule_span(Rule::Block, lexer);
+
+ ctx.local_table.push_scope();
+
+ lexer.expect(Token::Paren('{'))?;
+ let mut statements = ast::Block::default();
+ while !lexer.skip(Token::Paren('}')) {
+ self.statement(lexer, ctx.reborrow(), &mut statements)?;
+ }
+
+ ctx.local_table.pop_scope();
+
+ let span = self.pop_rule_span(lexer);
+ Ok((statements, span))
+ }
+
+ fn varying_binding<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ) -> Result<Option<crate::Binding>, Error<'a>> {
+ let mut bind_parser = BindingParser::default();
+ self.push_rule_span(Rule::Attribute, lexer);
+
+ while lexer.skip(Token::Attribute) {
+ let (word, span) = lexer.next_ident_with_span()?;
+ bind_parser.parse(lexer, word, span)?;
+ }
+
+ let span = self.pop_rule_span(lexer);
+ bind_parser.finish(span)
+ }
+
+ fn function_decl<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ out: &mut ast::TranslationUnit<'a>,
+ dependencies: &mut FastHashSet<ast::Dependency<'a>>,
+ ) -> Result<ast::Function<'a>, Error<'a>> {
+ self.push_rule_span(Rule::FunctionDecl, lexer);
+ // read function name
+ let fun_name = lexer.next_ident()?;
+
+ let mut locals = Arena::new();
+
+ let mut ctx = ExpressionContext {
+ expressions: &mut out.expressions,
+ local_table: &mut SymbolTable::default(),
+ locals: &mut locals,
+ types: &mut out.types,
+ unresolved: dependencies,
+ };
+
+ // read parameter list
+ let mut arguments = Vec::new();
+ lexer.expect(Token::Paren('('))?;
+ let mut ready = true;
+ while !lexer.skip(Token::Paren(')')) {
+ if !ready {
+ return Err(Error::Unexpected(
+ lexer.next().1,
+ ExpectedToken::Token(Token::Separator(',')),
+ ));
+ }
+ let binding = self.varying_binding(lexer)?;
+
+ let param_name = lexer.next_ident()?;
+
+ lexer.expect(Token::Separator(':'))?;
+ let param_type = self.type_decl(lexer, ctx.reborrow())?;
+
+ let handle = ctx.locals.append(ast::Local, param_name.span);
+ ctx.local_table.add(param_name.name, handle);
+ arguments.push(ast::FunctionArgument {
+ name: param_name,
+ ty: param_type,
+ binding,
+ handle,
+ });
+ ready = lexer.skip(Token::Separator(','));
+ }
+ // read return type
+ let result = if lexer.skip(Token::Arrow) && !lexer.skip(Token::Word("void")) {
+ let binding = self.varying_binding(lexer)?;
+ let ty = self.type_decl(lexer, ctx.reborrow())?;
+ Some(ast::FunctionResult { ty, binding })
+ } else {
+ None
+ };
+
+ // read body
+ let body = self.block(lexer, ctx)?.0;
+
+ let fun = ast::Function {
+ entry_point: None,
+ name: fun_name,
+ arguments,
+ result,
+ body,
+ locals,
+ };
+
+ // done
+ self.pop_rule_span(lexer);
+
+ Ok(fun)
+ }
+
+ fn global_decl<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ out: &mut ast::TranslationUnit<'a>,
+ ) -> Result<(), Error<'a>> {
+ // read attributes
+ let mut binding = None;
+ let mut stage = None;
+ let mut workgroup_size = [0u32; 3];
+ let mut early_depth_test = None;
+ let (mut bind_index, mut bind_group) = (None, None);
+
+ self.push_rule_span(Rule::Attribute, lexer);
+ while lexer.skip(Token::Attribute) {
+ match lexer.next_ident_with_span()? {
+ ("binding", _) => {
+ lexer.expect(Token::Paren('('))?;
+ bind_index = Some(Self::non_negative_i32_literal(lexer)?);
+ lexer.expect(Token::Paren(')'))?;
+ }
+ ("group", _) => {
+ lexer.expect(Token::Paren('('))?;
+ bind_group = Some(Self::non_negative_i32_literal(lexer)?);
+ lexer.expect(Token::Paren(')'))?;
+ }
+ ("vertex", _) => {
+ stage = Some(crate::ShaderStage::Vertex);
+ }
+ ("fragment", _) => {
+ stage = Some(crate::ShaderStage::Fragment);
+ }
+ ("compute", _) => {
+ stage = Some(crate::ShaderStage::Compute);
+ }
+ ("workgroup_size", _) => {
+ lexer.expect(Token::Paren('('))?;
+ workgroup_size = [1u32; 3];
+ for (i, size) in workgroup_size.iter_mut().enumerate() {
+ *size = Self::generic_non_negative_int_literal(lexer)?;
+ match lexer.next() {
+ (Token::Paren(')'), _) => break,
+ (Token::Separator(','), _) if i != 2 => (),
+ other => {
+ return Err(Error::Unexpected(
+ other.1,
+ ExpectedToken::WorkgroupSizeSeparator,
+ ))
+ }
+ }
+ }
+ }
+ ("early_depth_test", _) => {
+ let conservative = if lexer.skip(Token::Paren('(')) {
+ let (ident, ident_span) = lexer.next_ident_with_span()?;
+ let value = conv::map_conservative_depth(ident, ident_span)?;
+ lexer.expect(Token::Paren(')'))?;
+ Some(value)
+ } else {
+ None
+ };
+ early_depth_test = Some(crate::EarlyDepthTest { conservative });
+ }
+ (_, word_span) => return Err(Error::UnknownAttribute(word_span)),
+ }
+ }
+
+ let attrib_span = self.pop_rule_span(lexer);
+ match (bind_group, bind_index) {
+ (Some(group), Some(index)) => {
+ binding = Some(crate::ResourceBinding {
+ group,
+ binding: index,
+ });
+ }
+ (Some(_), None) => return Err(Error::MissingAttribute("binding", attrib_span)),
+ (None, Some(_)) => return Err(Error::MissingAttribute("group", attrib_span)),
+ (None, None) => {}
+ }
+
+ let mut dependencies = FastHashSet::default();
+ let mut ctx = ExpressionContext {
+ expressions: &mut out.expressions,
+ local_table: &mut SymbolTable::default(),
+ locals: &mut Arena::new(),
+ types: &mut out.types,
+ unresolved: &mut dependencies,
+ };
+
+ // read item
+ let start = lexer.start_byte_offset();
+ let kind = match lexer.next() {
+ (Token::Separator(';'), _) => None,
+ (Token::Word("struct"), _) => {
+ let name = lexer.next_ident()?;
+
+ let members = self.struct_body(lexer, ctx)?;
+ Some(ast::GlobalDeclKind::Struct(ast::Struct { name, members }))
+ }
+ (Token::Word("alias"), _) => {
+ let name = lexer.next_ident()?;
+
+ lexer.expect(Token::Operation('='))?;
+ let ty = self.type_decl(lexer, ctx)?;
+ lexer.expect(Token::Separator(';'))?;
+ Some(ast::GlobalDeclKind::Type(ast::TypeAlias { name, ty }))
+ }
+ (Token::Word("const"), _) => {
+ let name = lexer.next_ident()?;
+
+ let ty = if lexer.skip(Token::Separator(':')) {
+ let ty = self.type_decl(lexer, ctx.reborrow())?;
+ Some(ty)
+ } else {
+ None
+ };
+
+ lexer.expect(Token::Operation('='))?;
+ let init = self.general_expression(lexer, ctx)?;
+ lexer.expect(Token::Separator(';'))?;
+
+ Some(ast::GlobalDeclKind::Const(ast::Const { name, ty, init }))
+ }
+ (Token::Word("var"), _) => {
+ let mut var = self.variable_decl(lexer, ctx)?;
+ var.binding = binding.take();
+ Some(ast::GlobalDeclKind::Var(var))
+ }
+ (Token::Word("fn"), _) => {
+ let function = self.function_decl(lexer, out, &mut dependencies)?;
+ Some(ast::GlobalDeclKind::Fn(ast::Function {
+ entry_point: stage.map(|stage| ast::EntryPoint {
+ stage,
+ early_depth_test,
+ workgroup_size,
+ }),
+ ..function
+ }))
+ }
+ (Token::End, _) => return Ok(()),
+ other => return Err(Error::Unexpected(other.1, ExpectedToken::GlobalItem)),
+ };
+
+ if let Some(kind) = kind {
+ out.decls.append(
+ ast::GlobalDecl { kind, dependencies },
+ lexer.span_from(start),
+ );
+ }
+
+ if !self.rules.is_empty() {
+ log::error!("Reached the end of global decl, but rule stack is not empty");
+ log::error!("Rules: {:?}", self.rules);
+ return Err(Error::Other);
+ };
+
+ match binding {
+ None => Ok(()),
+ // we had the attribute but no var?
+ Some(_) => Err(Error::Other),
+ }
+ }
+
+ pub fn parse<'a>(&mut self, source: &'a str) -> Result<ast::TranslationUnit<'a>, Error<'a>> {
+ self.reset();
+
+ let mut lexer = Lexer::new(source);
+ let mut tu = ast::TranslationUnit::default();
+ loop {
+ match self.global_decl(&mut lexer, &mut tu) {
+ Err(error) => return Err(error),
+ Ok(()) => {
+ if lexer.peek().0 == Token::End {
+ break;
+ }
+ }
+ }
+ }
+
+ Ok(tu)
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/parse/number.rs b/third_party/rust/naga/src/front/wgsl/parse/number.rs
new file mode 100644
index 0000000000..57a2be6142
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/parse/number.rs
@@ -0,0 +1,443 @@
+use std::borrow::Cow;
+
+use crate::front::wgsl::error::NumberError;
+use crate::front::wgsl::parse::lexer::Token;
+
+/// When using this type assume no Abstract Int/Float for now
+#[derive(Copy, Clone, Debug, PartialEq)]
+pub enum Number {
+ /// Abstract Int (-2^63 ≤ i < 2^63)
+ AbstractInt(i64),
+ /// Abstract Float (IEEE-754 binary64)
+ AbstractFloat(f64),
+ /// Concrete i32
+ I32(i32),
+ /// Concrete u32
+ U32(u32),
+ /// Concrete f32
+ F32(f32),
+}
+
+impl Number {
+ /// Convert abstract numbers to a plausible concrete counterpart.
+ ///
+ /// Return concrete numbers unchanged. If the conversion would be
+ /// lossy, return an error.
+ fn abstract_to_concrete(self) -> Result<Number, NumberError> {
+ match self {
+ Number::AbstractInt(num) => i32::try_from(num)
+ .map(Number::I32)
+ .map_err(|_| NumberError::NotRepresentable),
+ Number::AbstractFloat(num) => {
+ let num = num as f32;
+ if num.is_finite() {
+ Ok(Number::F32(num))
+ } else {
+ Err(NumberError::NotRepresentable)
+ }
+ }
+ num => Ok(num),
+ }
+ }
+}
+
+// TODO: when implementing Creation-Time Expressions, remove the ability to match the minus sign
+
+pub(in crate::front::wgsl) fn consume_number(input: &str) -> (Token<'_>, &str) {
+ let (result, rest) = parse(input);
+ (
+ Token::Number(result.and_then(Number::abstract_to_concrete)),
+ rest,
+ )
+}
+
+enum Kind {
+ Int(IntKind),
+ Float(FloatKind),
+}
+
+enum IntKind {
+ I32,
+ U32,
+}
+
+enum FloatKind {
+ F32,
+ F16,
+}
+
+// The following regexes (from the WGSL spec) will be matched:
+
+// int_literal:
+// | / 0 [iu]? /
+// | / [1-9][0-9]* [iu]? /
+// | / 0[xX][0-9a-fA-F]+ [iu]? /
+
+// decimal_float_literal:
+// | / 0 [fh] /
+// | / [1-9][0-9]* [fh] /
+// | / [0-9]* \.[0-9]+ ([eE][+-]?[0-9]+)? [fh]? /
+// | / [0-9]+ \.[0-9]* ([eE][+-]?[0-9]+)? [fh]? /
+// | / [0-9]+ [eE][+-]?[0-9]+ [fh]? /
+
+// hex_float_literal:
+// | / 0[xX][0-9a-fA-F]* \.[0-9a-fA-F]+ ([pP][+-]?[0-9]+ [fh]?)? /
+// | / 0[xX][0-9a-fA-F]+ \.[0-9a-fA-F]* ([pP][+-]?[0-9]+ [fh]?)? /
+// | / 0[xX][0-9a-fA-F]+ [pP][+-]?[0-9]+ [fh]? /
+
+// You could visualize the regex below via https://debuggex.com to get a rough idea what `parse` is doing
+// -?(?:0[xX](?:([0-9a-fA-F]+\.[0-9a-fA-F]*|[0-9a-fA-F]*\.[0-9a-fA-F]+)(?:([pP][+-]?[0-9]+)([fh]?))?|([0-9a-fA-F]+)([pP][+-]?[0-9]+)([fh]?)|([0-9a-fA-F]+)([iu]?))|((?:[0-9]+[eE][+-]?[0-9]+|(?:[0-9]+\.[0-9]*|[0-9]*\.[0-9]+)(?:[eE][+-]?[0-9]+)?))([fh]?)|((?:[0-9]|[1-9][0-9]+))([iufh]?))
+
+fn parse(input: &str) -> (Result<Number, NumberError>, &str) {
+ /// returns `true` and consumes `X` bytes from the given byte buffer
+ /// if the given `X` nr of patterns are found at the start of the buffer
+ macro_rules! consume {
+ ($bytes:ident, $($pattern:pat),*) => {
+ match $bytes {
+ &[$($pattern),*, ref rest @ ..] => { $bytes = rest; true },
+ _ => false,
+ }
+ };
+ }
+
+ /// consumes one byte from the given byte buffer
+ /// if one of the given patterns are found at the start of the buffer
+ /// returning the corresponding expr for the matched pattern
+ macro_rules! consume_map {
+ ($bytes:ident, [$($pattern:pat_param => $to:expr),*]) => {
+ match $bytes {
+ $( &[$pattern, ref rest @ ..] => { $bytes = rest; Some($to) }, )*
+ _ => None,
+ }
+ };
+ }
+
+ /// consumes all consecutive bytes matched by the `0-9` pattern from the given byte buffer
+ /// returning the number of consumed bytes
+ macro_rules! consume_dec_digits {
+ ($bytes:ident) => {{
+ let start_len = $bytes.len();
+ while let &[b'0'..=b'9', ref rest @ ..] = $bytes {
+ $bytes = rest;
+ }
+ start_len - $bytes.len()
+ }};
+ }
+
+ /// consumes all consecutive bytes matched by the `0-9 | a-f | A-F` pattern from the given byte buffer
+ /// returning the number of consumed bytes
+ macro_rules! consume_hex_digits {
+ ($bytes:ident) => {{
+ let start_len = $bytes.len();
+ while let &[b'0'..=b'9' | b'a'..=b'f' | b'A'..=b'F', ref rest @ ..] = $bytes {
+ $bytes = rest;
+ }
+ start_len - $bytes.len()
+ }};
+ }
+
+ /// maps the given `&[u8]` (tail of the initial `input: &str`) to a `&str`
+ macro_rules! rest_to_str {
+ ($bytes:ident) => {
+ &input[input.len() - $bytes.len()..]
+ };
+ }
+
+ struct ExtractSubStr<'a>(&'a str);
+
+ impl<'a> ExtractSubStr<'a> {
+ /// given an `input` and a `start` (tail of the `input`)
+ /// creates a new [`ExtractSubStr`](`Self`)
+ fn start(input: &'a str, start: &'a [u8]) -> Self {
+ let start = input.len() - start.len();
+ Self(&input[start..])
+ }
+ /// given an `end` (tail of the initial `input`)
+ /// returns a substring of `input`
+ fn end(&self, end: &'a [u8]) -> &'a str {
+ let end = self.0.len() - end.len();
+ &self.0[..end]
+ }
+ }
+
+ let mut bytes = input.as_bytes();
+
+ let general_extract = ExtractSubStr::start(input, bytes);
+
+ let is_negative = consume!(bytes, b'-');
+
+ if consume!(bytes, b'0', b'x' | b'X') {
+ let digits_extract = ExtractSubStr::start(input, bytes);
+
+ let consumed = consume_hex_digits!(bytes);
+
+ if consume!(bytes, b'.') {
+ let consumed_after_period = consume_hex_digits!(bytes);
+
+ if consumed + consumed_after_period == 0 {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+
+ let significand = general_extract.end(bytes);
+
+ if consume!(bytes, b'p' | b'P') {
+ consume!(bytes, b'+' | b'-');
+ let consumed = consume_dec_digits!(bytes);
+
+ if consumed == 0 {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+
+ let number = general_extract.end(bytes);
+
+ let kind = consume_map!(bytes, [b'f' => FloatKind::F32, b'h' => FloatKind::F16]);
+
+ (parse_hex_float(number, kind), rest_to_str!(bytes))
+ } else {
+ (
+ parse_hex_float_missing_exponent(significand, None),
+ rest_to_str!(bytes),
+ )
+ }
+ } else {
+ if consumed == 0 {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+
+ let significand = general_extract.end(bytes);
+ let digits = digits_extract.end(bytes);
+
+ let exp_extract = ExtractSubStr::start(input, bytes);
+
+ if consume!(bytes, b'p' | b'P') {
+ consume!(bytes, b'+' | b'-');
+ let consumed = consume_dec_digits!(bytes);
+
+ if consumed == 0 {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+
+ let exponent = exp_extract.end(bytes);
+
+ let kind = consume_map!(bytes, [b'f' => FloatKind::F32, b'h' => FloatKind::F16]);
+
+ (
+ parse_hex_float_missing_period(significand, exponent, kind),
+ rest_to_str!(bytes),
+ )
+ } else {
+ let kind = consume_map!(bytes, [b'i' => IntKind::I32, b'u' => IntKind::U32]);
+
+ (
+ parse_hex_int(is_negative, digits, kind),
+ rest_to_str!(bytes),
+ )
+ }
+ }
+ } else {
+ let is_first_zero = bytes.first() == Some(&b'0');
+
+ let consumed = consume_dec_digits!(bytes);
+
+ if consume!(bytes, b'.') {
+ let consumed_after_period = consume_dec_digits!(bytes);
+
+ if consumed + consumed_after_period == 0 {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+
+ if consume!(bytes, b'e' | b'E') {
+ consume!(bytes, b'+' | b'-');
+ let consumed = consume_dec_digits!(bytes);
+
+ if consumed == 0 {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+ }
+
+ let number = general_extract.end(bytes);
+
+ let kind = consume_map!(bytes, [b'f' => FloatKind::F32, b'h' => FloatKind::F16]);
+
+ (parse_dec_float(number, kind), rest_to_str!(bytes))
+ } else {
+ if consumed == 0 {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+
+ if consume!(bytes, b'e' | b'E') {
+ consume!(bytes, b'+' | b'-');
+ let consumed = consume_dec_digits!(bytes);
+
+ if consumed == 0 {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+
+ let number = general_extract.end(bytes);
+
+ let kind = consume_map!(bytes, [b'f' => FloatKind::F32, b'h' => FloatKind::F16]);
+
+ (parse_dec_float(number, kind), rest_to_str!(bytes))
+ } else {
+ // make sure the multi-digit numbers don't start with zero
+ if consumed > 1 && is_first_zero {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+
+ let digits_with_sign = general_extract.end(bytes);
+
+ let kind = consume_map!(bytes, [
+ b'i' => Kind::Int(IntKind::I32),
+ b'u' => Kind::Int(IntKind::U32),
+ b'f' => Kind::Float(FloatKind::F32),
+ b'h' => Kind::Float(FloatKind::F16)
+ ]);
+
+ (
+ parse_dec(is_negative, digits_with_sign, kind),
+ rest_to_str!(bytes),
+ )
+ }
+ }
+ }
+}
+
+fn parse_hex_float_missing_exponent(
+ // format: -?0[xX] ( [0-9a-fA-F]+\.[0-9a-fA-F]* | [0-9a-fA-F]*\.[0-9a-fA-F]+ )
+ significand: &str,
+ kind: Option<FloatKind>,
+) -> Result<Number, NumberError> {
+ let hexf_input = format!("{}{}", significand, "p0");
+ parse_hex_float(&hexf_input, kind)
+}
+
+fn parse_hex_float_missing_period(
+ // format: -?0[xX] [0-9a-fA-F]+
+ significand: &str,
+ // format: [pP][+-]?[0-9]+
+ exponent: &str,
+ kind: Option<FloatKind>,
+) -> Result<Number, NumberError> {
+ let hexf_input = format!("{significand}.{exponent}");
+ parse_hex_float(&hexf_input, kind)
+}
+
+fn parse_hex_int(
+ is_negative: bool,
+ // format: [0-9a-fA-F]+
+ digits: &str,
+ kind: Option<IntKind>,
+) -> Result<Number, NumberError> {
+ let digits_with_sign = if is_negative {
+ Cow::Owned(format!("-{digits}"))
+ } else {
+ Cow::Borrowed(digits)
+ };
+ parse_int(&digits_with_sign, kind, 16, is_negative)
+}
+
+fn parse_dec(
+ is_negative: bool,
+ // format: -? ( [0-9] | [1-9][0-9]+ )
+ digits_with_sign: &str,
+ kind: Option<Kind>,
+) -> Result<Number, NumberError> {
+ match kind {
+ None => parse_int(digits_with_sign, None, 10, is_negative),
+ Some(Kind::Int(kind)) => parse_int(digits_with_sign, Some(kind), 10, is_negative),
+ Some(Kind::Float(kind)) => parse_dec_float(digits_with_sign, Some(kind)),
+ }
+}
+
+// Float parsing notes
+
+// The following chapters of IEEE 754-2019 are relevant:
+//
+// 7.4 Overflow (largest finite number is exceeded by what would have been
+// the rounded floating-point result were the exponent range unbounded)
+//
+// 7.5 Underflow (tiny non-zero result is detected;
+// for decimal formats tininess is detected before rounding when a non-zero result
+// computed as though both the exponent range and the precision were unbounded
+// would lie strictly between 2^−126)
+//
+// 7.6 Inexact (rounded result differs from what would have been computed
+// were both exponent range and precision unbounded)
+
+// The WGSL spec requires us to error:
+// on overflow for decimal floating point literals
+// on overflow and inexact for hexadecimal floating point literals
+// (underflow is not mentioned)
+
+// hexf_parse errors on overflow, underflow, inexact
+// rust std lib float from str handles overflow, underflow, inexact transparently (rounds and will not error)
+
+// Therefore we only check for overflow manually for decimal floating point literals
+
+// input format: -?0[xX] ( [0-9a-fA-F]+\.[0-9a-fA-F]* | [0-9a-fA-F]*\.[0-9a-fA-F]+ ) [pP][+-]?[0-9]+
+fn parse_hex_float(input: &str, kind: Option<FloatKind>) -> Result<Number, NumberError> {
+ match kind {
+ None => match hexf_parse::parse_hexf64(input, false) {
+ Ok(num) => Ok(Number::AbstractFloat(num)),
+ // can only be ParseHexfErrorKind::Inexact but we can't check since it's private
+ _ => Err(NumberError::NotRepresentable),
+ },
+ Some(FloatKind::F32) => match hexf_parse::parse_hexf32(input, false) {
+ Ok(num) => Ok(Number::F32(num)),
+ // can only be ParseHexfErrorKind::Inexact but we can't check since it's private
+ _ => Err(NumberError::NotRepresentable),
+ },
+ Some(FloatKind::F16) => Err(NumberError::UnimplementedF16),
+ }
+}
+
+// input format: -? ( [0-9]+\.[0-9]* | [0-9]*\.[0-9]+ ) ([eE][+-]?[0-9]+)?
+// | -? [0-9]+ [eE][+-]?[0-9]+
+fn parse_dec_float(input: &str, kind: Option<FloatKind>) -> Result<Number, NumberError> {
+ match kind {
+ None => {
+ let num = input.parse::<f64>().unwrap(); // will never fail
+ num.is_finite()
+ .then_some(Number::AbstractFloat(num))
+ .ok_or(NumberError::NotRepresentable)
+ }
+ Some(FloatKind::F32) => {
+ let num = input.parse::<f32>().unwrap(); // will never fail
+ num.is_finite()
+ .then_some(Number::F32(num))
+ .ok_or(NumberError::NotRepresentable)
+ }
+ Some(FloatKind::F16) => Err(NumberError::UnimplementedF16),
+ }
+}
+
+fn parse_int(
+ input: &str,
+ kind: Option<IntKind>,
+ radix: u32,
+ is_negative: bool,
+) -> Result<Number, NumberError> {
+ fn map_err(e: core::num::ParseIntError) -> NumberError {
+ match *e.kind() {
+ core::num::IntErrorKind::PosOverflow | core::num::IntErrorKind::NegOverflow => {
+ NumberError::NotRepresentable
+ }
+ _ => unreachable!(),
+ }
+ }
+ match kind {
+ None => match i64::from_str_radix(input, radix) {
+ Ok(num) => Ok(Number::AbstractInt(num)),
+ Err(e) => Err(map_err(e)),
+ },
+ Some(IntKind::I32) => match i32::from_str_radix(input, radix) {
+ Ok(num) => Ok(Number::I32(num)),
+ Err(e) => Err(map_err(e)),
+ },
+ Some(IntKind::U32) if is_negative => Err(NumberError::NotRepresentable),
+ Some(IntKind::U32) => match u32::from_str_radix(input, radix) {
+ Ok(num) => Ok(Number::U32(num)),
+ Err(e) => Err(map_err(e)),
+ },
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/tests.rs b/third_party/rust/naga/src/front/wgsl/tests.rs
new file mode 100644
index 0000000000..31b5890707
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/tests.rs
@@ -0,0 +1,483 @@
+use super::parse_str;
+
+#[test]
+fn parse_comment() {
+ parse_str(
+ "//
+ ////
+ ///////////////////////////////////////////////////////// asda
+ //////////////////// dad ////////// /
+ /////////////////////////////////////////////////////////////////////////////////////////////////////
+ //
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_types() {
+ parse_str("const a : i32 = 2;").unwrap();
+ assert!(parse_str("const a : x32 = 2;").is_err());
+ parse_str("var t: texture_2d<f32>;").unwrap();
+ parse_str("var t: texture_cube_array<i32>;").unwrap();
+ parse_str("var t: texture_multisampled_2d<u32>;").unwrap();
+ parse_str("var t: texture_storage_1d<rgba8uint,write>;").unwrap();
+ parse_str("var t: texture_storage_3d<r32float,read>;").unwrap();
+}
+
+#[test]
+fn parse_type_inference() {
+ parse_str(
+ "
+ fn foo() {
+ let a = 2u;
+ let b: u32 = a;
+ var x = 3.;
+ var y = vec2<f32>(1, 2);
+ }",
+ )
+ .unwrap();
+ assert!(parse_str(
+ "
+ fn foo() { let c : i32 = 2.0; }",
+ )
+ .is_err());
+}
+
+#[test]
+fn parse_type_cast() {
+ parse_str(
+ "
+ const a : i32 = 2;
+ fn main() {
+ var x: f32 = f32(a);
+ x = f32(i32(a + 1) / 2);
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ fn main() {
+ let x: vec2<f32> = vec2<f32>(1.0, 2.0);
+ let y: vec2<u32> = vec2<u32>(x);
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ fn main() {
+ let x: vec2<f32> = vec2<f32>(0.0);
+ }
+ ",
+ )
+ .unwrap();
+ assert!(parse_str(
+ "
+ fn main() {
+ let x: vec2<f32> = vec2<f32>(0);
+ }
+ ",
+ )
+ .is_err());
+}
+
+#[test]
+fn parse_struct() {
+ parse_str(
+ "
+ struct Foo { x: i32 }
+ struct Bar {
+ @size(16) x: vec2<i32>,
+ @align(16) y: f32,
+ @size(32) @align(128) z: vec3<f32>,
+ };
+ struct Empty {}
+ var<storage,read_write> s: Foo;
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_standard_fun() {
+ parse_str(
+ "
+ fn main() {
+ var x: i32 = min(max(1, 2), 3);
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_statement() {
+ parse_str(
+ "
+ fn main() {
+ ;
+ {}
+ {;}
+ }
+ ",
+ )
+ .unwrap();
+
+ parse_str(
+ "
+ fn foo() {}
+ fn bar() { foo(); }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_if() {
+ parse_str(
+ "
+ fn main() {
+ if true {
+ discard;
+ } else {}
+ if 0 != 1 {}
+ if false {
+ return;
+ } else if true {
+ return;
+ } else {}
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_parentheses_if() {
+ parse_str(
+ "
+ fn main() {
+ if (true) {
+ discard;
+ } else {}
+ if (0 != 1) {}
+ if (false) {
+ return;
+ } else if (true) {
+ return;
+ } else {}
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_loop() {
+ parse_str(
+ "
+ fn main() {
+ var i: i32 = 0;
+ loop {
+ if i == 1 { break; }
+ continuing { i = 1; }
+ }
+ loop {
+ if i == 0 { continue; }
+ break;
+ }
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ fn main() {
+ var found: bool = false;
+ var i: i32 = 0;
+ while !found {
+ if i == 10 {
+ found = true;
+ }
+
+ i = i + 1;
+ }
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ fn main() {
+ while true {
+ break;
+ }
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ fn main() {
+ var a: i32 = 0;
+ for(var i: i32 = 0; i < 4; i = i + 1) {
+ a = a + 2;
+ }
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ fn main() {
+ for(;;) {
+ break;
+ }
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_switch() {
+ parse_str(
+ "
+ fn main() {
+ var pos: f32;
+ switch (3) {
+ case 0, 1: { pos = 0.0; }
+ case 2: { pos = 1.0; }
+ default: { pos = 3.0; }
+ }
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_switch_optional_colon_in_case() {
+ parse_str(
+ "
+ fn main() {
+ var pos: f32;
+ switch (3) {
+ case 0, 1 { pos = 0.0; }
+ case 2 { pos = 1.0; }
+ default { pos = 3.0; }
+ }
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_switch_default_in_case() {
+ parse_str(
+ "
+ fn main() {
+ var pos: f32;
+ switch (3) {
+ case 0, 1: { pos = 0.0; }
+ case 2: {}
+ case default, 3: { pos = 3.0; }
+ }
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_parentheses_switch() {
+ parse_str(
+ "
+ fn main() {
+ var pos: f32;
+ switch pos > 1.0 {
+ default: { pos = 3.0; }
+ }
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_texture_load() {
+ parse_str(
+ "
+ var t: texture_3d<u32>;
+ fn foo() {
+ let r: vec4<u32> = textureLoad(t, vec3<u32>(0.0, 1.0, 2.0), 1);
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ var t: texture_multisampled_2d_array<i32>;
+ fn foo() {
+ let r: vec4<i32> = textureLoad(t, vec2<i32>(10, 20), 2, 3);
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ var t: texture_storage_1d_array<r32float,read>;
+ fn foo() {
+ let r: vec4<f32> = textureLoad(t, 10, 2);
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_texture_store() {
+ parse_str(
+ "
+ var t: texture_storage_2d<rgba8unorm,write>;
+ fn foo() {
+ textureStore(t, vec2<i32>(10, 20), vec4<f32>(0.0, 1.0, 2.0, 3.0));
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_texture_query() {
+ parse_str(
+ "
+ var t: texture_multisampled_2d_array<f32>;
+ fn foo() {
+ var dim: vec2<u32> = textureDimensions(t);
+ dim = textureDimensions(t, 0);
+ let layers: u32 = textureNumLayers(t);
+ let samples: u32 = textureNumSamples(t);
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_postfix() {
+ parse_str(
+ "fn foo() {
+ let x: f32 = vec4<f32>(1.0, 2.0, 3.0, 4.0).xyz.rgbr.aaaa.wz.g;
+ let y: f32 = fract(vec2<f32>(0.5, x)).x;
+ }",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_expressions() {
+ parse_str("fn foo() {
+ let x: f32 = select(0.0, 1.0, true);
+ let y: vec2<f32> = select(vec2<f32>(1.0, 1.0), vec2<f32>(x, x), vec2<bool>(x < 0.5, x > 0.5));
+ let z: bool = !(0.0 == 1.0);
+ }").unwrap();
+}
+
+#[test]
+fn parse_pointers() {
+ parse_str(
+ "fn foo() {
+ var x: f32 = 1.0;
+ let px = &x;
+ let py = frexp(0.5, px);
+ }",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_struct_instantiation() {
+ parse_str(
+ "
+ struct Foo {
+ a: f32,
+ b: vec3<f32>,
+ }
+
+ @fragment
+ fn fs_main() {
+ var foo: Foo = Foo(0.0, vec3<f32>(0.0, 1.0, 42.0));
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_array_length() {
+ parse_str(
+ "
+ struct Foo {
+ data: array<u32>
+ } // this is used as both input and output for convenience
+
+ @group(0) @binding(0)
+ var<storage> foo: Foo;
+
+ @group(0) @binding(1)
+ var<storage> bar: array<u32>;
+
+ fn baz() {
+ var x: u32 = arrayLength(foo.data);
+ var y: u32 = arrayLength(bar);
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_storage_buffers() {
+ parse_str(
+ "
+ @group(0) @binding(0)
+ var<storage> foo: array<u32>;
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ @group(0) @binding(0)
+ var<storage,read> foo: array<u32>;
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ @group(0) @binding(0)
+ var<storage,write> foo: array<u32>;
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ @group(0) @binding(0)
+ var<storage,read_write> foo: array<u32>;
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_alias() {
+ parse_str(
+ "
+ alias Vec4 = vec4<f32>;
+ ",
+ )
+ .unwrap();
+}