From 26a029d407be480d791972afb5975cf62c9360a6 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 19 Apr 2024 02:47:55 +0200 Subject: Adding upstream version 124.0.1. Signed-off-by: Daniel Baumann --- .../rust/naga/src/front/wgsl/lower/construction.rs | 616 +++++ .../rust/naga/src/front/wgsl/lower/conversion.rs | 503 ++++ third_party/rust/naga/src/front/wgsl/lower/mod.rs | 2760 ++++++++++++++++++++ 3 files changed, 3879 insertions(+) create mode 100644 third_party/rust/naga/src/front/wgsl/lower/construction.rs create mode 100644 third_party/rust/naga/src/front/wgsl/lower/conversion.rs create mode 100644 third_party/rust/naga/src/front/wgsl/lower/mod.rs (limited to 'third_party/rust/naga/src/front/wgsl/lower') diff --git a/third_party/rust/naga/src/front/wgsl/lower/construction.rs b/third_party/rust/naga/src/front/wgsl/lower/construction.rs new file mode 100644 index 0000000000..de0d11d227 --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/lower/construction.rs @@ -0,0 +1,616 @@ +use std::num::NonZeroU32; + +use crate::front::wgsl::parse::ast; +use crate::{Handle, Span}; + +use crate::front::wgsl::error::Error; +use crate::front::wgsl::lower::{ExpressionContext, Lowerer}; + +/// A cooked form of `ast::ConstructorType` that uses Naga types whenever +/// possible. +enum Constructor { + /// A vector construction whose component type is inferred from the + /// argument: `vec3(1.0)`. + PartialVector { size: crate::VectorSize }, + + /// A matrix construction whose component type is inferred from the + /// argument: `mat2x2(1,2,3,4)`. + PartialMatrix { + columns: crate::VectorSize, + rows: crate::VectorSize, + }, + + /// An array whose component type and size are inferred from the arguments: + /// `array(3,4,5)`. + PartialArray, + + /// A known Naga type. + /// + /// When we match on this type, we need to see the `TypeInner` here, but at + /// the point that we build this value we'll still need mutable access to + /// the module later. To avoid borrowing from the module, the type parameter + /// `T` is `Handle` initially. Then we use `borrow_inner` to produce a + /// version holding a tuple `(Handle, &TypeInner)`. + Type(T), +} + +impl Constructor> { + /// Return an equivalent `Constructor` value that includes borrowed + /// `TypeInner` values alongside any type handles. + /// + /// The returned form is more convenient to match on, since the patterns + /// can actually see what the handle refers to. + fn borrow_inner( + self, + module: &crate::Module, + ) -> Constructor<(Handle, &crate::TypeInner)> { + match self { + Constructor::PartialVector { size } => Constructor::PartialVector { size }, + Constructor::PartialMatrix { columns, rows } => { + Constructor::PartialMatrix { columns, rows } + } + Constructor::PartialArray => Constructor::PartialArray, + Constructor::Type(handle) => Constructor::Type((handle, &module.types[handle].inner)), + } + } +} + +impl Constructor<(Handle, &crate::TypeInner)> { + fn to_error_string(&self, ctx: &ExpressionContext) -> String { + match *self { + Self::PartialVector { size } => { + format!("vec{}", size as u32,) + } + Self::PartialMatrix { columns, rows } => { + format!("mat{}x{}", columns as u32, rows as u32,) + } + Self::PartialArray => "array".to_string(), + Self::Type((handle, _inner)) => handle.to_wgsl(&ctx.module.to_ctx()), + } + } +} + +enum Components<'a> { + None, + One { + component: Handle, + span: Span, + ty_inner: &'a crate::TypeInner, + }, + Many { + components: Vec>, + spans: Vec, + }, +} + +impl Components<'_> { + fn into_components_vec(self) -> Vec> { + match self { + Self::None => vec![], + Self::One { component, .. } => vec![component], + Self::Many { components, .. } => components, + } + } +} + +impl<'source, 'temp> Lowerer<'source, 'temp> { + /// Generate Naga IR for a type constructor expression. + /// + /// The `constructor` value represents the head of the constructor + /// expression, which is at least a hint of which type is being built; if + /// it's one of the `Partial` variants, we need to consider the argument + /// types as well. + /// + /// This is used for [`Construct`] expressions, but also for [`Call`] + /// expressions, once we've determined that the "callable" (in WGSL spec + /// terms) is actually a type. + /// + /// [`Construct`]: ast::Expression::Construct + /// [`Call`]: ast::Expression::Call + pub fn construct( + &mut self, + span: Span, + constructor: &ast::ConstructorType<'source>, + ty_span: Span, + components: &[Handle>], + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result, Error<'source>> { + use crate::proc::TypeResolution as Tr; + + let constructor_h = self.constructor(constructor, ctx)?; + + let components = match *components { + [] => Components::None, + [component] => { + let span = ctx.ast_expressions.get_span(component); + let component = self.expression_for_abstract(component, ctx)?; + let ty_inner = super::resolve_inner!(ctx, component); + + Components::One { + component, + span, + ty_inner, + } + } + ref ast_components @ [_, _, ..] => { + let components = ast_components + .iter() + .map(|&expr| self.expression_for_abstract(expr, ctx)) + .collect::>()?; + let spans = ast_components + .iter() + .map(|&expr| ctx.ast_expressions.get_span(expr)) + .collect(); + + for &component in &components { + ctx.grow_types(component)?; + } + + Components::Many { components, spans } + } + }; + + // Even though we computed `constructor` above, wait until now to borrow + // a reference to the `TypeInner`, so that the component-handling code + // above can have mutable access to the type arena. + let constructor = constructor_h.borrow_inner(ctx.module); + + let expr; + match (components, constructor) { + // Empty constructor + (Components::None, dst_ty) => match dst_ty { + Constructor::Type((result_ty, _)) => { + return ctx.append_expression(crate::Expression::ZeroValue(result_ty), span) + } + Constructor::PartialVector { .. } + | Constructor::PartialMatrix { .. } + | Constructor::PartialArray => { + // We have no arguments from which to infer the result type, so + // partial constructors aren't acceptable here. + return Err(Error::TypeNotInferable(ty_span)); + } + }, + + // Scalar constructor & conversion (scalar -> scalar) + ( + Components::One { + component, + ty_inner: &crate::TypeInner::Scalar { .. }, + .. + }, + Constructor::Type((_, &crate::TypeInner::Scalar(scalar))), + ) => { + expr = crate::Expression::As { + expr: component, + kind: scalar.kind, + convert: Some(scalar.width), + }; + } + + // Vector conversion (vector -> vector) + ( + Components::One { + component, + ty_inner: &crate::TypeInner::Vector { size: src_size, .. }, + .. + }, + Constructor::Type(( + _, + &crate::TypeInner::Vector { + size: dst_size, + scalar: dst_scalar, + }, + )), + ) if dst_size == src_size => { + expr = crate::Expression::As { + expr: component, + kind: dst_scalar.kind, + convert: Some(dst_scalar.width), + }; + } + + // Vector conversion (vector -> vector) - partial + ( + Components::One { + component, + ty_inner: &crate::TypeInner::Vector { size: src_size, .. }, + .. + }, + Constructor::PartialVector { size: dst_size }, + ) if dst_size == src_size => { + // This is a trivial conversion: the sizes match, and a Partial + // constructor doesn't specify a scalar type, so nothing can + // possibly happen. + return Ok(component); + } + + // Matrix conversion (matrix -> matrix) + ( + Components::One { + component, + ty_inner: + &crate::TypeInner::Matrix { + columns: src_columns, + rows: src_rows, + .. + }, + .. + }, + Constructor::Type(( + _, + &crate::TypeInner::Matrix { + columns: dst_columns, + rows: dst_rows, + scalar: dst_scalar, + }, + )), + ) if dst_columns == src_columns && dst_rows == src_rows => { + expr = crate::Expression::As { + expr: component, + kind: dst_scalar.kind, + convert: Some(dst_scalar.width), + }; + } + + // Matrix conversion (matrix -> matrix) - partial + ( + Components::One { + component, + ty_inner: + &crate::TypeInner::Matrix { + columns: src_columns, + rows: src_rows, + .. + }, + .. + }, + Constructor::PartialMatrix { + columns: dst_columns, + rows: dst_rows, + }, + ) if dst_columns == src_columns && dst_rows == src_rows => { + // This is a trivial conversion: the sizes match, and a Partial + // constructor doesn't specify a scalar type, so nothing can + // possibly happen. + return Ok(component); + } + + // Vector constructor (splat) - infer type + ( + Components::One { + component, + ty_inner: &crate::TypeInner::Scalar { .. }, + .. + }, + Constructor::PartialVector { size }, + ) => { + expr = crate::Expression::Splat { + size, + value: component, + }; + } + + // Vector constructor (splat) + ( + Components::One { + mut component, + ty_inner: &crate::TypeInner::Scalar(_), + .. + }, + Constructor::Type((_, &crate::TypeInner::Vector { size, scalar })), + ) => { + ctx.convert_slice_to_common_leaf_scalar( + std::slice::from_mut(&mut component), + scalar, + )?; + expr = crate::Expression::Splat { + size, + value: component, + }; + } + + // Vector constructor (by elements), partial + ( + Components::Many { + mut components, + spans, + }, + Constructor::PartialVector { size }, + ) => { + let consensus_scalar = + ctx.automatic_conversion_consensus(&components) + .map_err(|index| { + Error::InvalidConstructorComponentType(spans[index], index as i32) + })?; + ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?; + let inner = consensus_scalar.to_inner_vector(size); + let ty = ctx.ensure_type_exists(inner); + expr = crate::Expression::Compose { ty, components }; + } + + // Vector constructor (by elements), full type given + ( + Components::Many { mut components, .. }, + Constructor::Type((ty, &crate::TypeInner::Vector { scalar, .. })), + ) => { + ctx.try_automatic_conversions_for_vector(&mut components, scalar, ty_span)?; + expr = crate::Expression::Compose { ty, components }; + } + + // Matrix constructor (by elements), partial + ( + Components::Many { + mut components, + spans, + }, + Constructor::PartialMatrix { columns, rows }, + ) if components.len() == columns as usize * rows as usize => { + let consensus_scalar = + ctx.automatic_conversion_consensus(&components) + .map_err(|index| { + Error::InvalidConstructorComponentType(spans[index], index as i32) + })?; + // We actually only accept floating-point elements. + let consensus_scalar = consensus_scalar + .automatic_conversion_combine(crate::Scalar::ABSTRACT_FLOAT) + .unwrap_or(consensus_scalar); + ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?; + let vec_ty = ctx.ensure_type_exists(consensus_scalar.to_inner_vector(rows)); + + let components = components + .chunks(rows as usize) + .map(|vec_components| { + ctx.append_expression( + crate::Expression::Compose { + ty: vec_ty, + components: Vec::from(vec_components), + }, + Default::default(), + ) + }) + .collect::, _>>()?; + + let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix { + columns, + rows, + scalar: consensus_scalar, + }); + expr = crate::Expression::Compose { ty, components }; + } + + // Matrix constructor (by elements), type given + ( + Components::Many { mut components, .. }, + Constructor::Type(( + _, + &crate::TypeInner::Matrix { + columns, + rows, + scalar, + }, + )), + ) if components.len() == columns as usize * rows as usize => { + let element = Tr::Value(crate::TypeInner::Scalar(scalar)); + ctx.try_automatic_conversions_slice(&mut components, &element, ty_span)?; + let vec_ty = ctx.ensure_type_exists(scalar.to_inner_vector(rows)); + + let components = components + .chunks(rows as usize) + .map(|vec_components| { + ctx.append_expression( + crate::Expression::Compose { + ty: vec_ty, + components: Vec::from(vec_components), + }, + Default::default(), + ) + }) + .collect::, _>>()?; + + let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix { + columns, + rows, + scalar, + }); + expr = crate::Expression::Compose { ty, components }; + } + + // Matrix constructor (by columns), partial + ( + Components::Many { + mut components, + spans, + }, + Constructor::PartialMatrix { columns, rows }, + ) => { + let consensus_scalar = + ctx.automatic_conversion_consensus(&components) + .map_err(|index| { + Error::InvalidConstructorComponentType(spans[index], index as i32) + })?; + ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?; + let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix { + columns, + rows, + scalar: consensus_scalar, + }); + expr = crate::Expression::Compose { ty, components }; + } + + // Matrix constructor (by columns), type given + ( + Components::Many { mut components, .. }, + Constructor::Type(( + ty, + &crate::TypeInner::Matrix { + columns: _, + rows, + scalar, + }, + )), + ) => { + let component_ty = crate::TypeInner::Vector { size: rows, scalar }; + ctx.try_automatic_conversions_slice( + &mut components, + &Tr::Value(component_ty), + ty_span, + )?; + expr = crate::Expression::Compose { ty, components }; + } + + // Array constructor - infer type + (components, Constructor::PartialArray) => { + let mut components = components.into_components_vec(); + if let Ok(consensus_scalar) = ctx.automatic_conversion_consensus(&components) { + // Note that this will *not* necessarily convert all the + // components to the same type! The `automatic_conversion_consensus` + // method only considers the parameters' leaf scalar + // types; the parameters themselves could be any mix of + // vectors, matrices, and scalars. + // + // But *if* it is possible for this array construction + // expression to be well-typed at all, then all the + // parameters must have the same type constructors (vec, + // matrix, scalar) applied to their leaf scalars, so + // reconciling their scalars is always the right thing to + // do. And if this array construction is not well-typed, + // these conversions will not make it so, and we can let + // validation catch the error. + ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?; + } else { + // There's no consensus scalar. Emit the `Compose` + // expression anyway, and let validation catch the problem. + } + + let base = ctx.register_type(components[0])?; + + let inner = crate::TypeInner::Array { + base, + size: crate::ArraySize::Constant( + NonZeroU32::new(u32::try_from(components.len()).unwrap()).unwrap(), + ), + stride: { + self.layouter.update(ctx.module.to_ctx()).unwrap(); + self.layouter[base].to_stride() + }, + }; + let ty = ctx.ensure_type_exists(inner); + + expr = crate::Expression::Compose { ty, components }; + } + + // Array constructor, explicit type + (components, Constructor::Type((ty, &crate::TypeInner::Array { base, .. }))) => { + let mut components = components.into_components_vec(); + ctx.try_automatic_conversions_slice(&mut components, &Tr::Handle(base), ty_span)?; + expr = crate::Expression::Compose { ty, components }; + } + + // Struct constructor + ( + components, + Constructor::Type((ty, &crate::TypeInner::Struct { ref members, .. })), + ) => { + let mut components = components.into_components_vec(); + let struct_ty_span = ctx.module.types.get_span(ty); + + // Make a vector of the members' type handles in advance, to + // avoid borrowing `members` from `ctx` while we generate + // new code. + let members: Vec> = members.iter().map(|m| m.ty).collect(); + + for (component, &ty) in components.iter_mut().zip(&members) { + *component = + ctx.try_automatic_conversions(*component, &Tr::Handle(ty), struct_ty_span)?; + } + expr = crate::Expression::Compose { ty, components }; + } + + // ERRORS + + // Bad conversion (type cast) + (Components::One { span, ty_inner, .. }, constructor) => { + let from_type = ty_inner.to_wgsl(&ctx.module.to_ctx()); + return Err(Error::BadTypeCast { + span, + from_type, + to_type: constructor.to_error_string(ctx), + }); + } + + // Too many parameters for scalar constructor + ( + Components::Many { spans, .. }, + Constructor::Type((_, &crate::TypeInner::Scalar { .. })), + ) => { + let span = spans[1].until(spans.last().unwrap()); + return Err(Error::UnexpectedComponents(span)); + } + + // Other types can't be constructed + _ => return Err(Error::TypeNotConstructible(ty_span)), + } + + let expr = ctx.append_expression(expr, span)?; + Ok(expr) + } + + /// Build a [`Constructor`] for a WGSL construction expression. + /// + /// If `constructor` conveys enough information to determine which Naga [`Type`] + /// we're actually building (i.e., it's not a partial constructor), then + /// ensure the `Type` exists in [`ctx.module`], and return + /// [`Constructor::Type`]. + /// + /// Otherwise, return the [`Constructor`] partial variant corresponding to + /// `constructor`. + /// + /// [`Type`]: crate::Type + /// [`ctx.module`]: ExpressionContext::module + fn constructor<'out>( + &mut self, + constructor: &ast::ConstructorType<'source>, + ctx: &mut ExpressionContext<'source, '_, 'out>, + ) -> Result>, Error<'source>> { + let handle = match *constructor { + ast::ConstructorType::Scalar(scalar) => { + let ty = ctx.ensure_type_exists(scalar.to_inner_scalar()); + Constructor::Type(ty) + } + ast::ConstructorType::PartialVector { size } => Constructor::PartialVector { size }, + ast::ConstructorType::Vector { size, scalar } => { + let ty = ctx.ensure_type_exists(scalar.to_inner_vector(size)); + Constructor::Type(ty) + } + ast::ConstructorType::PartialMatrix { columns, rows } => { + Constructor::PartialMatrix { columns, rows } + } + ast::ConstructorType::Matrix { + rows, + columns, + width, + } => { + let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix { + columns, + rows, + scalar: crate::Scalar::float(width), + }); + Constructor::Type(ty) + } + ast::ConstructorType::PartialArray => Constructor::PartialArray, + ast::ConstructorType::Array { base, size } => { + let base = self.resolve_ast_type(base, &mut ctx.as_global())?; + let size = self.array_size(size, &mut ctx.as_global())?; + + self.layouter.update(ctx.module.to_ctx()).unwrap(); + let stride = self.layouter[base].to_stride(); + + let ty = ctx.ensure_type_exists(crate::TypeInner::Array { base, size, stride }); + Constructor::Type(ty) + } + ast::ConstructorType::Type(ty) => Constructor::Type(ty), + }; + + Ok(handle) + } +} diff --git a/third_party/rust/naga/src/front/wgsl/lower/conversion.rs b/third_party/rust/naga/src/front/wgsl/lower/conversion.rs new file mode 100644 index 0000000000..2a2690f096 --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/lower/conversion.rs @@ -0,0 +1,503 @@ +//! WGSL's automatic conversions for abstract types. + +use crate::{Handle, Span}; + +impl<'source, 'temp, 'out> super::ExpressionContext<'source, 'temp, 'out> { + /// Try to use WGSL's automatic conversions to convert `expr` to `goal_ty`. + /// + /// If no conversions are necessary, return `expr` unchanged. + /// + /// If automatic conversions cannot convert `expr` to `goal_ty`, return an + /// [`AutoConversion`] error. + /// + /// Although the Load Rule is one of the automatic conversions, this + /// function assumes it has already been applied if appropriate, as + /// indicated by the fact that the Rust type of `expr` is not `Typed<_>`. + /// + /// [`AutoConversion`]: super::Error::AutoConversion + pub fn try_automatic_conversions( + &mut self, + expr: Handle, + goal_ty: &crate::proc::TypeResolution, + goal_span: Span, + ) -> Result, super::Error<'source>> { + let expr_span = self.get_expression_span(expr); + // Keep the TypeResolution so we can get type names for + // structs in error messages. + let expr_resolution = super::resolve!(self, expr); + let types = &self.module.types; + let expr_inner = expr_resolution.inner_with(types); + let goal_inner = goal_ty.inner_with(types); + + // If `expr` already has the requested type, we're done. + if expr_inner.equivalent(goal_inner, types) { + return Ok(expr); + } + + let (_expr_scalar, goal_scalar) = + match expr_inner.automatically_converts_to(goal_inner, types) { + Some(scalars) => scalars, + None => { + let gctx = &self.module.to_ctx(); + let source_type = expr_resolution.to_wgsl(gctx); + let dest_type = goal_ty.to_wgsl(gctx); + + return Err(super::Error::AutoConversion { + dest_span: goal_span, + dest_type, + source_span: expr_span, + source_type, + }); + } + }; + + self.convert_leaf_scalar(expr, expr_span, goal_scalar) + } + + /// Try to convert `expr`'s leaf scalar to `goal` using automatic conversions. + /// + /// If no conversions are necessary, return `expr` unchanged. + /// + /// If automatic conversions cannot convert `expr` to `goal_scalar`, return + /// an [`AutoConversionLeafScalar`] error. + /// + /// Although the Load Rule is one of the automatic conversions, this + /// function assumes it has already been applied if appropriate, as + /// indicated by the fact that the Rust type of `expr` is not `Typed<_>`. + /// + /// [`AutoConversionLeafScalar`]: super::Error::AutoConversionLeafScalar + pub fn try_automatic_conversion_for_leaf_scalar( + &mut self, + expr: Handle, + goal_scalar: crate::Scalar, + goal_span: Span, + ) -> Result, super::Error<'source>> { + let expr_span = self.get_expression_span(expr); + let expr_resolution = super::resolve!(self, expr); + let types = &self.module.types; + let expr_inner = expr_resolution.inner_with(types); + + let make_error = || { + let gctx = &self.module.to_ctx(); + let source_type = expr_resolution.to_wgsl(gctx); + super::Error::AutoConversionLeafScalar { + dest_span: goal_span, + dest_scalar: goal_scalar.to_wgsl(), + source_span: expr_span, + source_type, + } + }; + + let expr_scalar = match expr_inner.scalar() { + Some(scalar) => scalar, + None => return Err(make_error()), + }; + + if expr_scalar == goal_scalar { + return Ok(expr); + } + + if !expr_scalar.automatically_converts_to(goal_scalar) { + return Err(make_error()); + } + + assert!(expr_scalar.is_abstract()); + + self.convert_leaf_scalar(expr, expr_span, goal_scalar) + } + + fn convert_leaf_scalar( + &mut self, + expr: Handle, + expr_span: Span, + goal_scalar: crate::Scalar, + ) -> Result, super::Error<'source>> { + let expr_inner = super::resolve_inner!(self, expr); + if let crate::TypeInner::Array { .. } = *expr_inner { + self.as_const_evaluator() + .cast_array(expr, goal_scalar, expr_span) + .map_err(|err| super::Error::ConstantEvaluatorError(err, expr_span)) + } else { + let cast = crate::Expression::As { + expr, + kind: goal_scalar.kind, + convert: Some(goal_scalar.width), + }; + self.append_expression(cast, expr_span) + } + } + + /// Try to convert `exprs` to `goal_ty` using WGSL's automatic conversions. + pub fn try_automatic_conversions_slice( + &mut self, + exprs: &mut [Handle], + goal_ty: &crate::proc::TypeResolution, + goal_span: Span, + ) -> Result<(), super::Error<'source>> { + for expr in exprs.iter_mut() { + *expr = self.try_automatic_conversions(*expr, goal_ty, goal_span)?; + } + + Ok(()) + } + + /// Apply WGSL's automatic conversions to a vector constructor's arguments. + /// + /// When calling a vector constructor like `vec3(...)`, the parameters + /// can be a mix of scalars and vectors, with the latter being spread out to + /// contribute each of their components as a component of the new value. + /// When the element type is explicit, as with `` in the example above, + /// WGSL's automatic conversions should convert abstract scalar and vector + /// parameters to the constructor's required scalar type. + pub fn try_automatic_conversions_for_vector( + &mut self, + exprs: &mut [Handle], + goal_scalar: crate::Scalar, + goal_span: Span, + ) -> Result<(), super::Error<'source>> { + use crate::proc::TypeResolution as Tr; + use crate::TypeInner as Ti; + let goal_scalar_res = Tr::Value(Ti::Scalar(goal_scalar)); + + for (i, expr) in exprs.iter_mut().enumerate() { + // Keep the TypeResolution so we can get full type names + // in error messages. + let expr_resolution = super::resolve!(self, *expr); + let types = &self.module.types; + let expr_inner = expr_resolution.inner_with(types); + + match *expr_inner { + Ti::Scalar(_) => { + *expr = self.try_automatic_conversions(*expr, &goal_scalar_res, goal_span)?; + } + Ti::Vector { size, scalar: _ } => { + let goal_vector_res = Tr::Value(Ti::Vector { + size, + scalar: goal_scalar, + }); + *expr = self.try_automatic_conversions(*expr, &goal_vector_res, goal_span)?; + } + _ => { + let span = self.get_expression_span(*expr); + return Err(super::Error::InvalidConstructorComponentType( + span, i as i32, + )); + } + } + } + + Ok(()) + } + + /// Convert `expr` to the leaf scalar type `scalar`. + pub fn convert_to_leaf_scalar( + &mut self, + expr: &mut Handle, + goal: crate::Scalar, + ) -> Result<(), super::Error<'source>> { + let inner = super::resolve_inner!(self, *expr); + // Do nothing if `inner` doesn't even have leaf scalars; + // it's a type error that validation will catch. + if inner.scalar() != Some(goal) { + let cast = crate::Expression::As { + expr: *expr, + kind: goal.kind, + convert: Some(goal.width), + }; + let expr_span = self.get_expression_span(*expr); + *expr = self.append_expression(cast, expr_span)?; + } + + Ok(()) + } + + /// Convert all expressions in `exprs` to a common scalar type. + /// + /// Note that the caller is responsible for making sure these + /// conversions are actually justified. This function simply + /// generates `As` expressions, regardless of whether they are + /// permitted WGSL automatic conversions. Callers intending to + /// implement automatic conversions need to determine for + /// themselves whether the casts we we generate are justified, + /// perhaps by calling `TypeInner::automatically_converts_to` or + /// `Scalar::automatic_conversion_combine`. + pub fn convert_slice_to_common_leaf_scalar( + &mut self, + exprs: &mut [Handle], + goal: crate::Scalar, + ) -> Result<(), super::Error<'source>> { + for expr in exprs.iter_mut() { + self.convert_to_leaf_scalar(expr, goal)?; + } + + Ok(()) + } + + /// Return an expression for the concretized value of `expr`. + /// + /// If `expr` is already concrete, return it unchanged. + pub fn concretize( + &mut self, + mut expr: Handle, + ) -> Result, super::Error<'source>> { + let inner = super::resolve_inner!(self, expr); + if let Some(scalar) = inner.automatically_convertible_scalar(&self.module.types) { + let concretized = scalar.concretize(); + if concretized != scalar { + assert!(scalar.is_abstract()); + let expr_span = self.get_expression_span(expr); + expr = self + .as_const_evaluator() + .cast_array(expr, concretized, expr_span) + .map_err(|err| { + // A `TypeResolution` includes the type's full name, if + // it has one. Also, avoid holding the borrow of `inner` + // across the call to `cast_array`. + let expr_type = &self.typifier()[expr]; + super::Error::ConcretizationFailed { + expr_span, + expr_type: expr_type.to_wgsl(&self.module.to_ctx()), + scalar: concretized.to_wgsl(), + inner: err, + } + })?; + } + } + + Ok(expr) + } + + /// Find the consensus scalar of `components` under WGSL's automatic + /// conversions. + /// + /// If `components` can all be converted to any common scalar via + /// WGSL's automatic conversions, return the best such scalar. + /// + /// The `components` slice must not be empty. All elements' types must + /// have been resolved. + /// + /// If `components` are definitely not acceptable as arguments to such + /// constructors, return `Err(i)`, where `i` is the index in + /// `components` of some problematic argument. + /// + /// This function doesn't fully type-check the arguments - it only + /// considers their leaf scalar types. This means it may return `Ok` + /// even when the Naga validator will reject the resulting + /// construction expression later. + pub fn automatic_conversion_consensus<'handle, I>( + &self, + components: I, + ) -> Result + where + I: IntoIterator>, + I::IntoIter: Clone, // for debugging + { + let types = &self.module.types; + let mut inners = components + .into_iter() + .map(|&c| self.typifier()[c].inner_with(types)); + log::debug!( + "wgsl automatic_conversion_consensus: {:?}", + inners + .clone() + .map(|inner| inner.to_wgsl(&self.module.to_ctx())) + .collect::>() + ); + let mut best = inners.next().unwrap().scalar().ok_or(0_usize)?; + for (inner, i) in inners.zip(1..) { + let scalar = inner.scalar().ok_or(i)?; + match best.automatic_conversion_combine(scalar) { + Some(new_best) => { + best = new_best; + } + None => return Err(i), + } + } + + log::debug!(" consensus: {:?}", best.to_wgsl()); + Ok(best) + } +} + +impl crate::TypeInner { + /// Determine whether `self` automatically converts to `goal`. + /// + /// If WGSL's automatic conversions (excluding the Load Rule) will + /// convert `self` to `goal`, then return a pair `(from, to)`, + /// where `from` and `to` are the scalar types of the leaf values + /// of `self` and `goal`. + /// + /// This function assumes that `self` and `goal` are different + /// types. Callers should first check whether any conversion is + /// needed at all. + /// + /// If the automatic conversions cannot convert `self` to `goal`, + /// return `None`. + fn automatically_converts_to( + &self, + goal: &Self, + types: &crate::UniqueArena, + ) -> Option<(crate::Scalar, crate::Scalar)> { + use crate::ScalarKind as Sk; + use crate::TypeInner as Ti; + + // Automatic conversions only change the scalar type of a value's leaves + // (e.g., `vec4` to `vec4`), never the type + // constructors applied to those scalar types (e.g., never scalar to + // `vec4`, or `vec2` to `vec3`). So first we check that the type + // constructors match, extracting the leaf scalar types in the process. + let expr_scalar; + let goal_scalar; + match (self, goal) { + (&Ti::Scalar(expr), &Ti::Scalar(goal)) => { + expr_scalar = expr; + goal_scalar = goal; + } + ( + &Ti::Vector { + size: expr_size, + scalar: expr, + }, + &Ti::Vector { + size: goal_size, + scalar: goal, + }, + ) if expr_size == goal_size => { + expr_scalar = expr; + goal_scalar = goal; + } + ( + &Ti::Matrix { + rows: expr_rows, + columns: expr_columns, + scalar: expr, + }, + &Ti::Matrix { + rows: goal_rows, + columns: goal_columns, + scalar: goal, + }, + ) if expr_rows == goal_rows && expr_columns == goal_columns => { + expr_scalar = expr; + goal_scalar = goal; + } + ( + &Ti::Array { + base: expr_base, + size: expr_size, + stride: _, + }, + &Ti::Array { + base: goal_base, + size: goal_size, + stride: _, + }, + ) if expr_size == goal_size => { + return types[expr_base] + .inner + .automatically_converts_to(&types[goal_base].inner, types); + } + _ => return None, + } + + match (expr_scalar.kind, goal_scalar.kind) { + (Sk::AbstractFloat, Sk::Float) => {} + (Sk::AbstractInt, Sk::Sint | Sk::Uint | Sk::AbstractFloat | Sk::Float) => {} + _ => return None, + } + + log::trace!(" okay: expr {expr_scalar:?}, goal {goal_scalar:?}"); + Some((expr_scalar, goal_scalar)) + } + + fn automatically_convertible_scalar( + &self, + types: &crate::UniqueArena, + ) -> Option { + use crate::TypeInner as Ti; + match *self { + Ti::Scalar(scalar) | Ti::Vector { scalar, .. } | Ti::Matrix { scalar, .. } => { + Some(scalar) + } + Ti::Array { base, .. } => types[base].inner.automatically_convertible_scalar(types), + Ti::Atomic(_) + | Ti::Pointer { .. } + | Ti::ValuePointer { .. } + | Ti::Struct { .. } + | Ti::Image { .. } + | Ti::Sampler { .. } + | Ti::AccelerationStructure + | Ti::RayQuery + | Ti::BindingArray { .. } => None, + } + } +} + +impl crate::Scalar { + /// Find the common type of `self` and `other` under WGSL's + /// automatic conversions. + /// + /// If there are any scalars to which WGSL's automatic conversions + /// will convert both `self` and `other`, return the best such + /// scalar. Otherwise, return `None`. + pub const fn automatic_conversion_combine(self, other: Self) -> Option { + use crate::ScalarKind as Sk; + + match (self.kind, other.kind) { + // When the kinds match... + (Sk::AbstractFloat, Sk::AbstractFloat) + | (Sk::AbstractInt, Sk::AbstractInt) + | (Sk::Sint, Sk::Sint) + | (Sk::Uint, Sk::Uint) + | (Sk::Float, Sk::Float) + | (Sk::Bool, Sk::Bool) => { + if self.width == other.width { + // ... either no conversion is necessary ... + Some(self) + } else { + // ... or no conversion is possible. + // We never convert concrete to concrete, and + // abstract types should have only one size. + None + } + } + + // AbstractInt converts to AbstractFloat. + (Sk::AbstractFloat, Sk::AbstractInt) => Some(self), + (Sk::AbstractInt, Sk::AbstractFloat) => Some(other), + + // AbstractFloat converts to Float. + (Sk::AbstractFloat, Sk::Float) => Some(other), + (Sk::Float, Sk::AbstractFloat) => Some(self), + + // AbstractInt converts to concrete integer or float. + (Sk::AbstractInt, Sk::Uint | Sk::Sint | Sk::Float) => Some(other), + (Sk::Uint | Sk::Sint | Sk::Float, Sk::AbstractInt) => Some(self), + + // AbstractFloat can't be reconciled with concrete integer types. + (Sk::AbstractFloat, Sk::Uint | Sk::Sint) | (Sk::Uint | Sk::Sint, Sk::AbstractFloat) => { + None + } + + // Nothing can be reconciled with `bool`. + (Sk::Bool, _) | (_, Sk::Bool) => None, + + // Different concrete types cannot be reconciled. + (Sk::Sint | Sk::Uint | Sk::Float, Sk::Sint | Sk::Uint | Sk::Float) => None, + } + } + + /// Return `true` if automatic conversions will covert `self` to `goal`. + pub fn automatically_converts_to(self, goal: Self) -> bool { + self.automatic_conversion_combine(goal) == Some(goal) + } + + const fn concretize(self) -> Self { + use crate::ScalarKind as Sk; + match self.kind { + Sk::Sint | Sk::Uint | Sk::Float | Sk::Bool => self, + Sk::AbstractInt => Self::I32, + Sk::AbstractFloat => Self::F32, + } + } +} diff --git a/third_party/rust/naga/src/front/wgsl/lower/mod.rs b/third_party/rust/naga/src/front/wgsl/lower/mod.rs new file mode 100644 index 0000000000..ba9b49e135 --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/lower/mod.rs @@ -0,0 +1,2760 @@ +use std::num::NonZeroU32; + +use crate::front::wgsl::error::{Error, ExpectedToken, InvalidAssignmentType}; +use crate::front::wgsl::index::Index; +use crate::front::wgsl::parse::number::Number; +use crate::front::wgsl::parse::{ast, conv}; +use crate::front::Typifier; +use crate::proc::{ + ensure_block_returns, Alignment, ConstantEvaluator, Emitter, Layouter, ResolveContext, +}; +use crate::{Arena, FastHashMap, FastIndexMap, Handle, Span}; + +mod construction; +mod conversion; + +/// Resolves the inner type of a given expression. +/// +/// Expects a &mut [`ExpressionContext`] and a [`Handle`]. +/// +/// Returns a &[`crate::TypeInner`]. +/// +/// Ideally, we would simply have a function that takes a `&mut ExpressionContext` +/// and returns a `&TypeResolution`. Unfortunately, this leads the borrow checker +/// to conclude that the mutable borrow lasts for as long as we are using the +/// `&TypeResolution`, so we can't use the `ExpressionContext` for anything else - +/// like, say, resolving another operand's type. Using a macro that expands to +/// two separate calls, only the first of which needs a `&mut`, +/// lets the borrow checker see that the mutable borrow is over. +macro_rules! resolve_inner { + ($ctx:ident, $expr:expr) => {{ + $ctx.grow_types($expr)?; + $ctx.typifier()[$expr].inner_with(&$ctx.module.types) + }}; +} +pub(super) use resolve_inner; + +/// Resolves the inner types of two given expressions. +/// +/// Expects a &mut [`ExpressionContext`] and two [`Handle`]s. +/// +/// Returns a tuple containing two &[`crate::TypeInner`]. +/// +/// See the documentation of [`resolve_inner!`] for why this macro is necessary. +macro_rules! resolve_inner_binary { + ($ctx:ident, $left:expr, $right:expr) => {{ + $ctx.grow_types($left)?; + $ctx.grow_types($right)?; + ( + $ctx.typifier()[$left].inner_with(&$ctx.module.types), + $ctx.typifier()[$right].inner_with(&$ctx.module.types), + ) + }}; +} + +/// Resolves the type of a given expression. +/// +/// Expects a &mut [`ExpressionContext`] and a [`Handle`]. +/// +/// Returns a &[`TypeResolution`]. +/// +/// See the documentation of [`resolve_inner!`] for why this macro is necessary. +/// +/// [`TypeResolution`]: crate::proc::TypeResolution +macro_rules! resolve { + ($ctx:ident, $expr:expr) => {{ + $ctx.grow_types($expr)?; + &$ctx.typifier()[$expr] + }}; +} +pub(super) use resolve; + +/// State for constructing a `crate::Module`. +pub struct GlobalContext<'source, 'temp, 'out> { + /// The `TranslationUnit`'s expressions arena. + ast_expressions: &'temp Arena>, + + /// The `TranslationUnit`'s types arena. + types: &'temp Arena>, + + // Naga IR values. + /// The map from the names of module-scope declarations to the Naga IR + /// `Handle`s we have built for them, owned by `Lowerer::lower`. + globals: &'temp mut FastHashMap<&'source str, LoweredGlobalDecl>, + + /// The module we're constructing. + module: &'out mut crate::Module, + + const_typifier: &'temp mut Typifier, +} + +impl<'source> GlobalContext<'source, '_, '_> { + fn as_const(&mut self) -> ExpressionContext<'source, '_, '_> { + ExpressionContext { + ast_expressions: self.ast_expressions, + globals: self.globals, + types: self.types, + module: self.module, + const_typifier: self.const_typifier, + expr_type: ExpressionContextType::Constant, + } + } + + fn ensure_type_exists( + &mut self, + name: Option, + inner: crate::TypeInner, + ) -> Handle { + self.module + .types + .insert(crate::Type { inner, name }, Span::UNDEFINED) + } +} + +/// State for lowering a statement within a function. +pub struct StatementContext<'source, 'temp, 'out> { + // WGSL AST values. + /// A reference to [`TranslationUnit::expressions`] for the translation unit + /// we're lowering. + /// + /// [`TranslationUnit::expressions`]: ast::TranslationUnit::expressions + ast_expressions: &'temp Arena>, + + /// A reference to [`TranslationUnit::types`] for the translation unit + /// we're lowering. + /// + /// [`TranslationUnit::types`]: ast::TranslationUnit::types + types: &'temp Arena>, + + // Naga IR values. + /// The map from the names of module-scope declarations to the Naga IR + /// `Handle`s we have built for them, owned by `Lowerer::lower`. + globals: &'temp mut FastHashMap<&'source str, LoweredGlobalDecl>, + + /// A map from each `ast::Local` handle to the Naga expression + /// we've built for it: + /// + /// - WGSL function arguments become Naga [`FunctionArgument`] expressions. + /// + /// - WGSL `var` declarations become Naga [`LocalVariable`] expressions. + /// + /// - WGSL `let` declararations become arbitrary Naga expressions. + /// + /// This always borrows the `local_table` local variable in + /// [`Lowerer::function`]. + /// + /// [`LocalVariable`]: crate::Expression::LocalVariable + /// [`FunctionArgument`]: crate::Expression::FunctionArgument + local_table: &'temp mut FastHashMap, Typed>>, + + const_typifier: &'temp mut Typifier, + typifier: &'temp mut Typifier, + function: &'out mut crate::Function, + /// Stores the names of expressions that are assigned in `let` statement + /// Also stores the spans of the names, for use in errors. + named_expressions: &'out mut FastIndexMap, (String, Span)>, + module: &'out mut crate::Module, + + /// Which `Expression`s in `self.naga_expressions` are const expressions, in + /// the WGSL sense. + /// + /// According to the WGSL spec, a const expression must not refer to any + /// `let` declarations, even if those declarations' initializers are + /// themselves const expressions. So this tracker is not simply concerned + /// with the form of the expressions; it is also tracking whether WGSL says + /// we should consider them to be const. See the use of `force_non_const` in + /// the code for lowering `let` bindings. + expression_constness: &'temp mut crate::proc::ExpressionConstnessTracker, +} + +impl<'a, 'temp> StatementContext<'a, 'temp, '_> { + fn as_expression<'t>( + &'t mut self, + block: &'t mut crate::Block, + emitter: &'t mut Emitter, + ) -> ExpressionContext<'a, 't, '_> + where + 'temp: 't, + { + ExpressionContext { + globals: self.globals, + types: self.types, + ast_expressions: self.ast_expressions, + const_typifier: self.const_typifier, + module: self.module, + expr_type: ExpressionContextType::Runtime(RuntimeExpressionContext { + local_table: self.local_table, + function: self.function, + block, + emitter, + typifier: self.typifier, + expression_constness: self.expression_constness, + }), + } + } + + fn as_global(&mut self) -> GlobalContext<'a, '_, '_> { + GlobalContext { + ast_expressions: self.ast_expressions, + globals: self.globals, + types: self.types, + module: self.module, + const_typifier: self.const_typifier, + } + } + + fn invalid_assignment_type(&self, expr: Handle) -> InvalidAssignmentType { + if let Some(&(_, span)) = self.named_expressions.get(&expr) { + InvalidAssignmentType::ImmutableBinding(span) + } else { + match self.function.expressions[expr] { + crate::Expression::Swizzle { .. } => InvalidAssignmentType::Swizzle, + crate::Expression::Access { base, .. } => self.invalid_assignment_type(base), + crate::Expression::AccessIndex { base, .. } => self.invalid_assignment_type(base), + _ => InvalidAssignmentType::Other, + } + } + } +} + +pub struct RuntimeExpressionContext<'temp, 'out> { + /// A map from [`ast::Local`] handles to the Naga expressions we've built for them. + /// + /// This is always [`StatementContext::local_table`] for the + /// enclosing statement; see that documentation for details. + local_table: &'temp FastHashMap, Typed>>, + + function: &'out mut crate::Function, + block: &'temp mut crate::Block, + emitter: &'temp mut Emitter, + typifier: &'temp mut Typifier, + + /// Which `Expression`s in `self.naga_expressions` are const expressions, in + /// the WGSL sense. + /// + /// See [`StatementContext::expression_constness`] for details. + expression_constness: &'temp mut crate::proc::ExpressionConstnessTracker, +} + +/// The type of Naga IR expression we are lowering an [`ast::Expression`] to. +pub enum ExpressionContextType<'temp, 'out> { + /// We are lowering to an arbitrary runtime expression, to be + /// included in a function's body. + /// + /// The given [`RuntimeExpressionContext`] holds information about local + /// variables, arguments, and other definitions available only to runtime + /// expressions, not constant or override expressions. + Runtime(RuntimeExpressionContext<'temp, 'out>), + + /// We are lowering to a constant expression, to be included in the module's + /// constant expression arena. + /// + /// Everything constant expressions are allowed to refer to is + /// available in the [`ExpressionContext`], so this variant + /// carries no further information. + Constant, +} + +/// State for lowering an [`ast::Expression`] to Naga IR. +/// +/// [`ExpressionContext`]s come in two kinds, distinguished by +/// the value of the [`expr_type`] field: +/// +/// - A [`Runtime`] context contributes [`naga::Expression`]s to a [`naga::Function`]'s +/// runtime expression arena. +/// +/// - A [`Constant`] context contributes [`naga::Expression`]s to a [`naga::Module`]'s +/// constant expression arena. +/// +/// [`ExpressionContext`]s are constructed in restricted ways: +/// +/// - To get a [`Runtime`] [`ExpressionContext`], call +/// [`StatementContext::as_expression`]. +/// +/// - To get a [`Constant`] [`ExpressionContext`], call +/// [`GlobalContext::as_const`]. +/// +/// - You can demote a [`Runtime`] context to a [`Constant`] context +/// by calling [`as_const`], but there's no way to go in the other +/// direction, producing a runtime context from a constant one. This +/// is because runtime expressions can refer to constant +/// expressions, via [`Expression::Constant`], but constant +/// expressions can't refer to a function's expressions. +/// +/// Not to be confused with `wgsl::parse::ExpressionContext`, which is +/// for parsing the `ast::Expression` in the first place. +/// +/// [`expr_type`]: ExpressionContext::expr_type +/// [`Runtime`]: ExpressionContextType::Runtime +/// [`naga::Expression`]: crate::Expression +/// [`naga::Function`]: crate::Function +/// [`Constant`]: ExpressionContextType::Constant +/// [`naga::Module`]: crate::Module +/// [`as_const`]: ExpressionContext::as_const +/// [`Expression::Constant`]: crate::Expression::Constant +pub struct ExpressionContext<'source, 'temp, 'out> { + // WGSL AST values. + ast_expressions: &'temp Arena>, + types: &'temp Arena>, + + // Naga IR values. + /// The map from the names of module-scope declarations to the Naga IR + /// `Handle`s we have built for them, owned by `Lowerer::lower`. + globals: &'temp mut FastHashMap<&'source str, LoweredGlobalDecl>, + + /// The IR [`Module`] we're constructing. + /// + /// [`Module`]: crate::Module + module: &'out mut crate::Module, + + /// Type judgments for [`module::const_expressions`]. + /// + /// [`module::const_expressions`]: crate::Module::const_expressions + const_typifier: &'temp mut Typifier, + + /// Whether we are lowering a constant expression or a general + /// runtime expression, and the data needed in each case. + expr_type: ExpressionContextType<'temp, 'out>, +} + +impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { + fn as_const(&mut self) -> ExpressionContext<'source, '_, '_> { + ExpressionContext { + globals: self.globals, + types: self.types, + ast_expressions: self.ast_expressions, + const_typifier: self.const_typifier, + module: self.module, + expr_type: ExpressionContextType::Constant, + } + } + + fn as_global(&mut self) -> GlobalContext<'source, '_, '_> { + GlobalContext { + ast_expressions: self.ast_expressions, + globals: self.globals, + types: self.types, + module: self.module, + const_typifier: self.const_typifier, + } + } + + fn as_const_evaluator(&mut self) -> ConstantEvaluator { + match self.expr_type { + ExpressionContextType::Runtime(ref mut rctx) => ConstantEvaluator::for_wgsl_function( + self.module, + &mut rctx.function.expressions, + rctx.expression_constness, + rctx.emitter, + rctx.block, + ), + ExpressionContextType::Constant => ConstantEvaluator::for_wgsl_module(self.module), + } + } + + fn append_expression( + &mut self, + expr: crate::Expression, + span: Span, + ) -> Result, Error<'source>> { + let mut eval = self.as_const_evaluator(); + match eval.try_eval_and_append(&expr, span) { + Ok(expr) => Ok(expr), + + // `expr` is not a constant expression. This is fine as + // long as we're not building `Module::const_expressions`. + Err(err) => match self.expr_type { + ExpressionContextType::Runtime(ref mut rctx) => { + Ok(rctx.function.expressions.append(expr, span)) + } + ExpressionContextType::Constant => Err(Error::ConstantEvaluatorError(err, span)), + }, + } + } + + fn const_access(&self, handle: Handle) -> Option { + match self.expr_type { + ExpressionContextType::Runtime(ref ctx) => { + if !ctx.expression_constness.is_const(handle) { + return None; + } + + self.module + .to_ctx() + .eval_expr_to_u32_from(handle, &ctx.function.expressions) + .ok() + } + ExpressionContextType::Constant => self.module.to_ctx().eval_expr_to_u32(handle).ok(), + } + } + + fn get_expression_span(&self, handle: Handle) -> Span { + match self.expr_type { + ExpressionContextType::Runtime(ref ctx) => ctx.function.expressions.get_span(handle), + ExpressionContextType::Constant => self.module.const_expressions.get_span(handle), + } + } + + fn typifier(&self) -> &Typifier { + match self.expr_type { + ExpressionContextType::Runtime(ref ctx) => ctx.typifier, + ExpressionContextType::Constant => self.const_typifier, + } + } + + fn runtime_expression_ctx( + &mut self, + span: Span, + ) -> Result<&mut RuntimeExpressionContext<'temp, 'out>, Error<'source>> { + match self.expr_type { + ExpressionContextType::Runtime(ref mut ctx) => Ok(ctx), + ExpressionContextType::Constant => Err(Error::UnexpectedOperationInConstContext(span)), + } + } + + fn gather_component( + &mut self, + expr: Handle, + component_span: Span, + gather_span: Span, + ) -> Result> { + match self.expr_type { + ExpressionContextType::Runtime(ref rctx) => { + if !rctx.expression_constness.is_const(expr) { + return Err(Error::ExpectedConstExprConcreteIntegerScalar( + component_span, + )); + } + + let index = self + .module + .to_ctx() + .eval_expr_to_u32_from(expr, &rctx.function.expressions) + .map_err(|err| match err { + crate::proc::U32EvalError::NonConst => { + Error::ExpectedConstExprConcreteIntegerScalar(component_span) + } + crate::proc::U32EvalError::Negative => { + Error::ExpectedNonNegative(component_span) + } + })?; + crate::SwizzleComponent::XYZW + .get(index as usize) + .copied() + .ok_or(Error::InvalidGatherComponent(component_span)) + } + // This means a `gather` operation appeared in a constant expression. + // This error refers to the `gather` itself, not its "component" argument. + ExpressionContextType::Constant => { + Err(Error::UnexpectedOperationInConstContext(gather_span)) + } + } + } + + /// Determine the type of `handle`, and add it to the module's arena. + /// + /// If you just need a `TypeInner` for `handle`'s type, use the + /// [`resolve_inner!`] macro instead. This function + /// should only be used when the type of `handle` needs to appear + /// in the module's final `Arena`, for example, if you're + /// creating a [`LocalVariable`] whose type is inferred from its + /// initializer. + /// + /// [`LocalVariable`]: crate::LocalVariable + fn register_type( + &mut self, + handle: Handle, + ) -> Result, Error<'source>> { + self.grow_types(handle)?; + // This is equivalent to calling ExpressionContext::typifier(), + // except that this lets the borrow checker see that it's okay + // to also borrow self.module.types mutably below. + let typifier = match self.expr_type { + ExpressionContextType::Runtime(ref ctx) => ctx.typifier, + ExpressionContextType::Constant => &*self.const_typifier, + }; + Ok(typifier.register_type(handle, &mut self.module.types)) + } + + /// Resolve the types of all expressions up through `handle`. + /// + /// Ensure that [`self.typifier`] has a [`TypeResolution`] for + /// every expression in [`self.function.expressions`]. + /// + /// This does not add types to any arena. The [`Typifier`] + /// documentation explains the steps we take to avoid filling + /// arenas with intermediate types. + /// + /// This function takes `&mut self`, so it can't conveniently + /// return a shared reference to the resulting `TypeResolution`: + /// the shared reference would extend the mutable borrow, and you + /// wouldn't be able to use `self` for anything else. Instead, you + /// should use [`register_type`] or one of [`resolve!`], + /// [`resolve_inner!`] or [`resolve_inner_binary!`]. + /// + /// [`self.typifier`]: ExpressionContext::typifier + /// [`TypeResolution`]: crate::proc::TypeResolution + /// [`register_type`]: Self::register_type + /// [`Typifier`]: Typifier + fn grow_types( + &mut self, + handle: Handle, + ) -> Result<&mut Self, Error<'source>> { + let empty_arena = Arena::new(); + let resolve_ctx; + let typifier; + let expressions; + match self.expr_type { + ExpressionContextType::Runtime(ref mut ctx) => { + resolve_ctx = ResolveContext::with_locals( + self.module, + &ctx.function.local_variables, + &ctx.function.arguments, + ); + typifier = &mut *ctx.typifier; + expressions = &ctx.function.expressions; + } + ExpressionContextType::Constant => { + resolve_ctx = ResolveContext::with_locals(self.module, &empty_arena, &[]); + typifier = self.const_typifier; + expressions = &self.module.const_expressions; + } + }; + typifier + .grow(handle, expressions, &resolve_ctx) + .map_err(Error::InvalidResolve)?; + + Ok(self) + } + + fn image_data( + &mut self, + image: Handle, + span: Span, + ) -> Result<(crate::ImageClass, bool), Error<'source>> { + match *resolve_inner!(self, image) { + crate::TypeInner::Image { class, arrayed, .. } => Ok((class, arrayed)), + _ => Err(Error::BadTexture(span)), + } + } + + fn prepare_args<'b>( + &mut self, + args: &'b [Handle>], + min_args: u32, + span: Span, + ) -> ArgumentContext<'b, 'source> { + ArgumentContext { + args: args.iter(), + min_args, + args_used: 0, + total_args: args.len() as u32, + span, + } + } + + /// Insert splats, if needed by the non-'*' operations. + /// + /// See the "Binary arithmetic expressions with mixed scalar and vector operands" + /// table in the WebGPU Shading Language specification for relevant operators. + /// + /// Multiply is not handled here as backends are expected to handle vec*scalar + /// operations, so inserting splats into the IR increases size needlessly. + fn binary_op_splat( + &mut self, + op: crate::BinaryOperator, + left: &mut Handle, + right: &mut Handle, + ) -> Result<(), Error<'source>> { + if matches!( + op, + crate::BinaryOperator::Add + | crate::BinaryOperator::Subtract + | crate::BinaryOperator::Divide + | crate::BinaryOperator::Modulo + ) { + match resolve_inner_binary!(self, *left, *right) { + (&crate::TypeInner::Vector { size, .. }, &crate::TypeInner::Scalar { .. }) => { + *right = self.append_expression( + crate::Expression::Splat { + size, + value: *right, + }, + self.get_expression_span(*right), + )?; + } + (&crate::TypeInner::Scalar { .. }, &crate::TypeInner::Vector { size, .. }) => { + *left = self.append_expression( + crate::Expression::Splat { size, value: *left }, + self.get_expression_span(*left), + )?; + } + _ => {} + } + } + + Ok(()) + } + + /// Add a single expression to the expression table that is not covered by `self.emitter`. + /// + /// This is useful for `CallResult` and `AtomicResult` expressions, which should not be covered by + /// `Emit` statements. + fn interrupt_emitter( + &mut self, + expression: crate::Expression, + span: Span, + ) -> Result, Error<'source>> { + match self.expr_type { + ExpressionContextType::Runtime(ref mut rctx) => { + rctx.block + .extend(rctx.emitter.finish(&rctx.function.expressions)); + } + ExpressionContextType::Constant => {} + } + let result = self.append_expression(expression, span); + match self.expr_type { + ExpressionContextType::Runtime(ref mut rctx) => { + rctx.emitter.start(&rctx.function.expressions); + } + ExpressionContextType::Constant => {} + } + result + } + + /// Apply the WGSL Load Rule to `expr`. + /// + /// If `expr` is has type `ref`, perform a load to produce a value of type + /// `T`. Otherwise, return `expr` unchanged. + fn apply_load_rule( + &mut self, + expr: Typed>, + ) -> Result, Error<'source>> { + match expr { + Typed::Reference(pointer) => { + let load = crate::Expression::Load { pointer }; + let span = self.get_expression_span(pointer); + self.append_expression(load, span) + } + Typed::Plain(handle) => Ok(handle), + } + } + + fn ensure_type_exists(&mut self, inner: crate::TypeInner) -> Handle { + self.as_global().ensure_type_exists(None, inner) + } +} + +struct ArgumentContext<'ctx, 'source> { + args: std::slice::Iter<'ctx, Handle>>, + min_args: u32, + args_used: u32, + total_args: u32, + span: Span, +} + +impl<'source> ArgumentContext<'_, 'source> { + pub fn finish(self) -> Result<(), Error<'source>> { + if self.args.len() == 0 { + Ok(()) + } else { + Err(Error::WrongArgumentCount { + found: self.total_args, + expected: self.min_args..self.args_used + 1, + span: self.span, + }) + } + } + + pub fn next(&mut self) -> Result>, Error<'source>> { + match self.args.next().copied() { + Some(arg) => { + self.args_used += 1; + Ok(arg) + } + None => Err(Error::WrongArgumentCount { + found: self.total_args, + expected: self.min_args..self.args_used + 1, + span: self.span, + }), + } + } +} + +/// WGSL type annotations on expressions, types, values, etc. +/// +/// Naga and WGSL types are very close, but Naga lacks WGSL's `ref` types, which +/// we need to know to apply the Load Rule. This enum carries some WGSL or Naga +/// datum along with enough information to determine its corresponding WGSL +/// type. +/// +/// The `T` type parameter can be any expression-like thing: +/// +/// - `Typed>` can represent a full WGSL type. For example, +/// given some Naga `Pointer` type `ptr`, a WGSL reference type is a +/// `Typed::Reference(ptr)` whereas a WGSL pointer type is a +/// `Typed::Plain(ptr)`. +/// +/// - `Typed` or `Typed>` can +/// represent references similarly. +/// +/// Use the `map` and `try_map` methods to convert from one expression +/// representation to another. +/// +/// [`Expression`]: crate::Expression +#[derive(Debug, Copy, Clone)] +enum Typed { + /// A WGSL reference. + Reference(T), + + /// A WGSL plain type. + Plain(T), +} + +impl Typed { + fn map(self, mut f: impl FnMut(T) -> U) -> Typed { + match self { + Self::Reference(v) => Typed::Reference(f(v)), + Self::Plain(v) => Typed::Plain(f(v)), + } + } + + fn try_map(self, mut f: impl FnMut(T) -> Result) -> Result, E> { + Ok(match self { + Self::Reference(expr) => Typed::Reference(f(expr)?), + Self::Plain(expr) => Typed::Plain(f(expr)?), + }) + } +} + +/// A single vector component or swizzle. +/// +/// This represents the things that can appear after the `.` in a vector access +/// expression: either a single component name, or a series of them, +/// representing a swizzle. +enum Components { + Single(u32), + Swizzle { + size: crate::VectorSize, + pattern: [crate::SwizzleComponent; 4], + }, +} + +impl Components { + const fn letter_component(letter: char) -> Option { + use crate::SwizzleComponent as Sc; + match letter { + 'x' | 'r' => Some(Sc::X), + 'y' | 'g' => Some(Sc::Y), + 'z' | 'b' => Some(Sc::Z), + 'w' | 'a' => Some(Sc::W), + _ => None, + } + } + + fn single_component(name: &str, name_span: Span) -> Result { + let ch = name.chars().next().ok_or(Error::BadAccessor(name_span))?; + match Self::letter_component(ch) { + Some(sc) => Ok(sc as u32), + None => Err(Error::BadAccessor(name_span)), + } + } + + /// Construct a `Components` value from a 'member' name, like `"wzy"` or `"x"`. + /// + /// Use `name_span` for reporting errors in parsing the component string. + fn new(name: &str, name_span: Span) -> Result { + let size = match name.len() { + 1 => return Ok(Components::Single(Self::single_component(name, name_span)?)), + 2 => crate::VectorSize::Bi, + 3 => crate::VectorSize::Tri, + 4 => crate::VectorSize::Quad, + _ => return Err(Error::BadAccessor(name_span)), + }; + + let mut pattern = [crate::SwizzleComponent::X; 4]; + for (comp, ch) in pattern.iter_mut().zip(name.chars()) { + *comp = Self::letter_component(ch).ok_or(Error::BadAccessor(name_span))?; + } + + Ok(Components::Swizzle { size, pattern }) + } +} + +/// An `ast::GlobalDecl` for which we have built the Naga IR equivalent. +enum LoweredGlobalDecl { + Function(Handle), + Var(Handle), + Const(Handle), + Type(Handle), + EntryPoint, +} + +enum Texture { + Gather, + GatherCompare, + + Sample, + SampleBias, + SampleCompare, + SampleCompareLevel, + SampleGrad, + SampleLevel, + // SampleBaseClampToEdge, +} + +impl Texture { + pub fn map(word: &str) -> Option { + Some(match word { + "textureGather" => Self::Gather, + "textureGatherCompare" => Self::GatherCompare, + + "textureSample" => Self::Sample, + "textureSampleBias" => Self::SampleBias, + "textureSampleCompare" => Self::SampleCompare, + "textureSampleCompareLevel" => Self::SampleCompareLevel, + "textureSampleGrad" => Self::SampleGrad, + "textureSampleLevel" => Self::SampleLevel, + // "textureSampleBaseClampToEdge" => Some(Self::SampleBaseClampToEdge), + _ => return None, + }) + } + + pub const fn min_argument_count(&self) -> u32 { + match *self { + Self::Gather => 3, + Self::GatherCompare => 4, + + Self::Sample => 3, + Self::SampleBias => 5, + Self::SampleCompare => 5, + Self::SampleCompareLevel => 5, + Self::SampleGrad => 6, + Self::SampleLevel => 5, + // Self::SampleBaseClampToEdge => 3, + } + } +} + +pub struct Lowerer<'source, 'temp> { + index: &'temp Index<'source>, + layouter: Layouter, +} + +impl<'source, 'temp> Lowerer<'source, 'temp> { + pub fn new(index: &'temp Index<'source>) -> Self { + Self { + index, + layouter: Layouter::default(), + } + } + + pub fn lower( + &mut self, + tu: &'temp ast::TranslationUnit<'source>, + ) -> Result> { + let mut module = crate::Module::default(); + + let mut ctx = GlobalContext { + ast_expressions: &tu.expressions, + globals: &mut FastHashMap::default(), + types: &tu.types, + module: &mut module, + const_typifier: &mut Typifier::new(), + }; + + for decl_handle in self.index.visit_ordered() { + let span = tu.decls.get_span(decl_handle); + let decl = &tu.decls[decl_handle]; + + match decl.kind { + ast::GlobalDeclKind::Fn(ref f) => { + let lowered_decl = self.function(f, span, &mut ctx)?; + ctx.globals.insert(f.name.name, lowered_decl); + } + ast::GlobalDeclKind::Var(ref v) => { + let ty = self.resolve_ast_type(v.ty, &mut ctx)?; + + let init; + if let Some(init_ast) = v.init { + let mut ectx = ctx.as_const(); + let lowered = self.expression_for_abstract(init_ast, &mut ectx)?; + let ty_res = crate::proc::TypeResolution::Handle(ty); + let converted = ectx + .try_automatic_conversions(lowered, &ty_res, v.name.span) + .map_err(|error| match error { + Error::AutoConversion { + dest_span: _, + dest_type, + source_span: _, + source_type, + } => Error::InitializationTypeMismatch { + name: v.name.span, + expected: dest_type, + got: source_type, + }, + other => other, + })?; + init = Some(converted); + } else { + init = None; + } + + let binding = if let Some(ref binding) = v.binding { + Some(crate::ResourceBinding { + group: self.const_u32(binding.group, &mut ctx.as_const())?.0, + binding: self.const_u32(binding.binding, &mut ctx.as_const())?.0, + }) + } else { + None + }; + + let handle = ctx.module.global_variables.append( + crate::GlobalVariable { + name: Some(v.name.name.to_string()), + space: v.space, + binding, + ty, + init, + }, + span, + ); + + ctx.globals + .insert(v.name.name, LoweredGlobalDecl::Var(handle)); + } + ast::GlobalDeclKind::Const(ref c) => { + let mut ectx = ctx.as_const(); + let mut init = self.expression_for_abstract(c.init, &mut ectx)?; + + let ty; + if let Some(explicit_ty) = c.ty { + let explicit_ty = + self.resolve_ast_type(explicit_ty, &mut ectx.as_global())?; + let explicit_ty_res = crate::proc::TypeResolution::Handle(explicit_ty); + init = ectx + .try_automatic_conversions(init, &explicit_ty_res, c.name.span) + .map_err(|error| match error { + Error::AutoConversion { + dest_span: _, + dest_type, + source_span: _, + source_type, + } => Error::InitializationTypeMismatch { + name: c.name.span, + expected: dest_type, + got: source_type, + }, + other => other, + })?; + ty = explicit_ty; + } else { + init = ectx.concretize(init)?; + ty = ectx.register_type(init)?; + } + + let handle = ctx.module.constants.append( + crate::Constant { + name: Some(c.name.name.to_string()), + r#override: crate::Override::None, + ty, + init, + }, + span, + ); + + ctx.globals + .insert(c.name.name, LoweredGlobalDecl::Const(handle)); + } + ast::GlobalDeclKind::Struct(ref s) => { + let handle = self.r#struct(s, span, &mut ctx)?; + ctx.globals + .insert(s.name.name, LoweredGlobalDecl::Type(handle)); + } + ast::GlobalDeclKind::Type(ref alias) => { + let ty = self.resolve_named_ast_type( + alias.ty, + Some(alias.name.name.to_string()), + &mut ctx, + )?; + ctx.globals + .insert(alias.name.name, LoweredGlobalDecl::Type(ty)); + } + } + } + + // Constant evaluation may leave abstract-typed literals and + // compositions in expression arenas, so we need to compact the module + // to remove unused expressions and types. + crate::compact::compact(&mut module); + + Ok(module) + } + + fn function( + &mut self, + f: &ast::Function<'source>, + span: Span, + ctx: &mut GlobalContext<'source, '_, '_>, + ) -> Result> { + let mut local_table = FastHashMap::default(); + let mut expressions = Arena::new(); + let mut named_expressions = FastIndexMap::default(); + + let arguments = f + .arguments + .iter() + .enumerate() + .map(|(i, arg)| { + let ty = self.resolve_ast_type(arg.ty, ctx)?; + let expr = expressions + .append(crate::Expression::FunctionArgument(i as u32), arg.name.span); + local_table.insert(arg.handle, Typed::Plain(expr)); + named_expressions.insert(expr, (arg.name.name.to_string(), arg.name.span)); + + Ok(crate::FunctionArgument { + name: Some(arg.name.name.to_string()), + ty, + binding: self.binding(&arg.binding, ty, ctx)?, + }) + }) + .collect::, _>>()?; + + let result = f + .result + .as_ref() + .map(|res| { + let ty = self.resolve_ast_type(res.ty, ctx)?; + Ok(crate::FunctionResult { + ty, + binding: self.binding(&res.binding, ty, ctx)?, + }) + }) + .transpose()?; + + let mut function = crate::Function { + name: Some(f.name.name.to_string()), + arguments, + result, + local_variables: Arena::new(), + expressions, + named_expressions: crate::NamedExpressions::default(), + body: crate::Block::default(), + }; + + let mut typifier = Typifier::default(); + let mut stmt_ctx = StatementContext { + local_table: &mut local_table, + globals: ctx.globals, + ast_expressions: ctx.ast_expressions, + const_typifier: ctx.const_typifier, + typifier: &mut typifier, + function: &mut function, + named_expressions: &mut named_expressions, + types: ctx.types, + module: ctx.module, + expression_constness: &mut crate::proc::ExpressionConstnessTracker::new(), + }; + let mut body = self.block(&f.body, false, &mut stmt_ctx)?; + ensure_block_returns(&mut body); + + function.body = body; + function.named_expressions = named_expressions + .into_iter() + .map(|(key, (name, _))| (key, name)) + .collect(); + + if let Some(ref entry) = f.entry_point { + let workgroup_size = if let Some(workgroup_size) = entry.workgroup_size { + // TODO: replace with try_map once stabilized + let mut workgroup_size_out = [1; 3]; + for (i, size) in workgroup_size.into_iter().enumerate() { + if let Some(size_expr) = size { + workgroup_size_out[i] = self.const_u32(size_expr, &mut ctx.as_const())?.0; + } + } + workgroup_size_out + } else { + [0; 3] + }; + + ctx.module.entry_points.push(crate::EntryPoint { + name: f.name.name.to_string(), + stage: entry.stage, + early_depth_test: entry.early_depth_test, + workgroup_size, + function, + }); + Ok(LoweredGlobalDecl::EntryPoint) + } else { + let handle = ctx.module.functions.append(function, span); + Ok(LoweredGlobalDecl::Function(handle)) + } + } + + fn block( + &mut self, + b: &ast::Block<'source>, + is_inside_loop: bool, + ctx: &mut StatementContext<'source, '_, '_>, + ) -> Result> { + let mut block = crate::Block::default(); + + for stmt in b.stmts.iter() { + self.statement(stmt, &mut block, is_inside_loop, ctx)?; + } + + Ok(block) + } + + fn statement( + &mut self, + stmt: &ast::Statement<'source>, + block: &mut crate::Block, + is_inside_loop: bool, + ctx: &mut StatementContext<'source, '_, '_>, + ) -> Result<(), Error<'source>> { + let out = match stmt.kind { + ast::StatementKind::Block(ref block) => { + let block = self.block(block, is_inside_loop, ctx)?; + crate::Statement::Block(block) + } + ast::StatementKind::LocalDecl(ref decl) => match *decl { + ast::LocalDecl::Let(ref l) => { + let mut emitter = Emitter::default(); + emitter.start(&ctx.function.expressions); + + let value = + self.expression(l.init, &mut ctx.as_expression(block, &mut emitter))?; + + // The WGSL spec says that any expression that refers to a + // `let`-bound variable is not a const expression. This + // affects when errors must be reported, so we can't even + // treat suitable `let` bindings as constant as an + // optimization. + ctx.expression_constness.force_non_const(value); + + let explicit_ty = + l.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx.as_global())) + .transpose()?; + + if let Some(ty) = explicit_ty { + let mut ctx = ctx.as_expression(block, &mut emitter); + let init_ty = ctx.register_type(value)?; + if !ctx.module.types[ty] + .inner + .equivalent(&ctx.module.types[init_ty].inner, &ctx.module.types) + { + let gctx = &ctx.module.to_ctx(); + return Err(Error::InitializationTypeMismatch { + name: l.name.span, + expected: ty.to_wgsl(gctx), + got: init_ty.to_wgsl(gctx), + }); + } + } + + block.extend(emitter.finish(&ctx.function.expressions)); + ctx.local_table.insert(l.handle, Typed::Plain(value)); + ctx.named_expressions + .insert(value, (l.name.name.to_string(), l.name.span)); + + return Ok(()); + } + ast::LocalDecl::Var(ref v) => { + let explicit_ty = + v.ty.map(|ast| self.resolve_ast_type(ast, &mut ctx.as_global())) + .transpose()?; + + let mut emitter = Emitter::default(); + emitter.start(&ctx.function.expressions); + let mut ectx = ctx.as_expression(block, &mut emitter); + + let ty; + let initializer; + match (v.init, explicit_ty) { + (Some(init), Some(explicit_ty)) => { + let init = self.expression_for_abstract(init, &mut ectx)?; + let ty_res = crate::proc::TypeResolution::Handle(explicit_ty); + let init = ectx + .try_automatic_conversions(init, &ty_res, v.name.span) + .map_err(|error| match error { + Error::AutoConversion { + dest_span: _, + dest_type, + source_span: _, + source_type, + } => Error::InitializationTypeMismatch { + name: v.name.span, + expected: dest_type, + got: source_type, + }, + other => other, + })?; + ty = explicit_ty; + initializer = Some(init); + } + (Some(init), None) => { + let concretized = self.expression(init, &mut ectx)?; + ty = ectx.register_type(concretized)?; + initializer = Some(concretized); + } + (None, Some(explicit_ty)) => { + ty = explicit_ty; + initializer = None; + } + (None, None) => return Err(Error::MissingType(v.name.span)), + } + + let (const_initializer, initializer) = { + match initializer { + Some(init) => { + // It's not correct to hoist the initializer up + // to the top of the function if: + // - the initialization is inside a loop, and should + // take place on every iteration, or + // - the initialization is not a constant + // expression, so its value depends on the + // state at the point of initialization. + if is_inside_loop || !ctx.expression_constness.is_const(init) { + (None, Some(init)) + } else { + (Some(init), None) + } + } + None => (None, None), + } + }; + + let var = ctx.function.local_variables.append( + crate::LocalVariable { + name: Some(v.name.name.to_string()), + ty, + init: const_initializer, + }, + stmt.span, + ); + + let handle = ctx.as_expression(block, &mut emitter).interrupt_emitter( + crate::Expression::LocalVariable(var), + Span::UNDEFINED, + )?; + block.extend(emitter.finish(&ctx.function.expressions)); + ctx.local_table.insert(v.handle, Typed::Reference(handle)); + + match initializer { + Some(initializer) => crate::Statement::Store { + pointer: handle, + value: initializer, + }, + None => return Ok(()), + } + } + }, + ast::StatementKind::If { + condition, + ref accept, + ref reject, + } => { + let mut emitter = Emitter::default(); + emitter.start(&ctx.function.expressions); + + let condition = + self.expression(condition, &mut ctx.as_expression(block, &mut emitter))?; + block.extend(emitter.finish(&ctx.function.expressions)); + + let accept = self.block(accept, is_inside_loop, ctx)?; + let reject = self.block(reject, is_inside_loop, ctx)?; + + crate::Statement::If { + condition, + accept, + reject, + } + } + ast::StatementKind::Switch { + selector, + ref cases, + } => { + let mut emitter = Emitter::default(); + emitter.start(&ctx.function.expressions); + + let mut ectx = ctx.as_expression(block, &mut emitter); + let selector = self.expression(selector, &mut ectx)?; + + let uint = + resolve_inner!(ectx, selector).scalar_kind() == Some(crate::ScalarKind::Uint); + block.extend(emitter.finish(&ctx.function.expressions)); + + let cases = cases + .iter() + .map(|case| { + Ok(crate::SwitchCase { + value: match case.value { + ast::SwitchValue::Expr(expr) => { + let span = ctx.ast_expressions.get_span(expr); + let expr = + self.expression(expr, &mut ctx.as_global().as_const())?; + match ctx.module.to_ctx().eval_expr_to_literal(expr) { + Some(crate::Literal::I32(value)) if !uint => { + crate::SwitchValue::I32(value) + } + Some(crate::Literal::U32(value)) if uint => { + crate::SwitchValue::U32(value) + } + _ => { + return Err(Error::InvalidSwitchValue { uint, span }); + } + } + } + ast::SwitchValue::Default => crate::SwitchValue::Default, + }, + body: self.block(&case.body, is_inside_loop, ctx)?, + fall_through: case.fall_through, + }) + }) + .collect::>()?; + + crate::Statement::Switch { selector, cases } + } + ast::StatementKind::Loop { + ref body, + ref continuing, + break_if, + } => { + let body = self.block(body, true, ctx)?; + let mut continuing = self.block(continuing, true, ctx)?; + + let mut emitter = Emitter::default(); + emitter.start(&ctx.function.expressions); + let break_if = break_if + .map(|expr| { + self.expression(expr, &mut ctx.as_expression(&mut continuing, &mut emitter)) + }) + .transpose()?; + continuing.extend(emitter.finish(&ctx.function.expressions)); + + crate::Statement::Loop { + body, + continuing, + break_if, + } + } + ast::StatementKind::Break => crate::Statement::Break, + ast::StatementKind::Continue => crate::Statement::Continue, + ast::StatementKind::Return { value } => { + let mut emitter = Emitter::default(); + emitter.start(&ctx.function.expressions); + + let value = value + .map(|expr| self.expression(expr, &mut ctx.as_expression(block, &mut emitter))) + .transpose()?; + block.extend(emitter.finish(&ctx.function.expressions)); + + crate::Statement::Return { value } + } + ast::StatementKind::Kill => crate::Statement::Kill, + ast::StatementKind::Call { + ref function, + ref arguments, + } => { + let mut emitter = Emitter::default(); + emitter.start(&ctx.function.expressions); + + let _ = self.call( + stmt.span, + function, + arguments, + &mut ctx.as_expression(block, &mut emitter), + )?; + block.extend(emitter.finish(&ctx.function.expressions)); + return Ok(()); + } + ast::StatementKind::Assign { + target: ast_target, + op, + value, + } => { + let mut emitter = Emitter::default(); + emitter.start(&ctx.function.expressions); + + let target = self.expression_for_reference( + ast_target, + &mut ctx.as_expression(block, &mut emitter), + )?; + let mut value = + self.expression(value, &mut ctx.as_expression(block, &mut emitter))?; + + let target_handle = match target { + Typed::Reference(handle) => handle, + Typed::Plain(handle) => { + let ty = ctx.invalid_assignment_type(handle); + return Err(Error::InvalidAssignment { + span: ctx.ast_expressions.get_span(ast_target), + ty, + }); + } + }; + + let value = match op { + Some(op) => { + let mut ctx = ctx.as_expression(block, &mut emitter); + let mut left = ctx.apply_load_rule(target)?; + ctx.binary_op_splat(op, &mut left, &mut value)?; + ctx.append_expression( + crate::Expression::Binary { + op, + left, + right: value, + }, + stmt.span, + )? + } + None => value, + }; + block.extend(emitter.finish(&ctx.function.expressions)); + + crate::Statement::Store { + pointer: target_handle, + value, + } + } + ast::StatementKind::Increment(value) | ast::StatementKind::Decrement(value) => { + let mut emitter = Emitter::default(); + emitter.start(&ctx.function.expressions); + + let op = match stmt.kind { + ast::StatementKind::Increment(_) => crate::BinaryOperator::Add, + ast::StatementKind::Decrement(_) => crate::BinaryOperator::Subtract, + _ => unreachable!(), + }; + + let value_span = ctx.ast_expressions.get_span(value); + let target = self + .expression_for_reference(value, &mut ctx.as_expression(block, &mut emitter))?; + let target_handle = match target { + Typed::Reference(handle) => handle, + Typed::Plain(_) => return Err(Error::BadIncrDecrReferenceType(value_span)), + }; + + let mut ectx = ctx.as_expression(block, &mut emitter); + let scalar = match *resolve_inner!(ectx, target_handle) { + crate::TypeInner::ValuePointer { + size: None, scalar, .. + } => scalar, + crate::TypeInner::Pointer { base, .. } => match ectx.module.types[base].inner { + crate::TypeInner::Scalar(scalar) => scalar, + _ => return Err(Error::BadIncrDecrReferenceType(value_span)), + }, + _ => return Err(Error::BadIncrDecrReferenceType(value_span)), + }; + let literal = match scalar.kind { + crate::ScalarKind::Sint | crate::ScalarKind::Uint => { + crate::Literal::one(scalar) + .ok_or(Error::BadIncrDecrReferenceType(value_span))? + } + _ => return Err(Error::BadIncrDecrReferenceType(value_span)), + }; + + let right = + ectx.interrupt_emitter(crate::Expression::Literal(literal), Span::UNDEFINED)?; + let rctx = ectx.runtime_expression_ctx(stmt.span)?; + let left = rctx.function.expressions.append( + crate::Expression::Load { + pointer: target_handle, + }, + value_span, + ); + let value = rctx + .function + .expressions + .append(crate::Expression::Binary { op, left, right }, stmt.span); + + block.extend(emitter.finish(&ctx.function.expressions)); + crate::Statement::Store { + pointer: target_handle, + value, + } + } + ast::StatementKind::Ignore(expr) => { + let mut emitter = Emitter::default(); + emitter.start(&ctx.function.expressions); + + let _ = self.expression(expr, &mut ctx.as_expression(block, &mut emitter))?; + block.extend(emitter.finish(&ctx.function.expressions)); + return Ok(()); + } + }; + + block.push(out, stmt.span); + + Ok(()) + } + + /// Lower `expr` and apply the Load Rule if possible. + /// + /// For the time being, this concretizes abstract values, to support + /// consumers that haven't been adapted to consume them yet. Consumers + /// prepared for abstract values can call [`expression_for_abstract`]. + /// + /// [`expression_for_abstract`]: Lowerer::expression_for_abstract + fn expression( + &mut self, + expr: Handle>, + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result, Error<'source>> { + let expr = self.expression_for_abstract(expr, ctx)?; + ctx.concretize(expr) + } + + fn expression_for_abstract( + &mut self, + expr: Handle>, + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result, Error<'source>> { + let expr = self.expression_for_reference(expr, ctx)?; + ctx.apply_load_rule(expr) + } + + fn expression_for_reference( + &mut self, + expr: Handle>, + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result>, Error<'source>> { + let span = ctx.ast_expressions.get_span(expr); + let expr = &ctx.ast_expressions[expr]; + + let expr: Typed = match *expr { + ast::Expression::Literal(literal) => { + let literal = match literal { + ast::Literal::Number(Number::F32(f)) => crate::Literal::F32(f), + ast::Literal::Number(Number::I32(i)) => crate::Literal::I32(i), + ast::Literal::Number(Number::U32(u)) => crate::Literal::U32(u), + ast::Literal::Number(Number::F64(f)) => crate::Literal::F64(f), + ast::Literal::Number(Number::AbstractInt(i)) => crate::Literal::AbstractInt(i), + ast::Literal::Number(Number::AbstractFloat(f)) => { + crate::Literal::AbstractFloat(f) + } + ast::Literal::Bool(b) => crate::Literal::Bool(b), + }; + let handle = ctx.interrupt_emitter(crate::Expression::Literal(literal), span)?; + return Ok(Typed::Plain(handle)); + } + ast::Expression::Ident(ast::IdentExpr::Local(local)) => { + let rctx = ctx.runtime_expression_ctx(span)?; + return Ok(rctx.local_table[&local]); + } + ast::Expression::Ident(ast::IdentExpr::Unresolved(name)) => { + let global = ctx + .globals + .get(name) + .ok_or(Error::UnknownIdent(span, name))?; + let expr = match *global { + LoweredGlobalDecl::Var(handle) => { + let expr = crate::Expression::GlobalVariable(handle); + match ctx.module.global_variables[handle].space { + crate::AddressSpace::Handle => Typed::Plain(expr), + _ => Typed::Reference(expr), + } + } + LoweredGlobalDecl::Const(handle) => { + Typed::Plain(crate::Expression::Constant(handle)) + } + _ => { + return Err(Error::Unexpected(span, ExpectedToken::Variable)); + } + }; + + return expr.try_map(|handle| ctx.interrupt_emitter(handle, span)); + } + ast::Expression::Construct { + ref ty, + ty_span, + ref components, + } => { + let handle = self.construct(span, ty, ty_span, components, ctx)?; + return Ok(Typed::Plain(handle)); + } + ast::Expression::Unary { op, expr } => { + let expr = self.expression_for_abstract(expr, ctx)?; + Typed::Plain(crate::Expression::Unary { op, expr }) + } + ast::Expression::AddrOf(expr) => { + // The `&` operator simply converts a reference to a pointer. And since a + // reference is required, the Load Rule is not applied. + match self.expression_for_reference(expr, ctx)? { + Typed::Reference(handle) => { + // No code is generated. We just declare the reference a pointer now. + return Ok(Typed::Plain(handle)); + } + Typed::Plain(_) => { + return Err(Error::NotReference("the operand of the `&` operator", span)); + } + } + } + ast::Expression::Deref(expr) => { + // The pointer we dereference must be loaded. + let pointer = self.expression(expr, ctx)?; + + if resolve_inner!(ctx, pointer).pointer_space().is_none() { + return Err(Error::NotPointer(span)); + } + + // No code is generated. We just declare the pointer a reference now. + return Ok(Typed::Reference(pointer)); + } + ast::Expression::Binary { op, left, right } => { + self.binary(op, left, right, span, ctx)? + } + ast::Expression::Call { + ref function, + ref arguments, + } => { + let handle = self + .call(span, function, arguments, ctx)? + .ok_or(Error::FunctionReturnsVoid(function.span))?; + return Ok(Typed::Plain(handle)); + } + ast::Expression::Index { base, index } => { + let lowered_base = self.expression_for_reference(base, ctx)?; + let index = self.expression(index, ctx)?; + + if let Typed::Plain(handle) = lowered_base { + if resolve_inner!(ctx, handle).pointer_space().is_some() { + return Err(Error::Pointer( + "the value indexed by a `[]` subscripting expression", + ctx.ast_expressions.get_span(base), + )); + } + } + + lowered_base.map(|base| match ctx.const_access(index) { + Some(index) => crate::Expression::AccessIndex { base, index }, + None => crate::Expression::Access { base, index }, + }) + } + ast::Expression::Member { base, ref field } => { + let lowered_base = self.expression_for_reference(base, ctx)?; + + let temp_inner; + let composite_type: &crate::TypeInner = match lowered_base { + Typed::Reference(handle) => { + let inner = resolve_inner!(ctx, handle); + match *inner { + crate::TypeInner::Pointer { base, .. } => &ctx.module.types[base].inner, + crate::TypeInner::ValuePointer { + size: None, scalar, .. + } => { + temp_inner = crate::TypeInner::Scalar(scalar); + &temp_inner + } + crate::TypeInner::ValuePointer { + size: Some(size), + scalar, + .. + } => { + temp_inner = crate::TypeInner::Vector { size, scalar }; + &temp_inner + } + _ => unreachable!( + "In Typed::Reference(handle), handle must be a Naga pointer" + ), + } + } + + Typed::Plain(handle) => { + let inner = resolve_inner!(ctx, handle); + if let crate::TypeInner::Pointer { .. } + | crate::TypeInner::ValuePointer { .. } = *inner + { + return Err(Error::Pointer( + "the value accessed by a `.member` expression", + ctx.ast_expressions.get_span(base), + )); + } + inner + } + }; + + let access = match *composite_type { + crate::TypeInner::Struct { ref members, .. } => { + let index = members + .iter() + .position(|m| m.name.as_deref() == Some(field.name)) + .ok_or(Error::BadAccessor(field.span))? + as u32; + + lowered_base.map(|base| crate::Expression::AccessIndex { base, index }) + } + crate::TypeInner::Vector { .. } | crate::TypeInner::Matrix { .. } => { + match Components::new(field.name, field.span)? { + Components::Swizzle { size, pattern } => { + // Swizzles aren't allowed on matrices, but + // validation will catch that. + Typed::Plain(crate::Expression::Swizzle { + size, + vector: ctx.apply_load_rule(lowered_base)?, + pattern, + }) + } + Components::Single(index) => lowered_base + .map(|base| crate::Expression::AccessIndex { base, index }), + } + } + _ => return Err(Error::BadAccessor(field.span)), + }; + + access + } + ast::Expression::Bitcast { expr, to, ty_span } => { + let expr = self.expression(expr, ctx)?; + let to_resolved = self.resolve_ast_type(to, &mut ctx.as_global())?; + + let element_scalar = match ctx.module.types[to_resolved].inner { + crate::TypeInner::Scalar(scalar) => scalar, + crate::TypeInner::Vector { scalar, .. } => scalar, + _ => { + let ty = resolve!(ctx, expr); + let gctx = &ctx.module.to_ctx(); + return Err(Error::BadTypeCast { + from_type: ty.to_wgsl(gctx), + span: ty_span, + to_type: to_resolved.to_wgsl(gctx), + }); + } + }; + + Typed::Plain(crate::Expression::As { + expr, + kind: element_scalar.kind, + convert: None, + }) + } + }; + + expr.try_map(|handle| ctx.append_expression(handle, span)) + } + + fn binary( + &mut self, + op: crate::BinaryOperator, + left: Handle>, + right: Handle>, + span: Span, + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result, Error<'source>> { + // Load both operands. + let mut left = self.expression_for_abstract(left, ctx)?; + let mut right = self.expression_for_abstract(right, ctx)?; + + // Convert `scalar op vector` to `vector op vector` by introducing + // `Splat` expressions. + ctx.binary_op_splat(op, &mut left, &mut right)?; + + // Apply automatic conversions. + match op { + // Shift operators require the right operand to be `u32` or + // `vecN`. We can let the validator sort out vector length + // issues, but the right operand must be, or convert to, a u32 leaf + // scalar. + crate::BinaryOperator::ShiftLeft | crate::BinaryOperator::ShiftRight => { + right = + ctx.try_automatic_conversion_for_leaf_scalar(right, crate::Scalar::U32, span)?; + } + + // All other operators follow the same pattern: reconcile the + // scalar leaf types. If there's no reconciliation possible, + // leave the expressions as they are: validation will report the + // problem. + _ => { + ctx.grow_types(left)?; + ctx.grow_types(right)?; + if let Ok(consensus_scalar) = + ctx.automatic_conversion_consensus([left, right].iter()) + { + ctx.convert_to_leaf_scalar(&mut left, consensus_scalar)?; + ctx.convert_to_leaf_scalar(&mut right, consensus_scalar)?; + } + } + } + + Ok(Typed::Plain(crate::Expression::Binary { op, left, right })) + } + + /// Generate Naga IR for call expressions and statements, and type + /// constructor expressions. + /// + /// The "function" being called is simply an `Ident` that we know refers to + /// some module-scope definition. + /// + /// - If it is the name of a type, then the expression is a type constructor + /// expression: either constructing a value from components, a conversion + /// expression, or a zero value expression. + /// + /// - If it is the name of a function, then we're generating a [`Call`] + /// statement. We may be in the midst of generating code for an + /// expression, in which case we must generate an `Emit` statement to + /// force evaluation of the IR expressions we've generated so far, add the + /// `Call` statement to the current block, and then resume generating + /// expressions. + /// + /// [`Call`]: crate::Statement::Call + fn call( + &mut self, + span: Span, + function: &ast::Ident<'source>, + arguments: &[Handle>], + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result>, Error<'source>> { + match ctx.globals.get(function.name) { + Some(&LoweredGlobalDecl::Type(ty)) => { + let handle = self.construct( + span, + &ast::ConstructorType::Type(ty), + function.span, + arguments, + ctx, + )?; + Ok(Some(handle)) + } + Some(&LoweredGlobalDecl::Const(_) | &LoweredGlobalDecl::Var(_)) => { + Err(Error::Unexpected(function.span, ExpectedToken::Function)) + } + Some(&LoweredGlobalDecl::EntryPoint) => Err(Error::CalledEntryPoint(function.span)), + Some(&LoweredGlobalDecl::Function(function)) => { + let arguments = arguments + .iter() + .map(|&arg| self.expression(arg, ctx)) + .collect::, _>>()?; + + let has_result = ctx.module.functions[function].result.is_some(); + let rctx = ctx.runtime_expression_ctx(span)?; + // we need to always do this before a fn call since all arguments need to be emitted before the fn call + rctx.block + .extend(rctx.emitter.finish(&rctx.function.expressions)); + let result = has_result.then(|| { + rctx.function + .expressions + .append(crate::Expression::CallResult(function), span) + }); + rctx.emitter.start(&rctx.function.expressions); + rctx.block.push( + crate::Statement::Call { + function, + arguments, + result, + }, + span, + ); + + Ok(result) + } + None => { + let span = function.span; + let expr = if let Some(fun) = conv::map_relational_fun(function.name) { + let mut args = ctx.prepare_args(arguments, 1, span); + let argument = self.expression(args.next()?, ctx)?; + args.finish()?; + + // Check for no-op all(bool) and any(bool): + let argument_unmodified = matches!( + fun, + crate::RelationalFunction::All | crate::RelationalFunction::Any + ) && { + matches!( + resolve_inner!(ctx, argument), + &crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Bool, + .. + }) + ) + }; + + if argument_unmodified { + return Ok(Some(argument)); + } else { + crate::Expression::Relational { fun, argument } + } + } else if let Some((axis, ctrl)) = conv::map_derivative(function.name) { + let mut args = ctx.prepare_args(arguments, 1, span); + let expr = self.expression(args.next()?, ctx)?; + args.finish()?; + + crate::Expression::Derivative { axis, ctrl, expr } + } else if let Some(fun) = conv::map_standard_fun(function.name) { + let expected = fun.argument_count() as _; + let mut args = ctx.prepare_args(arguments, expected, span); + + let arg = self.expression(args.next()?, ctx)?; + let arg1 = args + .next() + .map(|x| self.expression(x, ctx)) + .ok() + .transpose()?; + let arg2 = args + .next() + .map(|x| self.expression(x, ctx)) + .ok() + .transpose()?; + let arg3 = args + .next() + .map(|x| self.expression(x, ctx)) + .ok() + .transpose()?; + + args.finish()?; + + if fun == crate::MathFunction::Modf || fun == crate::MathFunction::Frexp { + if let Some((size, width)) = match *resolve_inner!(ctx, arg) { + crate::TypeInner::Scalar(crate::Scalar { width, .. }) => { + Some((None, width)) + } + crate::TypeInner::Vector { + size, + scalar: crate::Scalar { width, .. }, + .. + } => Some((Some(size), width)), + _ => None, + } { + ctx.module.generate_predeclared_type( + if fun == crate::MathFunction::Modf { + crate::PredeclaredType::ModfResult { size, width } + } else { + crate::PredeclaredType::FrexpResult { size, width } + }, + ); + } + } + + crate::Expression::Math { + fun, + arg, + arg1, + arg2, + arg3, + } + } else if let Some(fun) = Texture::map(function.name) { + self.texture_sample_helper(fun, arguments, span, ctx)? + } else { + match function.name { + "select" => { + let mut args = ctx.prepare_args(arguments, 3, span); + + let reject = self.expression(args.next()?, ctx)?; + let accept = self.expression(args.next()?, ctx)?; + let condition = self.expression(args.next()?, ctx)?; + + args.finish()?; + + crate::Expression::Select { + reject, + accept, + condition, + } + } + "arrayLength" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let expr = self.expression(args.next()?, ctx)?; + args.finish()?; + + crate::Expression::ArrayLength(expr) + } + "atomicLoad" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let pointer = self.atomic_pointer(args.next()?, ctx)?; + args.finish()?; + + crate::Expression::Load { pointer } + } + "atomicStore" => { + let mut args = ctx.prepare_args(arguments, 2, span); + let pointer = self.atomic_pointer(args.next()?, ctx)?; + let value = self.expression(args.next()?, ctx)?; + args.finish()?; + + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .extend(rctx.emitter.finish(&rctx.function.expressions)); + rctx.emitter.start(&rctx.function.expressions); + rctx.block + .push(crate::Statement::Store { pointer, value }, span); + return Ok(None); + } + "atomicAdd" => { + return Ok(Some(self.atomic_helper( + span, + crate::AtomicFunction::Add, + arguments, + ctx, + )?)) + } + "atomicSub" => { + return Ok(Some(self.atomic_helper( + span, + crate::AtomicFunction::Subtract, + arguments, + ctx, + )?)) + } + "atomicAnd" => { + return Ok(Some(self.atomic_helper( + span, + crate::AtomicFunction::And, + arguments, + ctx, + )?)) + } + "atomicOr" => { + return Ok(Some(self.atomic_helper( + span, + crate::AtomicFunction::InclusiveOr, + arguments, + ctx, + )?)) + } + "atomicXor" => { + return Ok(Some(self.atomic_helper( + span, + crate::AtomicFunction::ExclusiveOr, + arguments, + ctx, + )?)) + } + "atomicMin" => { + return Ok(Some(self.atomic_helper( + span, + crate::AtomicFunction::Min, + arguments, + ctx, + )?)) + } + "atomicMax" => { + return Ok(Some(self.atomic_helper( + span, + crate::AtomicFunction::Max, + arguments, + ctx, + )?)) + } + "atomicExchange" => { + return Ok(Some(self.atomic_helper( + span, + crate::AtomicFunction::Exchange { compare: None }, + arguments, + ctx, + )?)) + } + "atomicCompareExchangeWeak" => { + let mut args = ctx.prepare_args(arguments, 3, span); + + let pointer = self.atomic_pointer(args.next()?, ctx)?; + + let compare = self.expression(args.next()?, ctx)?; + + let value = args.next()?; + let value_span = ctx.ast_expressions.get_span(value); + let value = self.expression(value, ctx)?; + + args.finish()?; + + let expression = match *resolve_inner!(ctx, value) { + crate::TypeInner::Scalar(scalar) => { + crate::Expression::AtomicResult { + ty: ctx.module.generate_predeclared_type( + crate::PredeclaredType::AtomicCompareExchangeWeakResult( + scalar, + ), + ), + comparison: true, + } + } + _ => return Err(Error::InvalidAtomicOperandType(value_span)), + }; + + let result = ctx.interrupt_emitter(expression, span)?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::Atomic { + pointer, + fun: crate::AtomicFunction::Exchange { + compare: Some(compare), + }, + value, + result, + }, + span, + ); + return Ok(Some(result)); + } + "storageBarrier" => { + ctx.prepare_args(arguments, 0, span).finish()?; + + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .push(crate::Statement::Barrier(crate::Barrier::STORAGE), span); + return Ok(None); + } + "workgroupBarrier" => { + ctx.prepare_args(arguments, 0, span).finish()?; + + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .push(crate::Statement::Barrier(crate::Barrier::WORK_GROUP), span); + return Ok(None); + } + "workgroupUniformLoad" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let expr = args.next()?; + args.finish()?; + + let pointer = self.expression(expr, ctx)?; + let result_ty = match *resolve_inner!(ctx, pointer) { + crate::TypeInner::Pointer { + base, + space: crate::AddressSpace::WorkGroup, + } => base, + ref other => { + log::error!("Type {other:?} passed to workgroupUniformLoad"); + let span = ctx.ast_expressions.get_span(expr); + return Err(Error::InvalidWorkGroupUniformLoad(span)); + } + }; + let result = ctx.interrupt_emitter( + crate::Expression::WorkGroupUniformLoadResult { ty: result_ty }, + span, + )?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::WorkGroupUniformLoad { pointer, result }, + span, + ); + + return Ok(Some(result)); + } + "textureStore" => { + let mut args = ctx.prepare_args(arguments, 3, span); + + let image = args.next()?; + let image_span = ctx.ast_expressions.get_span(image); + let image = self.expression(image, ctx)?; + + let coordinate = self.expression(args.next()?, ctx)?; + + let (_, arrayed) = ctx.image_data(image, image_span)?; + let array_index = arrayed + .then(|| { + args.min_args += 1; + self.expression(args.next()?, ctx) + }) + .transpose()?; + + let value = self.expression(args.next()?, ctx)?; + + args.finish()?; + + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .extend(rctx.emitter.finish(&rctx.function.expressions)); + rctx.emitter.start(&rctx.function.expressions); + let stmt = crate::Statement::ImageStore { + image, + coordinate, + array_index, + value, + }; + rctx.block.push(stmt, span); + return Ok(None); + } + "textureLoad" => { + let mut args = ctx.prepare_args(arguments, 2, span); + + let image = args.next()?; + let image_span = ctx.ast_expressions.get_span(image); + let image = self.expression(image, ctx)?; + + let coordinate = self.expression(args.next()?, ctx)?; + + let (class, arrayed) = ctx.image_data(image, image_span)?; + let array_index = arrayed + .then(|| { + args.min_args += 1; + self.expression(args.next()?, ctx) + }) + .transpose()?; + + let level = class + .is_mipmapped() + .then(|| { + args.min_args += 1; + self.expression(args.next()?, ctx) + }) + .transpose()?; + + let sample = class + .is_multisampled() + .then(|| self.expression(args.next()?, ctx)) + .transpose()?; + + args.finish()?; + + crate::Expression::ImageLoad { + image, + coordinate, + array_index, + level, + sample, + } + } + "textureDimensions" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let image = self.expression(args.next()?, ctx)?; + let level = args + .next() + .map(|arg| self.expression(arg, ctx)) + .ok() + .transpose()?; + args.finish()?; + + crate::Expression::ImageQuery { + image, + query: crate::ImageQuery::Size { level }, + } + } + "textureNumLevels" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let image = self.expression(args.next()?, ctx)?; + args.finish()?; + + crate::Expression::ImageQuery { + image, + query: crate::ImageQuery::NumLevels, + } + } + "textureNumLayers" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let image = self.expression(args.next()?, ctx)?; + args.finish()?; + + crate::Expression::ImageQuery { + image, + query: crate::ImageQuery::NumLayers, + } + } + "textureNumSamples" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let image = self.expression(args.next()?, ctx)?; + args.finish()?; + + crate::Expression::ImageQuery { + image, + query: crate::ImageQuery::NumSamples, + } + } + "rayQueryInitialize" => { + let mut args = ctx.prepare_args(arguments, 3, span); + let query = self.ray_query_pointer(args.next()?, ctx)?; + let acceleration_structure = self.expression(args.next()?, ctx)?; + let descriptor = self.expression(args.next()?, ctx)?; + args.finish()?; + + let _ = ctx.module.generate_ray_desc_type(); + let fun = crate::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + }; + + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .extend(rctx.emitter.finish(&rctx.function.expressions)); + rctx.emitter.start(&rctx.function.expressions); + rctx.block + .push(crate::Statement::RayQuery { query, fun }, span); + return Ok(None); + } + "rayQueryProceed" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let query = self.ray_query_pointer(args.next()?, ctx)?; + args.finish()?; + + let result = ctx.interrupt_emitter( + crate::Expression::RayQueryProceedResult, + span, + )?; + let fun = crate::RayQueryFunction::Proceed { result }; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .push(crate::Statement::RayQuery { query, fun }, span); + return Ok(Some(result)); + } + "rayQueryGetCommittedIntersection" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let query = self.ray_query_pointer(args.next()?, ctx)?; + args.finish()?; + + let _ = ctx.module.generate_ray_intersection_type(); + + crate::Expression::RayQueryGetIntersection { + query, + committed: true, + } + } + "RayDesc" => { + let ty = ctx.module.generate_ray_desc_type(); + let handle = self.construct( + span, + &ast::ConstructorType::Type(ty), + function.span, + arguments, + ctx, + )?; + return Ok(Some(handle)); + } + _ => return Err(Error::UnknownIdent(function.span, function.name)), + } + }; + + let expr = ctx.append_expression(expr, span)?; + Ok(Some(expr)) + } + } + } + + fn atomic_pointer( + &mut self, + expr: Handle>, + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result, Error<'source>> { + let span = ctx.ast_expressions.get_span(expr); + let pointer = self.expression(expr, ctx)?; + + match *resolve_inner!(ctx, pointer) { + crate::TypeInner::Pointer { base, .. } => match ctx.module.types[base].inner { + crate::TypeInner::Atomic { .. } => Ok(pointer), + ref other => { + log::error!("Pointer type to {:?} passed to atomic op", other); + Err(Error::InvalidAtomicPointer(span)) + } + }, + ref other => { + log::error!("Type {:?} passed to atomic op", other); + Err(Error::InvalidAtomicPointer(span)) + } + } + } + + fn atomic_helper( + &mut self, + span: Span, + fun: crate::AtomicFunction, + args: &[Handle>], + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result, Error<'source>> { + let mut args = ctx.prepare_args(args, 2, span); + + let pointer = self.atomic_pointer(args.next()?, ctx)?; + + let value = args.next()?; + let value = self.expression(value, ctx)?; + let ty = ctx.register_type(value)?; + + args.finish()?; + + let result = ctx.interrupt_emitter( + crate::Expression::AtomicResult { + ty, + comparison: false, + }, + span, + )?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::Atomic { + pointer, + fun, + value, + result, + }, + span, + ); + Ok(result) + } + + fn texture_sample_helper( + &mut self, + fun: Texture, + args: &[Handle>], + span: Span, + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result> { + let mut args = ctx.prepare_args(args, fun.min_argument_count(), span); + + fn get_image_and_span<'source>( + lowerer: &mut Lowerer<'source, '_>, + args: &mut ArgumentContext<'_, 'source>, + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result<(Handle, Span), Error<'source>> { + let image = args.next()?; + let image_span = ctx.ast_expressions.get_span(image); + let image = lowerer.expression(image, ctx)?; + Ok((image, image_span)) + } + + let (image, image_span, gather) = match fun { + Texture::Gather => { + let image_or_component = args.next()?; + let image_or_component_span = ctx.ast_expressions.get_span(image_or_component); + // Gathers from depth textures don't take an initial `component` argument. + let lowered_image_or_component = self.expression(image_or_component, ctx)?; + + match *resolve_inner!(ctx, lowered_image_or_component) { + crate::TypeInner::Image { + class: crate::ImageClass::Depth { .. }, + .. + } => ( + lowered_image_or_component, + image_or_component_span, + Some(crate::SwizzleComponent::X), + ), + _ => { + let (image, image_span) = get_image_and_span(self, &mut args, ctx)?; + ( + image, + image_span, + Some(ctx.gather_component( + lowered_image_or_component, + image_or_component_span, + span, + )?), + ) + } + } + } + Texture::GatherCompare => { + let (image, image_span) = get_image_and_span(self, &mut args, ctx)?; + (image, image_span, Some(crate::SwizzleComponent::X)) + } + + _ => { + let (image, image_span) = get_image_and_span(self, &mut args, ctx)?; + (image, image_span, None) + } + }; + + let sampler = self.expression(args.next()?, ctx)?; + + let coordinate = self.expression(args.next()?, ctx)?; + + let (_, arrayed) = ctx.image_data(image, image_span)?; + let array_index = arrayed + .then(|| self.expression(args.next()?, ctx)) + .transpose()?; + + let (level, depth_ref) = match fun { + Texture::Gather => (crate::SampleLevel::Zero, None), + Texture::GatherCompare => { + let reference = self.expression(args.next()?, ctx)?; + (crate::SampleLevel::Zero, Some(reference)) + } + + Texture::Sample => (crate::SampleLevel::Auto, None), + Texture::SampleBias => { + let bias = self.expression(args.next()?, ctx)?; + (crate::SampleLevel::Bias(bias), None) + } + Texture::SampleCompare => { + let reference = self.expression(args.next()?, ctx)?; + (crate::SampleLevel::Auto, Some(reference)) + } + Texture::SampleCompareLevel => { + let reference = self.expression(args.next()?, ctx)?; + (crate::SampleLevel::Zero, Some(reference)) + } + Texture::SampleGrad => { + let x = self.expression(args.next()?, ctx)?; + let y = self.expression(args.next()?, ctx)?; + (crate::SampleLevel::Gradient { x, y }, None) + } + Texture::SampleLevel => { + let level = self.expression(args.next()?, ctx)?; + (crate::SampleLevel::Exact(level), None) + } + }; + + let offset = args + .next() + .map(|arg| self.expression(arg, &mut ctx.as_const())) + .ok() + .transpose()?; + + args.finish()?; + + Ok(crate::Expression::ImageSample { + image, + sampler, + gather, + coordinate, + array_index, + offset, + level, + depth_ref, + }) + } + + fn r#struct( + &mut self, + s: &ast::Struct<'source>, + span: Span, + ctx: &mut GlobalContext<'source, '_, '_>, + ) -> Result, Error<'source>> { + let mut offset = 0; + let mut struct_alignment = Alignment::ONE; + let mut members = Vec::with_capacity(s.members.len()); + + for member in s.members.iter() { + let ty = self.resolve_ast_type(member.ty, ctx)?; + + self.layouter.update(ctx.module.to_ctx()).unwrap(); + + let member_min_size = self.layouter[ty].size; + let member_min_alignment = self.layouter[ty].alignment; + + let member_size = if let Some(size_expr) = member.size { + let (size, span) = self.const_u32(size_expr, &mut ctx.as_const())?; + if size < member_min_size { + return Err(Error::SizeAttributeTooLow(span, member_min_size)); + } else { + size + } + } else { + member_min_size + }; + + let member_alignment = if let Some(align_expr) = member.align { + let (align, span) = self.const_u32(align_expr, &mut ctx.as_const())?; + if let Some(alignment) = Alignment::new(align) { + if alignment < member_min_alignment { + return Err(Error::AlignAttributeTooLow(span, member_min_alignment)); + } else { + alignment + } + } else { + return Err(Error::NonPowerOfTwoAlignAttribute(span)); + } + } else { + member_min_alignment + }; + + let binding = self.binding(&member.binding, ty, ctx)?; + + offset = member_alignment.round_up(offset); + struct_alignment = struct_alignment.max(member_alignment); + + members.push(crate::StructMember { + name: Some(member.name.name.to_owned()), + ty, + binding, + offset, + }); + + offset += member_size; + } + + let size = struct_alignment.round_up(offset); + let inner = crate::TypeInner::Struct { + members, + span: size, + }; + + let handle = ctx.module.types.insert( + crate::Type { + name: Some(s.name.name.to_string()), + inner, + }, + span, + ); + Ok(handle) + } + + fn const_u32( + &mut self, + expr: Handle>, + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result<(u32, Span), Error<'source>> { + let span = ctx.ast_expressions.get_span(expr); + let expr = self.expression(expr, ctx)?; + let value = ctx + .module + .to_ctx() + .eval_expr_to_u32(expr) + .map_err(|err| match err { + crate::proc::U32EvalError::NonConst => { + Error::ExpectedConstExprConcreteIntegerScalar(span) + } + crate::proc::U32EvalError::Negative => Error::ExpectedNonNegative(span), + })?; + Ok((value, span)) + } + + fn array_size( + &mut self, + size: ast::ArraySize<'source>, + ctx: &mut GlobalContext<'source, '_, '_>, + ) -> Result> { + Ok(match size { + ast::ArraySize::Constant(expr) => { + let span = ctx.ast_expressions.get_span(expr); + let const_expr = self.expression(expr, &mut ctx.as_const())?; + let len = + ctx.module + .to_ctx() + .eval_expr_to_u32(const_expr) + .map_err(|err| match err { + crate::proc::U32EvalError::NonConst => { + Error::ExpectedConstExprConcreteIntegerScalar(span) + } + crate::proc::U32EvalError::Negative => { + Error::ExpectedPositiveArrayLength(span) + } + })?; + let size = NonZeroU32::new(len).ok_or(Error::ExpectedPositiveArrayLength(span))?; + crate::ArraySize::Constant(size) + } + ast::ArraySize::Dynamic => crate::ArraySize::Dynamic, + }) + } + + /// Build the Naga equivalent of a named AST type. + /// + /// Return a Naga `Handle` representing the front-end type + /// `handle`, which should be named `name`, if given. + /// + /// If `handle` refers to a type cached in [`SpecialTypes`], + /// `name` may be ignored. + /// + /// [`SpecialTypes`]: crate::SpecialTypes + fn resolve_named_ast_type( + &mut self, + handle: Handle>, + name: Option, + ctx: &mut GlobalContext<'source, '_, '_>, + ) -> Result, Error<'source>> { + let inner = match ctx.types[handle] { + ast::Type::Scalar(scalar) => scalar.to_inner_scalar(), + ast::Type::Vector { size, scalar } => scalar.to_inner_vector(size), + ast::Type::Matrix { + rows, + columns, + width, + } => crate::TypeInner::Matrix { + columns, + rows, + scalar: crate::Scalar::float(width), + }, + ast::Type::Atomic(scalar) => scalar.to_inner_atomic(), + ast::Type::Pointer { base, space } => { + let base = self.resolve_ast_type(base, ctx)?; + crate::TypeInner::Pointer { base, space } + } + ast::Type::Array { base, size } => { + let base = self.resolve_ast_type(base, ctx)?; + let size = self.array_size(size, ctx)?; + + self.layouter.update(ctx.module.to_ctx()).unwrap(); + let stride = self.layouter[base].to_stride(); + + crate::TypeInner::Array { base, size, stride } + } + ast::Type::Image { + dim, + arrayed, + class, + } => crate::TypeInner::Image { + dim, + arrayed, + class, + }, + ast::Type::Sampler { comparison } => crate::TypeInner::Sampler { comparison }, + ast::Type::AccelerationStructure => crate::TypeInner::AccelerationStructure, + ast::Type::RayQuery => crate::TypeInner::RayQuery, + ast::Type::BindingArray { base, size } => { + let base = self.resolve_ast_type(base, ctx)?; + let size = self.array_size(size, ctx)?; + crate::TypeInner::BindingArray { base, size } + } + ast::Type::RayDesc => { + return Ok(ctx.module.generate_ray_desc_type()); + } + ast::Type::RayIntersection => { + return Ok(ctx.module.generate_ray_intersection_type()); + } + ast::Type::User(ref ident) => { + return match ctx.globals.get(ident.name) { + Some(&LoweredGlobalDecl::Type(handle)) => Ok(handle), + Some(_) => Err(Error::Unexpected(ident.span, ExpectedToken::Type)), + None => Err(Error::UnknownType(ident.span)), + } + } + }; + + Ok(ctx.ensure_type_exists(name, inner)) + } + + /// Return a Naga `Handle` representing the front-end type `handle`. + fn resolve_ast_type( + &mut self, + handle: Handle>, + ctx: &mut GlobalContext<'source, '_, '_>, + ) -> Result, Error<'source>> { + self.resolve_named_ast_type(handle, None, ctx) + } + + fn binding( + &mut self, + binding: &Option>, + ty: Handle, + ctx: &mut GlobalContext<'source, '_, '_>, + ) -> Result, Error<'source>> { + Ok(match *binding { + Some(ast::Binding::BuiltIn(b)) => Some(crate::Binding::BuiltIn(b)), + Some(ast::Binding::Location { + location, + second_blend_source, + interpolation, + sampling, + }) => { + let mut binding = crate::Binding::Location { + location: self.const_u32(location, &mut ctx.as_const())?.0, + second_blend_source, + interpolation, + sampling, + }; + binding.apply_default_interpolation(&ctx.module.types[ty].inner); + Some(binding) + } + None => None, + }) + } + + fn ray_query_pointer( + &mut self, + expr: Handle>, + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result, Error<'source>> { + let span = ctx.ast_expressions.get_span(expr); + let pointer = self.expression(expr, ctx)?; + + match *resolve_inner!(ctx, pointer) { + crate::TypeInner::Pointer { base, .. } => match ctx.module.types[base].inner { + crate::TypeInner::RayQuery => Ok(pointer), + ref other => { + log::error!("Pointer type to {:?} passed to ray query op", other); + Err(Error::InvalidRayQueryPointer(span)) + } + }, + ref other => { + log::error!("Type {:?} passed to ray query op", other); + Err(Error::InvalidRayQueryPointer(span)) + } + } + } +} -- cgit v1.2.3