From 43a97878ce14b72f0981164f87f2e35e14151312 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sun, 7 Apr 2024 11:22:09 +0200 Subject: Adding upstream version 110.0.1. Signed-off-by: Daniel Baumann --- .../rust/naga/src/front/wgsl/construction.rs | 679 +++ third_party/rust/naga/src/front/wgsl/conv.rs | 225 + third_party/rust/naga/src/front/wgsl/lexer.rs | 671 +++ third_party/rust/naga/src/front/wgsl/mod.rs | 4750 ++++++++++++++++++++ third_party/rust/naga/src/front/wgsl/number.rs | 442 ++ third_party/rust/naga/src/front/wgsl/tests.rs | 458 ++ 6 files changed, 7225 insertions(+) create mode 100644 third_party/rust/naga/src/front/wgsl/construction.rs create mode 100644 third_party/rust/naga/src/front/wgsl/conv.rs create mode 100644 third_party/rust/naga/src/front/wgsl/lexer.rs create mode 100644 third_party/rust/naga/src/front/wgsl/mod.rs create mode 100644 third_party/rust/naga/src/front/wgsl/number.rs create mode 100644 third_party/rust/naga/src/front/wgsl/tests.rs (limited to 'third_party/rust/naga/src/front/wgsl') diff --git a/third_party/rust/naga/src/front/wgsl/construction.rs b/third_party/rust/naga/src/front/wgsl/construction.rs new file mode 100644 index 0000000000..43e719d0f3 --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/construction.rs @@ -0,0 +1,679 @@ +use crate::{ + proc::TypeResolution, Arena, ArraySize, Bytes, Constant, ConstantInner, Expression, Handle, + ScalarKind, ScalarValue, Span as NagaSpan, Type, TypeInner, UniqueArena, VectorSize, +}; + +use super::{Error, ExpressionContext, Lexer, Parser, Rule, Span, Token}; + +/// Represents the type of the constructor +/// +/// Vectors, Matrices and Arrays can have partial type information +/// which later gets inferred from the constructor parameters +enum ConstructorType { + Scalar { + kind: ScalarKind, + width: Bytes, + }, + PartialVector { + size: VectorSize, + }, + Vector { + size: VectorSize, + kind: ScalarKind, + width: Bytes, + }, + PartialMatrix { + columns: VectorSize, + rows: VectorSize, + }, + Matrix { + columns: VectorSize, + rows: VectorSize, + width: Bytes, + }, + PartialArray, + Array { + base: Handle, + size: ArraySize, + stride: u32, + }, + Struct(Handle), +} + +impl ConstructorType { + const fn to_type_resolution(&self) -> Option { + Some(match *self { + ConstructorType::Scalar { kind, width } => { + TypeResolution::Value(TypeInner::Scalar { kind, width }) + } + ConstructorType::Vector { size, kind, width } => { + TypeResolution::Value(TypeInner::Vector { size, kind, width }) + } + ConstructorType::Matrix { + columns, + rows, + width, + } => TypeResolution::Value(TypeInner::Matrix { + columns, + rows, + width, + }), + ConstructorType::Array { base, size, stride } => { + TypeResolution::Value(TypeInner::Array { base, size, stride }) + } + ConstructorType::Struct(handle) => TypeResolution::Handle(handle), + _ => return None, + }) + } +} + +impl ConstructorType { + fn to_error_string(&self, types: &UniqueArena, constants: &Arena) -> String { + match *self { + ConstructorType::Scalar { kind, width } => kind.to_wgsl(width), + ConstructorType::PartialVector { size } => { + format!("vec{}", size as u32,) + } + ConstructorType::Vector { size, kind, width } => { + format!("vec{}<{}>", size as u32, kind.to_wgsl(width)) + } + ConstructorType::PartialMatrix { columns, rows } => { + format!("mat{}x{}", columns as u32, rows as u32,) + } + ConstructorType::Matrix { + columns, + rows, + width, + } => { + format!( + "mat{}x{}<{}>", + columns as u32, + rows as u32, + ScalarKind::Float.to_wgsl(width) + ) + } + ConstructorType::PartialArray => "array".to_string(), + ConstructorType::Array { base, size, .. } => { + format!( + "array<{}, {}>", + types[base].name.as_deref().unwrap_or("?"), + match size { + ArraySize::Constant(size) => { + constants[size] + .to_array_length() + .map(|len| len.to_string()) + .unwrap_or_else(|| "?".to_string()) + } + _ => unreachable!(), + } + ) + } + ConstructorType::Struct(handle) => types[handle] + .name + .clone() + .unwrap_or_else(|| "?".to_string()), + } + } +} + +fn parse_constructor_type<'a>( + parser: &mut Parser, + lexer: &mut Lexer<'a>, + word: &'a str, + type_arena: &mut UniqueArena, + const_arena: &mut Arena, +) -> Result, Error<'a>> { + if let Some((kind, width)) = super::conv::get_scalar_type(word) { + return Ok(Some(ConstructorType::Scalar { kind, width })); + } + + let partial = match word { + "vec2" => ConstructorType::PartialVector { + size: VectorSize::Bi, + }, + "vec3" => ConstructorType::PartialVector { + size: VectorSize::Tri, + }, + "vec4" => ConstructorType::PartialVector { + size: VectorSize::Quad, + }, + "mat2x2" => ConstructorType::PartialMatrix { + columns: VectorSize::Bi, + rows: VectorSize::Bi, + }, + "mat2x3" => ConstructorType::PartialMatrix { + columns: VectorSize::Bi, + rows: VectorSize::Tri, + }, + "mat2x4" => ConstructorType::PartialMatrix { + columns: VectorSize::Bi, + rows: VectorSize::Quad, + }, + "mat3x2" => ConstructorType::PartialMatrix { + columns: VectorSize::Tri, + rows: VectorSize::Bi, + }, + "mat3x3" => ConstructorType::PartialMatrix { + columns: VectorSize::Tri, + rows: VectorSize::Tri, + }, + "mat3x4" => ConstructorType::PartialMatrix { + columns: VectorSize::Tri, + rows: VectorSize::Quad, + }, + "mat4x2" => ConstructorType::PartialMatrix { + columns: VectorSize::Quad, + rows: VectorSize::Bi, + }, + "mat4x3" => ConstructorType::PartialMatrix { + columns: VectorSize::Quad, + rows: VectorSize::Tri, + }, + "mat4x4" => ConstructorType::PartialMatrix { + columns: VectorSize::Quad, + rows: VectorSize::Quad, + }, + "array" => ConstructorType::PartialArray, + _ => return Ok(None), + }; + + // parse component type if present + match (lexer.peek().0, partial) { + (Token::Paren('<'), ConstructorType::PartialVector { size }) => { + let (kind, width) = lexer.next_scalar_generic()?; + Ok(Some(ConstructorType::Vector { size, kind, width })) + } + (Token::Paren('<'), ConstructorType::PartialMatrix { columns, rows }) => { + let (kind, width, span) = lexer.next_scalar_generic_with_span()?; + match kind { + ScalarKind::Float => Ok(Some(ConstructorType::Matrix { + columns, + rows, + width, + })), + _ => Err(Error::BadMatrixScalarKind(span, kind, width)), + } + } + (Token::Paren('<'), ConstructorType::PartialArray) => { + lexer.expect_generic_paren('<')?; + let base = parser.parse_type_decl(lexer, None, type_arena, const_arena)?; + let size = if lexer.skip(Token::Separator(',')) { + let const_handle = parser.parse_const_expression(lexer, type_arena, const_arena)?; + ArraySize::Constant(const_handle) + } else { + ArraySize::Dynamic + }; + lexer.expect_generic_paren('>')?; + + let stride = { + parser.layouter.update(type_arena, const_arena).unwrap(); + parser.layouter[base].to_stride() + }; + + Ok(Some(ConstructorType::Array { base, size, stride })) + } + (_, partial) => Ok(Some(partial)), + } +} + +/// Expects [`Rule::PrimaryExpr`] on top of rule stack; if returning Some(_), pops it. +pub(super) fn parse_construction<'a>( + parser: &mut Parser, + lexer: &mut Lexer<'a>, + type_name: &'a str, + type_span: Span, + mut ctx: ExpressionContext<'a, '_, '_>, +) -> Result>, Error<'a>> { + assert_eq!( + parser.rules.last().map(|&(ref rule, _)| rule.clone()), + Some(Rule::PrimaryExpr) + ); + let dst_ty = match parser.lookup_type.get(type_name) { + Some(&handle) => ConstructorType::Struct(handle), + None => match parse_constructor_type(parser, lexer, type_name, ctx.types, ctx.constants)? { + Some(inner) => inner, + None => { + match parser.parse_type_decl_impl( + lexer, + super::TypeAttributes::default(), + type_name, + ctx.types, + ctx.constants, + )? { + Some(_) => { + return Err(Error::TypeNotConstructible(type_span)); + } + None => return Ok(None), + } + } + }, + }; + + lexer.open_arguments()?; + + let mut components = Vec::new(); + let mut spans = Vec::new(); + + if lexer.peek().0 == Token::Paren(')') { + let _ = lexer.next(); + } else { + while components.is_empty() || lexer.next_argument()? { + let (component, span) = lexer + .capture_span(|lexer| parser.parse_general_expression(lexer, ctx.reborrow()))?; + components.push(component); + spans.push(span); + } + } + + enum Components<'a> { + None, + One { + component: Handle, + span: Span, + ty: &'a TypeInner, + }, + Many { + components: Vec>, + spans: Vec, + first_component_ty: &'a TypeInner, + }, + } + + impl<'a> Components<'a> { + fn into_components_vec(self) -> Vec> { + match self { + Components::None => vec![], + Components::One { component, .. } => vec![component], + Components::Many { components, .. } => components, + } + } + } + + let components = match *components.as_slice() { + [] => Components::None, + [component] => { + ctx.resolve_type(component)?; + Components::One { + component, + span: spans[0].clone(), + ty: ctx.typifier.get(component, ctx.types), + } + } + [component, ..] => { + ctx.resolve_type(component)?; + Components::Many { + components, + spans, + first_component_ty: ctx.typifier.get(component, ctx.types), + } + } + }; + + let expr = match (components, dst_ty) { + // Empty constructor + (Components::None, dst_ty) => { + let ty = match dst_ty.to_type_resolution() { + Some(TypeResolution::Handle(handle)) => handle, + Some(TypeResolution::Value(inner)) => ctx + .types + .insert(Type { name: None, inner }, Default::default()), + None => return Err(Error::TypeNotInferrable(type_span)), + }; + + return match ctx.create_zero_value_constant(ty) { + Some(constant) => { + let span = parser.pop_rule_span(lexer); + Ok(Some(ctx.interrupt_emitter( + Expression::Constant(constant), + span.into(), + ))) + } + None => Err(Error::TypeNotConstructible(type_span)), + }; + } + + // Scalar constructor & conversion (scalar -> scalar) + ( + Components::One { + component, + ty: &TypeInner::Scalar { .. }, + .. + }, + ConstructorType::Scalar { kind, width }, + ) => Expression::As { + expr: component, + kind, + convert: Some(width), + }, + + // Vector conversion (vector -> vector) + ( + Components::One { + component, + ty: &TypeInner::Vector { size: src_size, .. }, + .. + }, + ConstructorType::Vector { + size: dst_size, + kind: dst_kind, + width: dst_width, + }, + ) if dst_size == src_size => Expression::As { + expr: component, + kind: dst_kind, + convert: Some(dst_width), + }, + + // Vector conversion (vector -> vector) - partial + ( + Components::One { + component, + ty: + &TypeInner::Vector { + size: src_size, + kind: src_kind, + .. + }, + .. + }, + ConstructorType::PartialVector { size: dst_size }, + ) if dst_size == src_size => Expression::As { + expr: component, + kind: src_kind, + convert: None, + }, + + // Matrix conversion (matrix -> matrix) + ( + Components::One { + component, + ty: + &TypeInner::Matrix { + columns: src_columns, + rows: src_rows, + .. + }, + .. + }, + ConstructorType::Matrix { + columns: dst_columns, + rows: dst_rows, + width: dst_width, + }, + ) if dst_columns == src_columns && dst_rows == src_rows => Expression::As { + expr: component, + kind: ScalarKind::Float, + convert: Some(dst_width), + }, + + // Matrix conversion (matrix -> matrix) - partial + ( + Components::One { + component, + ty: + &TypeInner::Matrix { + columns: src_columns, + rows: src_rows, + .. + }, + .. + }, + ConstructorType::PartialMatrix { + columns: dst_columns, + rows: dst_rows, + }, + ) if dst_columns == src_columns && dst_rows == src_rows => Expression::As { + expr: component, + kind: ScalarKind::Float, + convert: None, + }, + + // Vector constructor (splat) - infer type + ( + Components::One { + component, + ty: &TypeInner::Scalar { .. }, + .. + }, + ConstructorType::PartialVector { size }, + ) => Expression::Splat { + size, + value: component, + }, + + // Vector constructor (splat) + ( + Components::One { + component, + ty: + &TypeInner::Scalar { + kind: src_kind, + width: src_width, + .. + }, + .. + }, + ConstructorType::Vector { + size, + kind: dst_kind, + width: dst_width, + }, + ) if dst_kind == src_kind || dst_width == src_width => Expression::Splat { + size, + value: component, + }, + + // Vector constructor (by elements) + ( + Components::Many { + components, + first_component_ty: + &TypeInner::Scalar { kind, width } | &TypeInner::Vector { kind, width, .. }, + .. + }, + ConstructorType::PartialVector { size }, + ) + | ( + Components::Many { + components, + first_component_ty: &TypeInner::Scalar { .. } | &TypeInner::Vector { .. }, + .. + }, + ConstructorType::Vector { size, width, kind }, + ) => { + let ty = ctx.types.insert( + Type { + name: None, + inner: TypeInner::Vector { size, kind, width }, + }, + Default::default(), + ); + Expression::Compose { ty, components } + } + + // Matrix constructor (by elements) + ( + Components::Many { + components, + first_component_ty: &TypeInner::Scalar { width, .. }, + .. + }, + ConstructorType::PartialMatrix { columns, rows }, + ) + | ( + Components::Many { + components, + first_component_ty: &TypeInner::Scalar { .. }, + .. + }, + ConstructorType::Matrix { + columns, + rows, + width, + }, + ) => { + let vec_ty = ctx.types.insert( + Type { + name: None, + inner: TypeInner::Vector { + width, + kind: ScalarKind::Float, + size: rows, + }, + }, + Default::default(), + ); + + let components = components + .chunks(rows as usize) + .map(|vec_components| { + ctx.expressions.append( + Expression::Compose { + ty: vec_ty, + components: Vec::from(vec_components), + }, + Default::default(), + ) + }) + .collect(); + + let ty = ctx.types.insert( + Type { + name: None, + inner: TypeInner::Matrix { + columns, + rows, + width, + }, + }, + Default::default(), + ); + Expression::Compose { ty, components } + } + + // Matrix constructor (by columns) + ( + Components::Many { + components, + first_component_ty: &TypeInner::Vector { width, .. }, + .. + }, + ConstructorType::PartialMatrix { columns, rows }, + ) + | ( + Components::Many { + components, + first_component_ty: &TypeInner::Vector { .. }, + .. + }, + ConstructorType::Matrix { + columns, + rows, + width, + }, + ) => { + let ty = ctx.types.insert( + Type { + name: None, + inner: TypeInner::Matrix { + columns, + rows, + width, + }, + }, + Default::default(), + ); + Expression::Compose { ty, components } + } + + // Array constructor - infer type + (components, ConstructorType::PartialArray) => { + let components = components.into_components_vec(); + + let base = match ctx.typifier[components[0]].clone() { + TypeResolution::Handle(ty) => ty, + TypeResolution::Value(inner) => ctx + .types + .insert(Type { name: None, inner }, Default::default()), + }; + + let size = Constant { + name: None, + specialization: None, + inner: ConstantInner::Scalar { + width: 4, + value: ScalarValue::Uint(components.len() as u64), + }, + }; + + let inner = TypeInner::Array { + base, + size: ArraySize::Constant(ctx.constants.append(size, Default::default())), + stride: { + parser.layouter.update(ctx.types, ctx.constants).unwrap(); + parser.layouter[base].to_stride() + }, + }; + + let ty = ctx + .types + .insert(Type { name: None, inner }, Default::default()); + + Expression::Compose { ty, components } + } + + // Array constructor + (components, ConstructorType::Array { base, size, stride }) => { + let components = components.into_components_vec(); + let inner = TypeInner::Array { base, size, stride }; + let ty = ctx + .types + .insert(Type { name: None, inner }, Default::default()); + Expression::Compose { ty, components } + } + + // Struct constructor + (components, ConstructorType::Struct(ty)) => Expression::Compose { + ty, + components: components.into_components_vec(), + }, + + // ERRORS + + // Bad conversion (type cast) + ( + Components::One { + span, ty: src_ty, .. + }, + dst_ty, + ) => { + return Err(Error::BadTypeCast { + span, + from_type: src_ty.to_wgsl(ctx.types, ctx.constants), + to_type: dst_ty.to_error_string(ctx.types, ctx.constants), + }); + } + + // Too many parameters for scalar constructor + (Components::Many { spans, .. }, ConstructorType::Scalar { .. }) => { + return Err(Error::UnexpectedComponents(Span { + start: spans[1].start, + end: spans.last().unwrap().end, + })); + } + + // Parameters are of the wrong type for vector or matrix constructor + ( + Components::Many { spans, .. }, + ConstructorType::Vector { .. } + | ConstructorType::Matrix { .. } + | ConstructorType::PartialVector { .. } + | ConstructorType::PartialMatrix { .. }, + ) => { + return Err(Error::InvalidConstructorComponentType(spans[0].clone(), 0)); + } + }; + + let span = NagaSpan::from(parser.pop_rule_span(lexer)); + Ok(Some(ctx.expressions.append(expr, span))) +} diff --git a/third_party/rust/naga/src/front/wgsl/conv.rs b/third_party/rust/naga/src/front/wgsl/conv.rs new file mode 100644 index 0000000000..ba41648757 --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/conv.rs @@ -0,0 +1,225 @@ +use super::{Error, Span}; + +pub fn map_address_space(word: &str, span: Span) -> Result> { + match word { + "private" => Ok(crate::AddressSpace::Private), + "workgroup" => Ok(crate::AddressSpace::WorkGroup), + "uniform" => Ok(crate::AddressSpace::Uniform), + "storage" => Ok(crate::AddressSpace::Storage { + access: crate::StorageAccess::default(), + }), + "push_constant" => Ok(crate::AddressSpace::PushConstant), + "function" => Ok(crate::AddressSpace::Function), + _ => Err(Error::UnknownAddressSpace(span)), + } +} + +pub fn map_built_in(word: &str, span: Span) -> Result> { + Ok(match word { + "position" => crate::BuiltIn::Position { invariant: false }, + // vertex + "vertex_index" => crate::BuiltIn::VertexIndex, + "instance_index" => crate::BuiltIn::InstanceIndex, + "view_index" => crate::BuiltIn::ViewIndex, + // fragment + "front_facing" => crate::BuiltIn::FrontFacing, + "frag_depth" => crate::BuiltIn::FragDepth, + "primitive_index" => crate::BuiltIn::PrimitiveIndex, + "sample_index" => crate::BuiltIn::SampleIndex, + "sample_mask" => crate::BuiltIn::SampleMask, + // compute + "global_invocation_id" => crate::BuiltIn::GlobalInvocationId, + "local_invocation_id" => crate::BuiltIn::LocalInvocationId, + "local_invocation_index" => crate::BuiltIn::LocalInvocationIndex, + "workgroup_id" => crate::BuiltIn::WorkGroupId, + "workgroup_size" => crate::BuiltIn::WorkGroupSize, + "num_workgroups" => crate::BuiltIn::NumWorkGroups, + _ => return Err(Error::UnknownBuiltin(span)), + }) +} + +pub fn map_interpolation(word: &str, span: Span) -> Result> { + match word { + "linear" => Ok(crate::Interpolation::Linear), + "flat" => Ok(crate::Interpolation::Flat), + "perspective" => Ok(crate::Interpolation::Perspective), + _ => Err(Error::UnknownAttribute(span)), + } +} + +pub fn map_sampling(word: &str, span: Span) -> Result> { + match word { + "center" => Ok(crate::Sampling::Center), + "centroid" => Ok(crate::Sampling::Centroid), + "sample" => Ok(crate::Sampling::Sample), + _ => Err(Error::UnknownAttribute(span)), + } +} + +pub fn map_storage_format(word: &str, span: Span) -> Result> { + use crate::StorageFormat as Sf; + Ok(match word { + "r8unorm" => Sf::R8Unorm, + "r8snorm" => Sf::R8Snorm, + "r8uint" => Sf::R8Uint, + "r8sint" => Sf::R8Sint, + "r16uint" => Sf::R16Uint, + "r16sint" => Sf::R16Sint, + "r16float" => Sf::R16Float, + "rg8unorm" => Sf::Rg8Unorm, + "rg8snorm" => Sf::Rg8Snorm, + "rg8uint" => Sf::Rg8Uint, + "rg8sint" => Sf::Rg8Sint, + "r32uint" => Sf::R32Uint, + "r32sint" => Sf::R32Sint, + "r32float" => Sf::R32Float, + "rg16uint" => Sf::Rg16Uint, + "rg16sint" => Sf::Rg16Sint, + "rg16float" => Sf::Rg16Float, + "rgba8unorm" => Sf::Rgba8Unorm, + "rgba8snorm" => Sf::Rgba8Snorm, + "rgba8uint" => Sf::Rgba8Uint, + "rgba8sint" => Sf::Rgba8Sint, + "rgb10a2unorm" => Sf::Rgb10a2Unorm, + "rg11b10float" => Sf::Rg11b10Float, + "rg32uint" => Sf::Rg32Uint, + "rg32sint" => Sf::Rg32Sint, + "rg32float" => Sf::Rg32Float, + "rgba16uint" => Sf::Rgba16Uint, + "rgba16sint" => Sf::Rgba16Sint, + "rgba16float" => Sf::Rgba16Float, + "rgba32uint" => Sf::Rgba32Uint, + "rgba32sint" => Sf::Rgba32Sint, + "rgba32float" => Sf::Rgba32Float, + _ => return Err(Error::UnknownStorageFormat(span)), + }) +} + +pub fn get_scalar_type(word: &str) -> Option<(crate::ScalarKind, crate::Bytes)> { + match word { + "f16" => Some((crate::ScalarKind::Float, 2)), + "f32" => Some((crate::ScalarKind::Float, 4)), + "f64" => Some((crate::ScalarKind::Float, 8)), + "i8" => Some((crate::ScalarKind::Sint, 1)), + "i16" => Some((crate::ScalarKind::Sint, 2)), + "i32" => Some((crate::ScalarKind::Sint, 4)), + "i64" => Some((crate::ScalarKind::Sint, 8)), + "u8" => Some((crate::ScalarKind::Uint, 1)), + "u16" => Some((crate::ScalarKind::Uint, 2)), + "u32" => Some((crate::ScalarKind::Uint, 4)), + "u64" => Some((crate::ScalarKind::Uint, 8)), + "bool" => Some((crate::ScalarKind::Bool, crate::BOOL_WIDTH)), + _ => None, + } +} + +pub fn map_derivative_axis(word: &str) -> Option { + match word { + "dpdx" => Some(crate::DerivativeAxis::X), + "dpdy" => Some(crate::DerivativeAxis::Y), + "fwidth" => Some(crate::DerivativeAxis::Width), + _ => None, + } +} + +pub fn map_relational_fun(word: &str) -> Option { + match word { + "any" => Some(crate::RelationalFunction::Any), + "all" => Some(crate::RelationalFunction::All), + "isFinite" => Some(crate::RelationalFunction::IsFinite), + "isNormal" => Some(crate::RelationalFunction::IsNormal), + _ => None, + } +} + +pub fn map_standard_fun(word: &str) -> Option { + use crate::MathFunction as Mf; + Some(match word { + // comparison + "abs" => Mf::Abs, + "min" => Mf::Min, + "max" => Mf::Max, + "clamp" => Mf::Clamp, + "saturate" => Mf::Saturate, + // trigonometry + "cos" => Mf::Cos, + "cosh" => Mf::Cosh, + "sin" => Mf::Sin, + "sinh" => Mf::Sinh, + "tan" => Mf::Tan, + "tanh" => Mf::Tanh, + "acos" => Mf::Acos, + "asin" => Mf::Asin, + "atan" => Mf::Atan, + "atan2" => Mf::Atan2, + "radians" => Mf::Radians, + "degrees" => Mf::Degrees, + // decomposition + "ceil" => Mf::Ceil, + "floor" => Mf::Floor, + "round" => Mf::Round, + "fract" => Mf::Fract, + "trunc" => Mf::Trunc, + "modf" => Mf::Modf, + "frexp" => Mf::Frexp, + "ldexp" => Mf::Ldexp, + // exponent + "exp" => Mf::Exp, + "exp2" => Mf::Exp2, + "log" => Mf::Log, + "log2" => Mf::Log2, + "pow" => Mf::Pow, + // geometry + "dot" => Mf::Dot, + "outerProduct" => Mf::Outer, + "cross" => Mf::Cross, + "distance" => Mf::Distance, + "length" => Mf::Length, + "normalize" => Mf::Normalize, + "faceForward" => Mf::FaceForward, + "reflect" => Mf::Reflect, + // computational + "sign" => Mf::Sign, + "fma" => Mf::Fma, + "mix" => Mf::Mix, + "step" => Mf::Step, + "smoothstep" => Mf::SmoothStep, + "sqrt" => Mf::Sqrt, + "inverseSqrt" => Mf::InverseSqrt, + "transpose" => Mf::Transpose, + "determinant" => Mf::Determinant, + // bits + "countOneBits" => Mf::CountOneBits, + "reverseBits" => Mf::ReverseBits, + "extractBits" => Mf::ExtractBits, + "insertBits" => Mf::InsertBits, + "firstTrailingBit" => Mf::FindLsb, + "firstLeadingBit" => Mf::FindMsb, + // data packing + "pack4x8snorm" => Mf::Pack4x8snorm, + "pack4x8unorm" => Mf::Pack4x8unorm, + "pack2x16snorm" => Mf::Pack2x16snorm, + "pack2x16unorm" => Mf::Pack2x16unorm, + "pack2x16float" => Mf::Pack2x16float, + // data unpacking + "unpack4x8snorm" => Mf::Unpack4x8snorm, + "unpack4x8unorm" => Mf::Unpack4x8unorm, + "unpack2x16snorm" => Mf::Unpack2x16snorm, + "unpack2x16unorm" => Mf::Unpack2x16unorm, + "unpack2x16float" => Mf::Unpack2x16float, + _ => return None, + }) +} + +pub fn map_conservative_depth( + word: &str, + span: Span, +) -> Result> { + use crate::ConservativeDepth as Cd; + match word { + "greater_equal" => Ok(Cd::GreaterEqual), + "less_equal" => Ok(Cd::LessEqual), + "unchanged" => Ok(Cd::Unchanged), + _ => Err(Error::UnknownConservativeDepth(span)), + } +} diff --git a/third_party/rust/naga/src/front/wgsl/lexer.rs b/third_party/rust/naga/src/front/wgsl/lexer.rs new file mode 100644 index 0000000000..35fe450892 --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/lexer.rs @@ -0,0 +1,671 @@ +use super::{conv, number::consume_number, Error, ExpectedToken, Span, Token, TokenSpan}; + +fn consume_any(input: &str, what: impl Fn(char) -> bool) -> (&str, &str) { + let pos = input.find(|c| !what(c)).unwrap_or(input.len()); + input.split_at(pos) +} + +fn consume_token(input: &str, generic: bool) -> (Token<'_>, &str) { + let mut chars = input.chars(); + let cur = match chars.next() { + Some(c) => c, + None => return (Token::End, ""), + }; + match cur { + ':' | ';' | ',' => (Token::Separator(cur), chars.as_str()), + '.' => { + let og_chars = chars.as_str(); + match chars.next() { + Some('0'..='9') => consume_number(input), + _ => (Token::Separator(cur), og_chars), + } + } + '@' => (Token::Attribute, chars.as_str()), + '(' | ')' | '{' | '}' | '[' | ']' => (Token::Paren(cur), chars.as_str()), + '<' | '>' => { + let og_chars = chars.as_str(); + match chars.next() { + Some('=') if !generic => (Token::LogicalOperation(cur), chars.as_str()), + Some(c) if c == cur && !generic => { + let og_chars = chars.as_str(); + match chars.next() { + Some('=') => (Token::AssignmentOperation(cur), chars.as_str()), + _ => (Token::ShiftOperation(cur), og_chars), + } + } + _ => (Token::Paren(cur), og_chars), + } + } + '0'..='9' => consume_number(input), + '/' => { + let og_chars = chars.as_str(); + match chars.next() { + Some('/') => { + let _ = chars.position(is_comment_end); + (Token::Trivia, chars.as_str()) + } + Some('*') => { + let mut depth = 1; + let mut prev = None; + + for c in &mut chars { + match (prev, c) { + (Some('*'), '/') => { + prev = None; + depth -= 1; + if depth == 0 { + return (Token::Trivia, chars.as_str()); + } + } + (Some('/'), '*') => { + prev = None; + depth += 1; + } + _ => { + prev = Some(c); + } + } + } + + (Token::End, "") + } + Some('=') => (Token::AssignmentOperation(cur), chars.as_str()), + _ => (Token::Operation(cur), og_chars), + } + } + '-' => { + let og_chars = chars.as_str(); + match chars.next() { + Some('>') => (Token::Arrow, chars.as_str()), + Some('0'..='9' | '.') => consume_number(input), + Some('-') => (Token::DecrementOperation, chars.as_str()), + Some('=') => (Token::AssignmentOperation(cur), chars.as_str()), + _ => (Token::Operation(cur), og_chars), + } + } + '+' => { + let og_chars = chars.as_str(); + match chars.next() { + Some('+') => (Token::IncrementOperation, chars.as_str()), + Some('=') => (Token::AssignmentOperation(cur), chars.as_str()), + _ => (Token::Operation(cur), og_chars), + } + } + '*' | '%' | '^' => { + let og_chars = chars.as_str(); + match chars.next() { + Some('=') => (Token::AssignmentOperation(cur), chars.as_str()), + _ => (Token::Operation(cur), og_chars), + } + } + '~' => (Token::Operation(cur), chars.as_str()), + '=' | '!' => { + let og_chars = chars.as_str(); + match chars.next() { + Some('=') => (Token::LogicalOperation(cur), chars.as_str()), + _ => (Token::Operation(cur), og_chars), + } + } + '&' | '|' => { + let og_chars = chars.as_str(); + match chars.next() { + Some(c) if c == cur => (Token::LogicalOperation(cur), chars.as_str()), + Some('=') => (Token::AssignmentOperation(cur), chars.as_str()), + _ => (Token::Operation(cur), og_chars), + } + } + _ if is_blankspace(cur) => { + let (_, rest) = consume_any(input, is_blankspace); + (Token::Trivia, rest) + } + _ if is_word_start(cur) => { + let (word, rest) = consume_any(input, is_word_part); + (Token::Word(word), rest) + } + _ => (Token::Unknown(cur), chars.as_str()), + } +} + +/// Returns whether or not a char is a comment end +/// (Unicode Pattern_White_Space excluding U+0020, U+0009, U+200E and U+200F) +const fn is_comment_end(c: char) -> bool { + match c { + '\u{000a}'..='\u{000d}' | '\u{0085}' | '\u{2028}' | '\u{2029}' => true, + _ => false, + } +} + +/// Returns whether or not a char is a blankspace (Unicode Pattern_White_Space) +const fn is_blankspace(c: char) -> bool { + match c { + '\u{0020}' + | '\u{0009}'..='\u{000d}' + | '\u{0085}' + | '\u{200e}' + | '\u{200f}' + | '\u{2028}' + | '\u{2029}' => true, + _ => false, + } +} + +/// Returns whether or not a char is a word start (Unicode XID_Start + '_') +fn is_word_start(c: char) -> bool { + c == '_' || unicode_xid::UnicodeXID::is_xid_start(c) +} + +/// Returns whether or not a char is a word part (Unicode XID_Continue) +fn is_word_part(c: char) -> bool { + unicode_xid::UnicodeXID::is_xid_continue(c) +} + +#[derive(Clone)] +pub(super) struct Lexer<'a> { + input: &'a str, + pub(super) source: &'a str, + // The byte offset of the end of the last non-trivia token. + last_end_offset: usize, +} + +impl<'a> Lexer<'a> { + pub(super) const fn new(input: &'a str) -> Self { + Lexer { + input, + source: input, + last_end_offset: 0, + } + } + + pub(super) const fn _leftover_span(&self) -> Span { + self.source.len() - self.input.len()..self.source.len() + } + + /// Calls the function with a lexer and returns the result of the function as well as the span for everything the function parsed + /// + /// # Examples + /// ```ignore + /// let lexer = Lexer::new("5"); + /// let (value, span) = lexer.capture_span(Lexer::next_uint_literal); + /// assert_eq!(value, 5); + /// ``` + #[inline] + pub fn capture_span( + &mut self, + inner: impl FnOnce(&mut Self) -> Result, + ) -> Result<(T, Span), E> { + let start = self.current_byte_offset(); + let res = inner(self)?; + let end = self.current_byte_offset(); + Ok((res, start..end)) + } + + pub(super) fn start_byte_offset(&mut self) -> usize { + loop { + // Eat all trivia because `next` doesn't eat trailing trivia. + let (token, rest) = consume_token(self.input, false); + if let Token::Trivia = token { + self.input = rest; + } else { + return self.current_byte_offset(); + } + } + } + + pub(super) const fn end_byte_offset(&self) -> usize { + self.last_end_offset + } + + fn peek_token_and_rest(&mut self) -> (TokenSpan<'a>, &'a str) { + let mut cloned = self.clone(); + let token = cloned.next(); + let rest = cloned.input; + (token, rest) + } + + const fn current_byte_offset(&self) -> usize { + self.source.len() - self.input.len() + } + + pub(super) const fn span_from(&self, offset: usize) -> Span { + offset..self.end_byte_offset() + } + + #[must_use] + pub(super) fn next(&mut self) -> TokenSpan<'a> { + let mut start_byte_offset = self.current_byte_offset(); + loop { + let (token, rest) = consume_token(self.input, false); + self.input = rest; + match token { + Token::Trivia => start_byte_offset = self.current_byte_offset(), + _ => { + self.last_end_offset = self.current_byte_offset(); + return (token, start_byte_offset..self.last_end_offset); + } + } + } + } + + #[must_use] + pub(super) fn next_generic(&mut self) -> TokenSpan<'a> { + let mut start_byte_offset = self.current_byte_offset(); + loop { + let (token, rest) = consume_token(self.input, true); + self.input = rest; + match token { + Token::Trivia => start_byte_offset = self.current_byte_offset(), + _ => return (token, start_byte_offset..self.current_byte_offset()), + } + } + } + + #[must_use] + pub(super) fn peek(&mut self) -> TokenSpan<'a> { + let (token, _) = self.peek_token_and_rest(); + token + } + + pub(super) fn expect_span( + &mut self, + expected: Token<'a>, + ) -> Result, Error<'a>> { + let next = self.next(); + if next.0 == expected { + Ok(next.1) + } else { + Err(Error::Unexpected(next.1, ExpectedToken::Token(expected))) + } + } + + pub(super) fn expect(&mut self, expected: Token<'a>) -> Result<(), Error<'a>> { + self.expect_span(expected)?; + Ok(()) + } + + pub(super) fn expect_generic_paren(&mut self, expected: char) -> Result<(), Error<'a>> { + let next = self.next_generic(); + if next.0 == Token::Paren(expected) { + Ok(()) + } else { + Err(Error::Unexpected( + next.1, + ExpectedToken::Token(Token::Paren(expected)), + )) + } + } + + /// If the next token matches it is skipped and true is returned + pub(super) fn skip(&mut self, what: Token<'_>) -> bool { + let (peeked_token, rest) = self.peek_token_and_rest(); + if peeked_token.0 == what { + self.input = rest; + true + } else { + false + } + } + + pub(super) fn next_ident_with_span(&mut self) -> Result<(&'a str, Span), Error<'a>> { + match self.next() { + (Token::Word(word), span) if word == "_" => { + Err(Error::InvalidIdentifierUnderscore(span)) + } + (Token::Word(word), span) if word.starts_with("__") => { + Err(Error::ReservedIdentifierPrefix(span)) + } + (Token::Word(word), span) => Ok((word, span)), + other => Err(Error::Unexpected(other.1, ExpectedToken::Identifier)), + } + } + + pub(super) fn next_ident(&mut self) -> Result<&'a str, Error<'a>> { + self.next_ident_with_span().map(|(word, _)| word) + } + + /// Parses a generic scalar type, for example ``. + pub(super) fn next_scalar_generic( + &mut self, + ) -> Result<(crate::ScalarKind, crate::Bytes), Error<'a>> { + self.expect_generic_paren('<')?; + let pair = match self.next() { + (Token::Word(word), span) => { + conv::get_scalar_type(word).ok_or(Error::UnknownScalarType(span)) + } + (_, span) => Err(Error::UnknownScalarType(span)), + }?; + self.expect_generic_paren('>')?; + Ok(pair) + } + + /// Parses a generic scalar type, for example ``. + /// + /// Returns the span covering the inner type, excluding the brackets. + pub(super) fn next_scalar_generic_with_span( + &mut self, + ) -> Result<(crate::ScalarKind, crate::Bytes, Span), Error<'a>> { + self.expect_generic_paren('<')?; + let pair = match self.next() { + (Token::Word(word), span) => conv::get_scalar_type(word) + .map(|(a, b)| (a, b, span.clone())) + .ok_or(Error::UnknownScalarType(span)), + (_, span) => Err(Error::UnknownScalarType(span)), + }?; + self.expect_generic_paren('>')?; + Ok(pair) + } + + pub(super) fn next_storage_access(&mut self) -> Result> { + let (ident, span) = self.next_ident_with_span()?; + match ident { + "read" => Ok(crate::StorageAccess::LOAD), + "write" => Ok(crate::StorageAccess::STORE), + "read_write" => Ok(crate::StorageAccess::LOAD | crate::StorageAccess::STORE), + _ => Err(Error::UnknownAccess(span)), + } + } + + pub(super) fn next_format_generic( + &mut self, + ) -> Result<(crate::StorageFormat, crate::StorageAccess), Error<'a>> { + self.expect(Token::Paren('<'))?; + let (ident, ident_span) = self.next_ident_with_span()?; + let format = conv::map_storage_format(ident, ident_span)?; + self.expect(Token::Separator(','))?; + let access = self.next_storage_access()?; + self.expect(Token::Paren('>'))?; + Ok((format, access)) + } + + pub(super) fn open_arguments(&mut self) -> Result<(), Error<'a>> { + self.expect(Token::Paren('(')) + } + + pub(super) fn close_arguments(&mut self) -> Result<(), Error<'a>> { + let _ = self.skip(Token::Separator(',')); + self.expect(Token::Paren(')')) + } + + pub(super) fn next_argument(&mut self) -> Result> { + let paren = Token::Paren(')'); + if self.skip(Token::Separator(',')) { + Ok(!self.skip(paren)) + } else { + self.expect(paren).map(|()| false) + } + } +} + +#[cfg(test)] +use super::{number::Number, NumberError}; + +#[cfg(test)] +fn sub_test(source: &str, expected_tokens: &[Token]) { + let mut lex = Lexer::new(source); + for &token in expected_tokens { + assert_eq!(lex.next().0, token); + } + assert_eq!(lex.next().0, Token::End); +} + +#[test] +fn test_numbers() { + // WGSL spec examples // + + // decimal integer + sub_test( + "0x123 0X123u 1u 123 0 0i 0x3f", + &[ + Token::Number(Ok(Number::I32(291))), + Token::Number(Ok(Number::U32(291))), + Token::Number(Ok(Number::U32(1))), + Token::Number(Ok(Number::I32(123))), + Token::Number(Ok(Number::I32(0))), + Token::Number(Ok(Number::I32(0))), + Token::Number(Ok(Number::I32(63))), + ], + ); + // decimal floating point + sub_test( + "0.e+4f 01. .01 12.34 .0f 0h 1e-3 0xa.fp+2 0x1P+4f 0X.3 0x3p+2h 0X1.fp-4 0x3.2p+2h", + &[ + Token::Number(Ok(Number::F32(0.))), + Token::Number(Ok(Number::F32(1.))), + Token::Number(Ok(Number::F32(0.01))), + Token::Number(Ok(Number::F32(12.34))), + Token::Number(Ok(Number::F32(0.))), + Token::Number(Err(NumberError::UnimplementedF16)), + Token::Number(Ok(Number::F32(0.001))), + Token::Number(Ok(Number::F32(43.75))), + Token::Number(Ok(Number::F32(16.))), + Token::Number(Ok(Number::F32(0.1875))), + Token::Number(Err(NumberError::UnimplementedF16)), + Token::Number(Ok(Number::F32(0.12109375))), + Token::Number(Err(NumberError::UnimplementedF16)), + ], + ); + + // MIN / MAX // + + // min / max decimal signed integer + sub_test( + "-2147483648i 2147483647i -2147483649i 2147483648i", + &[ + Token::Number(Ok(Number::I32(i32::MIN))), + Token::Number(Ok(Number::I32(i32::MAX))), + Token::Number(Err(NumberError::NotRepresentable)), + Token::Number(Err(NumberError::NotRepresentable)), + ], + ); + // min / max decimal unsigned integer + sub_test( + "0u 4294967295u -1u 4294967296u", + &[ + Token::Number(Ok(Number::U32(u32::MIN))), + Token::Number(Ok(Number::U32(u32::MAX))), + Token::Number(Err(NumberError::NotRepresentable)), + Token::Number(Err(NumberError::NotRepresentable)), + ], + ); + + // min / max hexadecimal signed integer + sub_test( + "-0x80000000i 0x7FFFFFFFi -0x80000001i 0x80000000i", + &[ + Token::Number(Ok(Number::I32(i32::MIN))), + Token::Number(Ok(Number::I32(i32::MAX))), + Token::Number(Err(NumberError::NotRepresentable)), + Token::Number(Err(NumberError::NotRepresentable)), + ], + ); + // min / max hexadecimal unsigned integer + sub_test( + "0x0u 0xFFFFFFFFu -0x1u 0x100000000u", + &[ + Token::Number(Ok(Number::U32(u32::MIN))), + Token::Number(Ok(Number::U32(u32::MAX))), + Token::Number(Err(NumberError::NotRepresentable)), + Token::Number(Err(NumberError::NotRepresentable)), + ], + ); + + /// ≈ 2^-126 * 2^−23 (= 2^−149) + const SMALLEST_POSITIVE_SUBNORMAL_F32: f32 = 1e-45; + /// ≈ 2^-126 * (1 − 2^−23) + const LARGEST_SUBNORMAL_F32: f32 = 1.1754942e-38; + /// ≈ 2^-126 + const SMALLEST_POSITIVE_NORMAL_F32: f32 = f32::MIN_POSITIVE; + /// ≈ 1 − 2^−24 + const LARGEST_F32_LESS_THAN_ONE: f32 = 0.99999994; + /// ≈ 1 + 2^−23 + const SMALLEST_F32_LARGER_THAN_ONE: f32 = 1.0000001; + /// ≈ -(2^127 * (2 − 2^−23)) + const SMALLEST_NORMAL_F32: f32 = f32::MIN; + /// ≈ 2^127 * (2 − 2^−23) + const LARGEST_NORMAL_F32: f32 = f32::MAX; + + // decimal floating point + sub_test( + "1e-45f 1.1754942e-38f 1.17549435e-38f 0.99999994f 1.0000001f -3.40282347e+38f 3.40282347e+38f", + &[ + Token::Number(Ok(Number::F32( + SMALLEST_POSITIVE_SUBNORMAL_F32, + ))), + Token::Number(Ok(Number::F32( + LARGEST_SUBNORMAL_F32, + ))), + Token::Number(Ok(Number::F32( + SMALLEST_POSITIVE_NORMAL_F32, + ))), + Token::Number(Ok(Number::F32( + LARGEST_F32_LESS_THAN_ONE, + ))), + Token::Number(Ok(Number::F32( + SMALLEST_F32_LARGER_THAN_ONE, + ))), + Token::Number(Ok(Number::F32( + SMALLEST_NORMAL_F32, + ))), + Token::Number(Ok(Number::F32( + LARGEST_NORMAL_F32, + ))), + ], + ); + sub_test( + "-3.40282367e+38f 3.40282367e+38f", + &[ + Token::Number(Err(NumberError::NotRepresentable)), // ≈ -2^128 + Token::Number(Err(NumberError::NotRepresentable)), // ≈ 2^128 + ], + ); + + // hexadecimal floating point + sub_test( + "0x1p-149f 0x7FFFFFp-149f 0x1p-126f 0xFFFFFFp-24f 0x800001p-23f -0xFFFFFFp+104f 0xFFFFFFp+104f", + &[ + Token::Number(Ok(Number::F32( + SMALLEST_POSITIVE_SUBNORMAL_F32, + ))), + Token::Number(Ok(Number::F32( + LARGEST_SUBNORMAL_F32, + ))), + Token::Number(Ok(Number::F32( + SMALLEST_POSITIVE_NORMAL_F32, + ))), + Token::Number(Ok(Number::F32( + LARGEST_F32_LESS_THAN_ONE, + ))), + Token::Number(Ok(Number::F32( + SMALLEST_F32_LARGER_THAN_ONE, + ))), + Token::Number(Ok(Number::F32( + SMALLEST_NORMAL_F32, + ))), + Token::Number(Ok(Number::F32( + LARGEST_NORMAL_F32, + ))), + ], + ); + sub_test( + "-0x1p128f 0x1p128f 0x1.000001p0f", + &[ + Token::Number(Err(NumberError::NotRepresentable)), // = -2^128 + Token::Number(Err(NumberError::NotRepresentable)), // = 2^128 + Token::Number(Err(NumberError::NotRepresentable)), + ], + ); +} + +#[test] +fn test_tokens() { + sub_test("id123_OK", &[Token::Word("id123_OK")]); + sub_test( + "92No", + &[Token::Number(Ok(Number::I32(92))), Token::Word("No")], + ); + sub_test( + "2u3o", + &[ + Token::Number(Ok(Number::U32(2))), + Token::Number(Ok(Number::I32(3))), + Token::Word("o"), + ], + ); + sub_test( + "2.4f44po", + &[ + Token::Number(Ok(Number::F32(2.4))), + Token::Number(Ok(Number::I32(44))), + Token::Word("po"), + ], + ); + sub_test( + "Δέλτα réflexion Кызыл 𐰓𐰏𐰇 朝焼け سلام 검정 שָׁלוֹם गुलाबी փիրուզ", + &[ + Token::Word("Δέλτα"), + Token::Word("réflexion"), + Token::Word("Кызыл"), + Token::Word("𐰓𐰏𐰇"), + Token::Word("朝焼け"), + Token::Word("سلام"), + Token::Word("검정"), + Token::Word("שָׁלוֹם"), + Token::Word("गुलाबी"), + Token::Word("փիրուզ"), + ], + ); + sub_test("æNoø", &[Token::Word("æNoø")]); + sub_test("No¾", &[Token::Word("No"), Token::Unknown('¾')]); + sub_test("No好", &[Token::Word("No好")]); + sub_test("_No", &[Token::Word("_No")]); + sub_test( + "*/*/***/*//=/*****//", + &[ + Token::Operation('*'), + Token::AssignmentOperation('/'), + Token::Operation('/'), + ], + ); +} + +#[test] +fn test_variable_decl() { + sub_test( + "@group(0 ) var< uniform> texture: texture_multisampled_2d ;", + &[ + Token::Attribute, + Token::Word("group"), + Token::Paren('('), + Token::Number(Ok(Number::I32(0))), + Token::Paren(')'), + Token::Word("var"), + Token::Paren('<'), + Token::Word("uniform"), + Token::Paren('>'), + Token::Word("texture"), + Token::Separator(':'), + Token::Word("texture_multisampled_2d"), + Token::Paren('<'), + Token::Word("f32"), + Token::Paren('>'), + Token::Separator(';'), + ], + ); + sub_test( + "var buffer: array;", + &[ + Token::Word("var"), + Token::Paren('<'), + Token::Word("storage"), + Token::Separator(','), + Token::Word("read_write"), + Token::Paren('>'), + Token::Word("buffer"), + Token::Separator(':'), + Token::Word("array"), + Token::Paren('<'), + Token::Word("u32"), + Token::Paren('>'), + Token::Separator(';'), + ], + ); +} diff --git a/third_party/rust/naga/src/front/wgsl/mod.rs b/third_party/rust/naga/src/front/wgsl/mod.rs new file mode 100644 index 0000000000..2873e6c73c --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/mod.rs @@ -0,0 +1,4750 @@ +/*! +Frontend for [WGSL][wgsl] (WebGPU Shading Language). + +[wgsl]: https://gpuweb.github.io/gpuweb/wgsl.html +*/ + +mod construction; +mod conv; +mod lexer; +mod number; +#[cfg(test)] +mod tests; + +use crate::{ + arena::{Arena, Handle, UniqueArena}, + proc::{ + ensure_block_returns, Alignment, Layouter, ResolveContext, ResolveError, TypeResolution, + }, + span::SourceLocation, + span::Span as NagaSpan, + ConstantInner, FastHashMap, ScalarValue, +}; + +use self::{lexer::Lexer, number::Number}; +use codespan_reporting::{ + diagnostic::{Diagnostic, Label}, + files::SimpleFile, + term::{ + self, + termcolor::{ColorChoice, NoColor, StandardStream}, + }, +}; +use std::{borrow::Cow, convert::TryFrom, ops}; +use thiserror::Error; + +type Span = ops::Range; +type TokenSpan<'a> = (Token<'a>, Span); + +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum Token<'a> { + Separator(char), + Paren(char), + Attribute, + Number(Result), + Word(&'a str), + Operation(char), + LogicalOperation(char), + ShiftOperation(char), + AssignmentOperation(char), + IncrementOperation, + DecrementOperation, + Arrow, + Unknown(char), + Trivia, + End, +} + +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum NumberType { + I32, + U32, + F32, +} + +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum ExpectedToken<'a> { + Token(Token<'a>), + Identifier, + Number(NumberType), + Integer, + Constant, + /// Expected: constant, parenthesized expression, identifier + PrimaryExpression, + /// Expected: assignment, increment/decrement expression + Assignment, + /// Expected: '}', identifier + FieldName, + /// Expected: attribute for a type + TypeAttribute, + /// Expected: ';', '{', word + Statement, + /// Expected: 'case', 'default', '}' + SwitchItem, + /// Expected: ',', ')' + WorkgroupSizeSeparator, + /// Expected: 'struct', 'let', 'var', 'type', ';', 'fn', eof + GlobalItem, +} + +#[derive(Clone, Copy, Debug, Error, PartialEq)] +pub enum NumberError { + #[error("invalid numeric literal format")] + Invalid, + #[error("numeric literal not representable by target type")] + NotRepresentable, + #[error("unimplemented f16 type")] + UnimplementedF16, +} + +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum InvalidAssignmentType { + Other, + Swizzle, + ImmutableBinding, +} + +#[derive(Clone, Debug)] +pub enum Error<'a> { + Unexpected(Span, ExpectedToken<'a>), + UnexpectedComponents(Span), + BadNumber(Span, NumberError), + /// A negative signed integer literal where both signed and unsigned, + /// but only non-negative literals are allowed. + NegativeInt(Span), + BadU32Constant(Span), + BadMatrixScalarKind(Span, crate::ScalarKind, u8), + BadAccessor(Span), + BadTexture(Span), + BadTypeCast { + span: Span, + from_type: String, + to_type: String, + }, + BadTextureSampleType { + span: Span, + kind: crate::ScalarKind, + width: u8, + }, + BadIncrDecrReferenceType(Span), + InvalidResolve(ResolveError), + InvalidForInitializer(Span), + /// A break if appeared outside of a continuing block + InvalidBreakIf(Span), + InvalidGatherComponent(Span, u32), + InvalidConstructorComponentType(Span, i32), + InvalidIdentifierUnderscore(Span), + ReservedIdentifierPrefix(Span), + UnknownAddressSpace(Span), + UnknownAttribute(Span), + UnknownBuiltin(Span), + UnknownAccess(Span), + UnknownShaderStage(Span), + UnknownIdent(Span, &'a str), + UnknownScalarType(Span), + UnknownType(Span), + UnknownStorageFormat(Span), + UnknownConservativeDepth(Span), + SizeAttributeTooLow(Span, u32), + AlignAttributeTooLow(Span, Alignment), + NonPowerOfTwoAlignAttribute(Span), + InconsistentBinding(Span), + UnknownLocalFunction(Span), + TypeNotConstructible(Span), + TypeNotInferrable(Span), + InitializationTypeMismatch(Span, String), + MissingType(Span), + MissingAttribute(&'static str, Span), + InvalidAtomicPointer(Span), + InvalidAtomicOperandType(Span), + Pointer(&'static str, Span), + NotPointer(Span), + NotReference(&'static str, Span), + InvalidAssignment { + span: Span, + ty: InvalidAssignmentType, + }, + ReservedKeyword(Span), + Redefinition { + previous: Span, + current: Span, + }, + Other, +} + +impl<'a> Error<'a> { + fn as_parse_error(&self, source: &'a str) -> ParseError { + match *self { + Error::Unexpected(ref unexpected_span, expected) => { + let expected_str = match expected { + ExpectedToken::Token(token) => { + match token { + Token::Separator(c) => format!("'{}'", c), + Token::Paren(c) => format!("'{}'", c), + Token::Attribute => "@".to_string(), + Token::Number(_) => "number".to_string(), + Token::Word(s) => s.to_string(), + Token::Operation(c) => format!("operation ('{}')", c), + Token::LogicalOperation(c) => format!("logical operation ('{}')", c), + Token::ShiftOperation(c) => format!("bitshift ('{}{}')", c, c), + Token::AssignmentOperation(c) if c=='<' || c=='>' => format!("bitshift ('{}{}=')", c, c), + Token::AssignmentOperation(c) => format!("operation ('{}=')", c), + Token::IncrementOperation => "increment operation".to_string(), + Token::DecrementOperation => "decrement operation".to_string(), + Token::Arrow => "->".to_string(), + Token::Unknown(c) => format!("unknown ('{}')", c), + Token::Trivia => "trivia".to_string(), + Token::End => "end".to_string(), + } + } + ExpectedToken::Identifier => "identifier".to_string(), + ExpectedToken::Number(ty) => { + match ty { + NumberType::I32 => "32-bit signed integer literal", + NumberType::U32 => "32-bit unsigned integer literal", + NumberType::F32 => "32-bit floating-point literal", + }.to_string() + }, + ExpectedToken::Integer => "unsigned/signed integer literal".to_string(), + ExpectedToken::Constant => "constant".to_string(), + ExpectedToken::PrimaryExpression => "expression".to_string(), + ExpectedToken::Assignment => "assignment or increment/decrement".to_string(), + ExpectedToken::FieldName => "field name or a closing curly bracket to signify the end of the struct".to_string(), + ExpectedToken::TypeAttribute => "type attribute".to_string(), + ExpectedToken::Statement => "statement".to_string(), + ExpectedToken::SwitchItem => "switch item ('case' or 'default') or a closing curly bracket to signify the end of the switch statement ('}')".to_string(), + ExpectedToken::WorkgroupSizeSeparator => "workgroup size separator (',') or a closing parenthesis".to_string(), + ExpectedToken::GlobalItem => "global item ('struct', 'let', 'var', 'type', ';', 'fn') or the end of the file".to_string(), + }; + ParseError { + message: format!( + "expected {}, found '{}'", + expected_str, + &source[unexpected_span.clone()], + ), + labels: vec![( + unexpected_span.clone(), + format!("expected {}", expected_str).into(), + )], + notes: vec![], + } + } + Error::UnexpectedComponents(ref bad_span) => ParseError { + message: "unexpected components".to_string(), + labels: vec![(bad_span.clone(), "unexpected components".into())], + notes: vec![], + }, + Error::BadNumber(ref bad_span, ref err) => ParseError { + message: format!("{}: `{}`", err, &source[bad_span.clone()],), + labels: vec![(bad_span.clone(), err.to_string().into())], + notes: vec![], + }, + Error::NegativeInt(ref bad_span) => ParseError { + message: format!( + "expected non-negative integer literal, found `{}`", + &source[bad_span.clone()], + ), + labels: vec![(bad_span.clone(), "expected non-negative integer".into())], + notes: vec![], + }, + Error::BadU32Constant(ref bad_span) => ParseError { + message: format!( + "expected unsigned integer constant expression, found `{}`", + &source[bad_span.clone()], + ), + labels: vec![(bad_span.clone(), "expected unsigned integer".into())], + notes: vec![], + }, + Error::BadMatrixScalarKind(ref span, kind, width) => ParseError { + message: format!( + "matrix scalar type must be floating-point, but found `{}`", + kind.to_wgsl(width) + ), + labels: vec![(span.clone(), "must be floating-point (e.g. `f32`)".into())], + notes: vec![], + }, + Error::BadAccessor(ref accessor_span) => ParseError { + message: format!( + "invalid field accessor `{}`", + &source[accessor_span.clone()], + ), + labels: vec![(accessor_span.clone(), "invalid accessor".into())], + notes: vec![], + }, + Error::UnknownIdent(ref ident_span, ident) => ParseError { + message: format!("no definition in scope for identifier: '{}'", ident), + labels: vec![(ident_span.clone(), "unknown identifier".into())], + notes: vec![], + }, + Error::UnknownScalarType(ref bad_span) => ParseError { + message: format!("unknown scalar type: '{}'", &source[bad_span.clone()]), + labels: vec![(bad_span.clone(), "unknown scalar type".into())], + notes: vec!["Valid scalar types are f16, f32, f64, \ + i8, i16, i32, i64, \ + u8, u16, u32, u64, bool" + .into()], + }, + Error::BadTextureSampleType { + ref span, + kind, + width, + } => ParseError { + message: format!( + "texture sample type must be one of f32, i32 or u32, but found {}", + kind.to_wgsl(width) + ), + labels: vec![(span.clone(), "must be one of f32, i32 or u32".into())], + notes: vec![], + }, + Error::BadIncrDecrReferenceType(ref span) => ParseError { + message: + "increment/decrement operation requires reference type to be one of i32 or u32" + .to_string(), + labels: vec![( + span.clone(), + "must be a reference type of i32 or u32".into(), + )], + notes: vec![], + }, + Error::BadTexture(ref bad_span) => ParseError { + message: format!( + "expected an image, but found '{}' which is not an image", + &source[bad_span.clone()] + ), + labels: vec![(bad_span.clone(), "not an image".into())], + notes: vec![], + }, + Error::BadTypeCast { + ref span, + ref from_type, + ref to_type, + } => { + let msg = format!("cannot cast a {} to a {}", from_type, to_type); + ParseError { + message: msg.clone(), + labels: vec![(span.clone(), msg.into())], + notes: vec![], + } + } + Error::InvalidResolve(ref resolve_error) => ParseError { + message: resolve_error.to_string(), + labels: vec![], + notes: vec![], + }, + Error::InvalidForInitializer(ref bad_span) => ParseError { + message: format!( + "for(;;) initializer is not an assignment or a function call: '{}'", + &source[bad_span.clone()] + ), + labels: vec![( + bad_span.clone(), + "not an assignment or function call".into(), + )], + notes: vec![], + }, + Error::InvalidBreakIf(ref bad_span) => ParseError { + message: "A break if is only allowed in a continuing block".to_string(), + labels: vec![(bad_span.clone(), "not in a continuing block".into())], + notes: vec![], + }, + Error::InvalidGatherComponent(ref bad_span, component) => ParseError { + message: format!( + "textureGather component {} doesn't exist, must be 0, 1, 2, or 3", + component + ), + labels: vec![(bad_span.clone(), "invalid component".into())], + notes: vec![], + }, + Error::InvalidConstructorComponentType(ref bad_span, component) => ParseError { + message: format!( + "invalid type for constructor component at index [{}]", + component + ), + labels: vec![(bad_span.clone(), "invalid component type".into())], + notes: vec![], + }, + Error::InvalidIdentifierUnderscore(ref bad_span) => ParseError { + message: "Identifier can't be '_'".to_string(), + labels: vec![(bad_span.clone(), "invalid identifier".into())], + notes: vec![ + "Use phony assignment instead ('_ =' notice the absence of 'let' or 'var')" + .to_string(), + ], + }, + Error::ReservedIdentifierPrefix(ref bad_span) => ParseError { + message: format!( + "Identifier starts with a reserved prefix: '{}'", + &source[bad_span.clone()] + ), + labels: vec![(bad_span.clone(), "invalid identifier".into())], + notes: vec![], + }, + Error::UnknownAddressSpace(ref bad_span) => ParseError { + message: format!("unknown address space: '{}'", &source[bad_span.clone()]), + labels: vec![(bad_span.clone(), "unknown address space".into())], + notes: vec![], + }, + Error::UnknownAttribute(ref bad_span) => ParseError { + message: format!("unknown attribute: '{}'", &source[bad_span.clone()]), + labels: vec![(bad_span.clone(), "unknown attribute".into())], + notes: vec![], + }, + Error::UnknownBuiltin(ref bad_span) => ParseError { + message: format!("unknown builtin: '{}'", &source[bad_span.clone()]), + labels: vec![(bad_span.clone(), "unknown builtin".into())], + notes: vec![], + }, + Error::UnknownAccess(ref bad_span) => ParseError { + message: format!("unknown access: '{}'", &source[bad_span.clone()]), + labels: vec![(bad_span.clone(), "unknown access".into())], + notes: vec![], + }, + Error::UnknownShaderStage(ref bad_span) => ParseError { + message: format!("unknown shader stage: '{}'", &source[bad_span.clone()]), + labels: vec![(bad_span.clone(), "unknown shader stage".into())], + notes: vec![], + }, + Error::UnknownStorageFormat(ref bad_span) => ParseError { + message: format!("unknown storage format: '{}'", &source[bad_span.clone()]), + labels: vec![(bad_span.clone(), "unknown storage format".into())], + notes: vec![], + }, + Error::UnknownConservativeDepth(ref bad_span) => ParseError { + message: format!( + "unknown conservative depth: '{}'", + &source[bad_span.clone()] + ), + labels: vec![(bad_span.clone(), "unknown conservative depth".into())], + notes: vec![], + }, + Error::UnknownType(ref bad_span) => ParseError { + message: format!("unknown type: '{}'", &source[bad_span.clone()]), + labels: vec![(bad_span.clone(), "unknown type".into())], + notes: vec![], + }, + Error::SizeAttributeTooLow(ref bad_span, min_size) => ParseError { + message: format!("struct member size must be at least {}", min_size), + labels: vec![( + bad_span.clone(), + format!("must be at least {}", min_size).into(), + )], + notes: vec![], + }, + Error::AlignAttributeTooLow(ref bad_span, min_align) => ParseError { + message: format!("struct member alignment must be at least {}", min_align), + labels: vec![( + bad_span.clone(), + format!("must be at least {}", min_align).into(), + )], + notes: vec![], + }, + Error::NonPowerOfTwoAlignAttribute(ref bad_span) => ParseError { + message: "struct member alignment must be a power of 2".to_string(), + labels: vec![(bad_span.clone(), "must be a power of 2".into())], + notes: vec![], + }, + Error::InconsistentBinding(ref span) => ParseError { + message: "input/output binding is not consistent".to_string(), + labels: vec![( + span.clone(), + "input/output binding is not consistent".into(), + )], + notes: vec![], + }, + Error::UnknownLocalFunction(ref span) => ParseError { + message: format!("unknown local function `{}`", &source[span.clone()]), + labels: vec![(span.clone(), "unknown local function".into())], + notes: vec![], + }, + Error::TypeNotConstructible(ref span) => ParseError { + message: format!("type `{}` is not constructible", &source[span.clone()]), + labels: vec![(span.clone(), "type is not constructible".into())], + notes: vec![], + }, + Error::TypeNotInferrable(ref span) => ParseError { + message: "type can't be inferred".to_string(), + labels: vec![(span.clone(), "type can't be inferred".into())], + notes: vec![], + }, + Error::InitializationTypeMismatch(ref name_span, ref expected_ty) => ParseError { + message: format!( + "the type of `{}` is expected to be `{}`", + &source[name_span.clone()], + expected_ty + ), + labels: vec![( + name_span.clone(), + format!("definition of `{}`", &source[name_span.clone()]).into(), + )], + notes: vec![], + }, + Error::MissingType(ref name_span) => ParseError { + message: format!("variable `{}` needs a type", &source[name_span.clone()]), + labels: vec![( + name_span.clone(), + format!("definition of `{}`", &source[name_span.clone()]).into(), + )], + notes: vec![], + }, + Error::MissingAttribute(name, ref name_span) => ParseError { + message: format!( + "variable `{}` needs a '{}' attribute", + &source[name_span.clone()], + name + ), + labels: vec![( + name_span.clone(), + format!("definition of `{}`", &source[name_span.clone()]).into(), + )], + notes: vec![], + }, + Error::InvalidAtomicPointer(ref span) => ParseError { + message: "atomic operation is done on a pointer to a non-atomic".to_string(), + labels: vec![(span.clone(), "atomic pointer is invalid".into())], + notes: vec![], + }, + Error::InvalidAtomicOperandType(ref span) => ParseError { + message: "atomic operand type is inconsistent with the operation".to_string(), + labels: vec![(span.clone(), "atomic operand type is invalid".into())], + notes: vec![], + }, + Error::NotPointer(ref span) => ParseError { + message: "the operand of the `*` operator must be a pointer".to_string(), + labels: vec![(span.clone(), "expression is not a pointer".into())], + notes: vec![], + }, + Error::NotReference(what, ref span) => ParseError { + message: format!("{} must be a reference", what), + labels: vec![(span.clone(), "expression is not a reference".into())], + notes: vec![], + }, + Error::InvalidAssignment { ref span, ty } => ParseError { + message: "invalid left-hand side of assignment".into(), + labels: vec![(span.clone(), "cannot assign to this expression".into())], + notes: match ty { + InvalidAssignmentType::Swizzle => vec![ + "WGSL does not support assignments to swizzles".into(), + "consider assigning each component individually".into(), + ], + InvalidAssignmentType::ImmutableBinding => vec![ + format!("'{}' is an immutable binding", &source[span.clone()]), + "consider declaring it with `var` instead of `let`".into(), + ], + InvalidAssignmentType::Other => vec![], + }, + }, + Error::Pointer(what, ref span) => ParseError { + message: format!("{} must not be a pointer", what), + labels: vec![(span.clone(), "expression is a pointer".into())], + notes: vec![], + }, + Error::ReservedKeyword(ref name_span) => ParseError { + message: format!( + "name `{}` is a reserved keyword", + &source[name_span.clone()] + ), + labels: vec![( + name_span.clone(), + format!("definition of `{}`", &source[name_span.clone()]).into(), + )], + notes: vec![], + }, + Error::Redefinition { + ref previous, + ref current, + } => ParseError { + message: format!("redefinition of `{}`", &source[current.clone()]), + labels: vec![ + ( + current.clone(), + format!("redefinition of `{}`", &source[current.clone()]).into(), + ), + ( + previous.clone(), + format!("previous definition of `{}`", &source[previous.clone()]).into(), + ), + ], + notes: vec![], + }, + Error::Other => ParseError { + message: "other error".to_string(), + labels: vec![], + notes: vec![], + }, + } + } +} + +impl crate::StorageFormat { + const fn to_wgsl(self) -> &'static str { + use crate::StorageFormat as Sf; + match self { + Sf::R8Unorm => "r8unorm", + Sf::R8Snorm => "r8snorm", + Sf::R8Uint => "r8uint", + Sf::R8Sint => "r8sint", + Sf::R16Uint => "r16uint", + Sf::R16Sint => "r16sint", + Sf::R16Float => "r16float", + Sf::Rg8Unorm => "rg8unorm", + Sf::Rg8Snorm => "rg8snorm", + Sf::Rg8Uint => "rg8uint", + Sf::Rg8Sint => "rg8sint", + Sf::R32Uint => "r32uint", + Sf::R32Sint => "r32sint", + Sf::R32Float => "r32float", + Sf::Rg16Uint => "rg16uint", + Sf::Rg16Sint => "rg16sint", + Sf::Rg16Float => "rg16float", + Sf::Rgba8Unorm => "rgba8unorm", + Sf::Rgba8Snorm => "rgba8snorm", + Sf::Rgba8Uint => "rgba8uint", + Sf::Rgba8Sint => "rgba8sint", + Sf::Rgb10a2Unorm => "rgb10a2unorm", + Sf::Rg11b10Float => "rg11b10float", + Sf::Rg32Uint => "rg32uint", + Sf::Rg32Sint => "rg32sint", + Sf::Rg32Float => "rg32float", + Sf::Rgba16Uint => "rgba16uint", + Sf::Rgba16Sint => "rgba16sint", + Sf::Rgba16Float => "rgba16float", + Sf::Rgba32Uint => "rgba32uint", + Sf::Rgba32Sint => "rgba32sint", + Sf::Rgba32Float => "rgba32float", + } + } +} + +impl crate::TypeInner { + /// Formats the type as it is written in wgsl. + /// + /// For example `vec3`. + /// + /// Note: The names of a `TypeInner::Struct` is not known. Therefore this method will simply return "struct" for them. + fn to_wgsl( + &self, + types: &UniqueArena, + constants: &Arena, + ) -> String { + use crate::TypeInner as Ti; + + match *self { + Ti::Scalar { kind, width } => kind.to_wgsl(width), + Ti::Vector { size, kind, width } => { + format!("vec{}<{}>", size as u32, kind.to_wgsl(width)) + } + Ti::Matrix { + columns, + rows, + width, + } => { + format!( + "mat{}x{}<{}>", + columns as u32, + rows as u32, + crate::ScalarKind::Float.to_wgsl(width), + ) + } + Ti::Atomic { kind, width } => { + format!("atomic<{}>", kind.to_wgsl(width)) + } + Ti::Pointer { base, .. } => { + let base = &types[base]; + let name = base.name.as_deref().unwrap_or("unknown"); + format!("ptr<{}>", name) + } + Ti::ValuePointer { kind, width, .. } => { + format!("ptr<{}>", kind.to_wgsl(width)) + } + Ti::Array { base, size, .. } => { + let member_type = &types[base]; + let base = member_type.name.as_deref().unwrap_or("unknown"); + match size { + crate::ArraySize::Constant(size) => { + let size = constants[size].name.as_deref().unwrap_or("unknown"); + format!("array<{}, {}>", base, size) + } + crate::ArraySize::Dynamic => format!("array<{}>", base), + } + } + Ti::Struct { .. } => { + // TODO: Actually output the struct? + "struct".to_string() + } + Ti::Image { + dim, + arrayed, + class, + } => { + let dim_suffix = match dim { + crate::ImageDimension::D1 => "_1d", + crate::ImageDimension::D2 => "_2d", + crate::ImageDimension::D3 => "_3d", + crate::ImageDimension::Cube => "_cube", + }; + let array_suffix = if arrayed { "_array" } else { "" }; + + let class_suffix = match class { + crate::ImageClass::Sampled { multi: true, .. } => "_multisampled", + crate::ImageClass::Depth { multi: false } => "_depth", + crate::ImageClass::Depth { multi: true } => "_depth_multisampled", + crate::ImageClass::Sampled { multi: false, .. } + | crate::ImageClass::Storage { .. } => "", + }; + + let type_in_brackets = match class { + crate::ImageClass::Sampled { kind, .. } => { + // Note: The only valid widths are 4 bytes wide. + // The lexer has already verified this, so we can safely assume it here. + // https://gpuweb.github.io/gpuweb/wgsl/#sampled-texture-type + let element_type = kind.to_wgsl(4); + format!("<{}>", element_type) + } + crate::ImageClass::Depth { multi: _ } => String::new(), + crate::ImageClass::Storage { format, access } => { + if access.contains(crate::StorageAccess::STORE) { + format!("<{},write>", format.to_wgsl()) + } else { + format!("<{}>", format.to_wgsl()) + } + } + }; + + format!( + "texture{}{}{}{}", + class_suffix, dim_suffix, array_suffix, type_in_brackets + ) + } + Ti::Sampler { .. } => "sampler".to_string(), + Ti::BindingArray { base, size, .. } => { + let member_type = &types[base]; + let base = member_type.name.as_deref().unwrap_or("unknown"); + match size { + crate::ArraySize::Constant(size) => { + let size = constants[size].name.as_deref().unwrap_or("unknown"); + format!("binding_array<{}, {}>", base, size) + } + crate::ArraySize::Dynamic => format!("binding_array<{}>", base), + } + } + } + } +} + +mod type_inner_tests { + #[test] + fn to_wgsl() { + let mut types = crate::UniqueArena::new(); + let mut constants = crate::Arena::new(); + let c = constants.append( + crate::Constant { + name: Some("C".to_string()), + specialization: None, + inner: crate::ConstantInner::Scalar { + width: 4, + value: crate::ScalarValue::Uint(32), + }, + }, + Default::default(), + ); + + let mytype1 = types.insert( + crate::Type { + name: Some("MyType1".to_string()), + inner: crate::TypeInner::Struct { + members: vec![], + span: 0, + }, + }, + Default::default(), + ); + let mytype2 = types.insert( + crate::Type { + name: Some("MyType2".to_string()), + inner: crate::TypeInner::Struct { + members: vec![], + span: 0, + }, + }, + Default::default(), + ); + + let array = crate::TypeInner::Array { + base: mytype1, + stride: 4, + size: crate::ArraySize::Constant(c), + }; + assert_eq!(array.to_wgsl(&types, &constants), "array"); + + let mat = crate::TypeInner::Matrix { + rows: crate::VectorSize::Quad, + columns: crate::VectorSize::Bi, + width: 8, + }; + assert_eq!(mat.to_wgsl(&types, &constants), "mat2x4"); + + let ptr = crate::TypeInner::Pointer { + base: mytype2, + space: crate::AddressSpace::Storage { + access: crate::StorageAccess::default(), + }, + }; + assert_eq!(ptr.to_wgsl(&types, &constants), "ptr"); + + let img1 = crate::TypeInner::Image { + dim: crate::ImageDimension::D2, + arrayed: false, + class: crate::ImageClass::Sampled { + kind: crate::ScalarKind::Float, + multi: true, + }, + }; + assert_eq!( + img1.to_wgsl(&types, &constants), + "texture_multisampled_2d" + ); + + let img2 = crate::TypeInner::Image { + dim: crate::ImageDimension::Cube, + arrayed: true, + class: crate::ImageClass::Depth { multi: false }, + }; + assert_eq!(img2.to_wgsl(&types, &constants), "texture_depth_cube_array"); + + let img3 = crate::TypeInner::Image { + dim: crate::ImageDimension::D2, + arrayed: false, + class: crate::ImageClass::Depth { multi: true }, + }; + assert_eq!( + img3.to_wgsl(&types, &constants), + "texture_depth_multisampled_2d" + ); + + let array = crate::TypeInner::BindingArray { + base: mytype1, + size: crate::ArraySize::Constant(c), + }; + assert_eq!( + array.to_wgsl(&types, &constants), + "binding_array" + ); + } +} + +impl crate::ScalarKind { + /// Format a scalar kind+width as a type is written in wgsl. + /// + /// Examples: `f32`, `u64`, `bool`. + fn to_wgsl(self, width: u8) -> String { + let prefix = match self { + crate::ScalarKind::Sint => "i", + crate::ScalarKind::Uint => "u", + crate::ScalarKind::Float => "f", + crate::ScalarKind::Bool => return "bool".to_string(), + }; + format!("{}{}", prefix, width * 8) + } +} + +trait StringValueLookup<'a> { + type Value; + fn lookup(&self, key: &'a str, span: Span) -> Result>; +} +impl<'a> StringValueLookup<'a> for FastHashMap<&'a str, TypedExpression> { + type Value = TypedExpression; + fn lookup(&self, key: &'a str, span: Span) -> Result> { + self.get(key).cloned().ok_or(Error::UnknownIdent(span, key)) + } +} + +struct StatementContext<'input, 'temp, 'out> { + symbol_table: &'temp mut super::SymbolTable<&'input str, TypedExpression>, + typifier: &'temp mut super::Typifier, + variables: &'out mut Arena, + expressions: &'out mut Arena, + named_expressions: &'out mut FastHashMap, String>, + types: &'out mut UniqueArena, + constants: &'out mut Arena, + global_vars: &'out Arena, + functions: &'out Arena, + arguments: &'out [crate::FunctionArgument], +} + +impl<'a, 'temp> StatementContext<'a, 'temp, '_> { + fn reborrow(&mut self) -> StatementContext<'a, '_, '_> { + StatementContext { + symbol_table: self.symbol_table, + typifier: self.typifier, + variables: self.variables, + expressions: self.expressions, + named_expressions: self.named_expressions, + types: self.types, + constants: self.constants, + global_vars: self.global_vars, + functions: self.functions, + arguments: self.arguments, + } + } + + fn as_expression<'t>( + &'t mut self, + block: &'t mut crate::Block, + emitter: &'t mut super::Emitter, + ) -> ExpressionContext<'a, 't, '_> + where + 'temp: 't, + { + ExpressionContext { + symbol_table: self.symbol_table, + typifier: self.typifier, + expressions: self.expressions, + types: self.types, + constants: self.constants, + global_vars: self.global_vars, + local_vars: self.variables, + functions: self.functions, + arguments: self.arguments, + block, + emitter, + } + } +} + +struct SamplingContext { + image: Handle, + arrayed: bool, +} + +struct ExpressionContext<'input, 'temp, 'out> { + symbol_table: &'temp mut super::SymbolTable<&'input str, TypedExpression>, + typifier: &'temp mut super::Typifier, + expressions: &'out mut Arena, + types: &'out mut UniqueArena, + constants: &'out mut Arena, + global_vars: &'out Arena, + local_vars: &'out Arena, + arguments: &'out [crate::FunctionArgument], + functions: &'out Arena, + block: &'temp mut crate::Block, + emitter: &'temp mut super::Emitter, +} + +impl<'a> ExpressionContext<'a, '_, '_> { + fn reborrow(&mut self) -> ExpressionContext<'a, '_, '_> { + ExpressionContext { + symbol_table: self.symbol_table, + typifier: self.typifier, + expressions: self.expressions, + types: self.types, + constants: self.constants, + global_vars: self.global_vars, + local_vars: self.local_vars, + functions: self.functions, + arguments: self.arguments, + block: self.block, + emitter: self.emitter, + } + } + + fn resolve_type( + &mut self, + handle: Handle, + ) -> Result<&crate::TypeInner, Error<'a>> { + let resolve_ctx = ResolveContext { + constants: self.constants, + types: self.types, + global_vars: self.global_vars, + local_vars: self.local_vars, + functions: self.functions, + arguments: self.arguments, + }; + match self.typifier.grow(handle, self.expressions, &resolve_ctx) { + Err(e) => Err(Error::InvalidResolve(e)), + Ok(()) => Ok(self.typifier.get(handle, self.types)), + } + } + + fn prepare_sampling( + &mut self, + image: Handle, + span: Span, + ) -> Result> { + Ok(SamplingContext { + image, + arrayed: match *self.resolve_type(image)? { + crate::TypeInner::Image { arrayed, .. } => arrayed, + _ => return Err(Error::BadTexture(span)), + }, + }) + } + + fn parse_binary_op( + &mut self, + lexer: &mut Lexer<'a>, + classifier: impl Fn(Token<'a>) -> Option, + mut parser: impl FnMut( + &mut Lexer<'a>, + ExpressionContext<'a, '_, '_>, + ) -> Result>, + ) -> Result> { + let start = lexer.start_byte_offset() as u32; + let mut accumulator = parser(lexer, self.reborrow())?; + while let Some(op) = classifier(lexer.peek().0) { + let _ = lexer.next(); + // Binary expressions always apply the load rule to their operands. + let mut left = self.apply_load_rule(accumulator); + let unloaded_right = parser(lexer, self.reborrow())?; + let right = self.apply_load_rule(unloaded_right); + let end = lexer.end_byte_offset() as u32; + left = self.expressions.append( + crate::Expression::Binary { op, left, right }, + NagaSpan::new(start, end), + ); + // Binary expressions never produce references. + accumulator = TypedExpression::non_reference(left); + } + Ok(accumulator) + } + + fn parse_binary_splat_op( + &mut self, + lexer: &mut Lexer<'a>, + classifier: impl Fn(Token<'a>) -> Option, + mut parser: impl FnMut( + &mut Lexer<'a>, + ExpressionContext<'a, '_, '_>, + ) -> Result>, + ) -> Result> { + let start = lexer.start_byte_offset() as u32; + let mut accumulator = parser(lexer, self.reborrow())?; + while let Some(op) = classifier(lexer.peek().0) { + let _ = lexer.next(); + // Binary expressions always apply the load rule to their operands. + let mut left = self.apply_load_rule(accumulator); + let unloaded_right = parser(lexer, self.reborrow())?; + let mut right = self.apply_load_rule(unloaded_right); + let end = lexer.end_byte_offset() as u32; + + self.binary_op_splat(op, &mut left, &mut right)?; + + accumulator = TypedExpression::non_reference(self.expressions.append( + crate::Expression::Binary { op, left, right }, + NagaSpan::new(start, end), + )); + } + Ok(accumulator) + } + + /// Insert splats, if needed by the non-'*' operations. + fn binary_op_splat( + &mut self, + op: crate::BinaryOperator, + left: &mut Handle, + right: &mut Handle, + ) -> Result<(), Error<'a>> { + if op != crate::BinaryOperator::Multiply { + let left_size = match *self.resolve_type(*left)? { + crate::TypeInner::Vector { size, .. } => Some(size), + _ => None, + }; + match (left_size, self.resolve_type(*right)?) { + (Some(size), &crate::TypeInner::Scalar { .. }) => { + *right = self.expressions.append( + crate::Expression::Splat { + size, + value: *right, + }, + self.expressions.get_span(*right), + ); + } + (None, &crate::TypeInner::Vector { size, .. }) => { + *left = self.expressions.append( + crate::Expression::Splat { size, value: *left }, + self.expressions.get_span(*left), + ); + } + _ => {} + } + } + + Ok(()) + } + + /// Add a single expression to the expression table that is not covered by `self.emitter`. + /// + /// This is useful for `CallResult` and `AtomicResult` expressions, which should not be covered by + /// `Emit` statements. + fn interrupt_emitter( + &mut self, + expression: crate::Expression, + span: NagaSpan, + ) -> Handle { + self.block.extend(self.emitter.finish(self.expressions)); + let result = self.expressions.append(expression, span); + self.emitter.start(self.expressions); + result + } + + /// Apply the WGSL Load Rule to `expr`. + /// + /// If `expr` is has type `ref`, perform a load to produce a value of type + /// `T`. Otherwise, return `expr` unchanged. + fn apply_load_rule(&mut self, expr: TypedExpression) -> Handle { + if expr.is_reference { + let load = crate::Expression::Load { + pointer: expr.handle, + }; + let span = self.expressions.get_span(expr.handle); + self.expressions.append(load, span) + } else { + expr.handle + } + } + + /// Creates a zero value constant of type `ty` + /// + /// Returns `None` if the given `ty` is not a constructible type + fn create_zero_value_constant( + &mut self, + ty: Handle, + ) -> Option> { + let inner = match self.types[ty].inner { + crate::TypeInner::Scalar { kind, width } => { + let value = match kind { + crate::ScalarKind::Sint => crate::ScalarValue::Sint(0), + crate::ScalarKind::Uint => crate::ScalarValue::Uint(0), + crate::ScalarKind::Float => crate::ScalarValue::Float(0.), + crate::ScalarKind::Bool => crate::ScalarValue::Bool(false), + }; + crate::ConstantInner::Scalar { width, value } + } + crate::TypeInner::Vector { size, kind, width } => { + let scalar_ty = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar { width, kind }, + }, + Default::default(), + ); + let component = self.create_zero_value_constant(scalar_ty); + crate::ConstantInner::Composite { + ty, + components: (0..size as u8).map(|_| component).collect::>()?, + } + } + crate::TypeInner::Matrix { + columns, + rows, + width, + } => { + let vec_ty = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Vector { + width, + kind: crate::ScalarKind::Float, + size: rows, + }, + }, + Default::default(), + ); + let component = self.create_zero_value_constant(vec_ty); + crate::ConstantInner::Composite { + ty, + components: (0..columns as u8) + .map(|_| component) + .collect::>()?, + } + } + crate::TypeInner::Array { + base, + size: crate::ArraySize::Constant(size), + .. + } => { + let component = self.create_zero_value_constant(base); + crate::ConstantInner::Composite { + ty, + components: (0..self.constants[size].to_array_length().unwrap()) + .map(|_| component) + .collect::>()?, + } + } + crate::TypeInner::Struct { ref members, .. } => { + let members = members.clone(); + crate::ConstantInner::Composite { + ty, + components: members + .iter() + .map(|member| self.create_zero_value_constant(member.ty)) + .collect::>()?, + } + } + _ => return None, + }; + + let constant = self.constants.fetch_or_append( + crate::Constant { + name: None, + specialization: None, + inner, + }, + crate::Span::default(), + ); + Some(constant) + } +} + +/// A Naga [`Expression`] handle, with WGSL type information. +/// +/// Naga and WGSL types are very close, but Naga lacks WGSL's 'reference' types, +/// which we need to know to apply the Load Rule. This struct carries a Naga +/// `Handle` along with enough information to determine its WGSL type. +/// +/// [`Expression`]: crate::Expression +#[derive(Debug, Copy, Clone)] +struct TypedExpression { + /// The handle of the Naga expression. + handle: Handle, + + /// True if this expression's WGSL type is a reference. + /// + /// When this is true, `handle` must be a pointer. + is_reference: bool, +} + +impl TypedExpression { + const fn non_reference(handle: Handle) -> TypedExpression { + TypedExpression { + handle, + is_reference: false, + } + } +} + +enum Composition { + Single(u32), + Multi(crate::VectorSize, [crate::SwizzleComponent; 4]), +} + +impl Composition { + const fn letter_component(letter: char) -> Option { + use crate::SwizzleComponent as Sc; + match letter { + 'x' | 'r' => Some(Sc::X), + 'y' | 'g' => Some(Sc::Y), + 'z' | 'b' => Some(Sc::Z), + 'w' | 'a' => Some(Sc::W), + _ => None, + } + } + + fn extract_impl(name: &str, name_span: Span) -> Result { + let ch = name + .chars() + .next() + .ok_or_else(|| Error::BadAccessor(name_span.clone()))?; + match Self::letter_component(ch) { + Some(sc) => Ok(sc as u32), + None => Err(Error::BadAccessor(name_span)), + } + } + + fn make(name: &str, name_span: Span) -> Result { + if name.len() > 1 { + let mut components = [crate::SwizzleComponent::X; 4]; + for (comp, ch) in components.iter_mut().zip(name.chars()) { + *comp = Self::letter_component(ch) + .ok_or_else(|| Error::BadAccessor(name_span.clone()))?; + } + + let size = match name.len() { + 2 => crate::VectorSize::Bi, + 3 => crate::VectorSize::Tri, + 4 => crate::VectorSize::Quad, + _ => return Err(Error::BadAccessor(name_span)), + }; + Ok(Composition::Multi(size, components)) + } else { + Self::extract_impl(name, name_span).map(Composition::Single) + } + } +} + +#[derive(Default)] +struct TypeAttributes { + // Although WGSL nas no type attributes at the moment, it had them in the past + // (`[[stride]]`) and may as well acquire some again in the future. + // Therefore, we are leaving the plumbing in for now. +} + +/// Which grammar rule we are in the midst of parsing. +/// +/// This is used for error checking. `Parser` maintains a stack of +/// these and (occasionally) checks that it is being pushed and popped +/// as expected. +#[derive(Clone, Debug, PartialEq)] +enum Rule { + Attribute, + VariableDecl, + TypeDecl, + FunctionDecl, + Block, + Statement, + ConstantExpr, + PrimaryExpr, + SingularExpr, + UnaryExpr, + GeneralExpr, +} + +type LocalFunctionCall = (Handle, Vec>); + +#[derive(Default)] +struct BindingParser { + location: Option, + built_in: Option, + interpolation: Option, + sampling: Option, + invariant: bool, +} + +impl BindingParser { + fn parse<'a>( + &mut self, + lexer: &mut Lexer<'a>, + name: &'a str, + name_span: Span, + ) -> Result<(), Error<'a>> { + match name { + "location" => { + lexer.expect(Token::Paren('('))?; + self.location = Some(Parser::parse_non_negative_i32_literal(lexer)?); + lexer.expect(Token::Paren(')'))?; + } + "builtin" => { + lexer.expect(Token::Paren('('))?; + let (raw, span) = lexer.next_ident_with_span()?; + self.built_in = Some(conv::map_built_in(raw, span)?); + lexer.expect(Token::Paren(')'))?; + } + "interpolate" => { + lexer.expect(Token::Paren('('))?; + let (raw, span) = lexer.next_ident_with_span()?; + self.interpolation = Some(conv::map_interpolation(raw, span)?); + if lexer.skip(Token::Separator(',')) { + let (raw, span) = lexer.next_ident_with_span()?; + self.sampling = Some(conv::map_sampling(raw, span)?); + } + lexer.expect(Token::Paren(')'))?; + } + "invariant" => self.invariant = true, + _ => return Err(Error::UnknownAttribute(name_span)), + } + Ok(()) + } + + const fn finish<'a>(self, span: Span) -> Result, Error<'a>> { + match ( + self.location, + self.built_in, + self.interpolation, + self.sampling, + self.invariant, + ) { + (None, None, None, None, false) => Ok(None), + (Some(location), None, interpolation, sampling, false) => { + // Before handing over the completed `Module`, we call + // `apply_default_interpolation` to ensure that the interpolation and + // sampling have been explicitly specified on all vertex shader output and fragment + // shader input user bindings, so leaving them potentially `None` here is fine. + Ok(Some(crate::Binding::Location { + location, + interpolation, + sampling, + })) + } + (None, Some(crate::BuiltIn::Position { .. }), None, None, invariant) => { + Ok(Some(crate::Binding::BuiltIn(crate::BuiltIn::Position { + invariant, + }))) + } + (None, Some(built_in), None, None, false) => { + Ok(Some(crate::Binding::BuiltIn(built_in))) + } + (_, _, _, _, _) => Err(Error::InconsistentBinding(span)), + } + } +} + +struct ParsedVariable<'a> { + name: &'a str, + name_span: Span, + space: Option, + ty: Handle, + init: Option>, +} + +struct CalledFunction { + result: Option>, +} + +#[derive(Clone, Debug)] +pub struct ParseError { + message: String, + labels: Vec<(Span, Cow<'static, str>)>, + notes: Vec, +} + +impl ParseError { + pub fn labels(&self) -> impl Iterator + ExactSizeIterator + '_ { + self.labels + .iter() + .map(|&(ref span, ref msg)| (span.clone(), msg.as_ref())) + } + + pub fn message(&self) -> &str { + &self.message + } + + fn diagnostic(&self) -> Diagnostic<()> { + let diagnostic = Diagnostic::error() + .with_message(self.message.to_string()) + .with_labels( + self.labels + .iter() + .map(|label| { + Label::primary((), label.0.clone()).with_message(label.1.to_string()) + }) + .collect(), + ) + .with_notes( + self.notes + .iter() + .map(|note| format!("note: {}", note)) + .collect(), + ); + diagnostic + } + + /// Emits a summary of the error to standard error stream. + pub fn emit_to_stderr(&self, source: &str) { + self.emit_to_stderr_with_path(source, "wgsl") + } + + /// Emits a summary of the error to standard error stream. + pub fn emit_to_stderr_with_path(&self, source: &str, path: &str) { + let files = SimpleFile::new(path, source); + let config = codespan_reporting::term::Config::default(); + let writer = StandardStream::stderr(ColorChoice::Auto); + term::emit(&mut writer.lock(), &config, &files, &self.diagnostic()) + .expect("cannot write error"); + } + + /// Emits a summary of the error to a string. + pub fn emit_to_string(&self, source: &str) -> String { + self.emit_to_string_with_path(source, "wgsl") + } + + /// Emits a summary of the error to a string. + pub fn emit_to_string_with_path(&self, source: &str, path: &str) -> String { + let files = SimpleFile::new(path, source); + let config = codespan_reporting::term::Config::default(); + let mut writer = NoColor::new(Vec::new()); + term::emit(&mut writer, &config, &files, &self.diagnostic()).expect("cannot write error"); + String::from_utf8(writer.into_inner()).unwrap() + } + + /// Returns a [`SourceLocation`] for the first label in the error message. + pub fn location(&self, source: &str) -> Option { + self.labels + .get(0) + .map(|label| NagaSpan::new(label.0.start as u32, label.0.end as u32).location(source)) + } +} + +impl std::fmt::Display for ParseError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.message) + } +} + +impl std::error::Error for ParseError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + None + } +} + +pub struct Parser { + rules: Vec<(Rule, usize)>, + module_scope_identifiers: FastHashMap, + lookup_type: FastHashMap>, + layouter: Layouter, +} + +impl Parser { + pub fn new() -> Self { + Parser { + rules: Vec::new(), + module_scope_identifiers: FastHashMap::default(), + lookup_type: FastHashMap::default(), + layouter: Default::default(), + } + } + + fn reset(&mut self) { + self.rules.clear(); + self.module_scope_identifiers.clear(); + self.lookup_type.clear(); + self.layouter.clear(); + } + + fn push_rule_span(&mut self, rule: Rule, lexer: &mut Lexer<'_>) { + self.rules.push((rule, lexer.start_byte_offset())); + } + + fn pop_rule_span(&mut self, lexer: &Lexer<'_>) -> Span { + let (_, initial) = self.rules.pop().unwrap(); + lexer.span_from(initial) + } + + fn peek_rule_span(&mut self, lexer: &Lexer<'_>) -> Span { + let &(_, initial) = self.rules.last().unwrap(); + lexer.span_from(initial) + } + + fn parse_switch_value<'a>(lexer: &mut Lexer<'a>, uint: bool) -> Result> { + let token_span = lexer.next(); + match token_span.0 { + Token::Number(Ok(Number::U32(num))) if uint => Ok(num as i32), + Token::Number(Ok(Number::I32(num))) if !uint => Ok(num), + Token::Number(Err(e)) => Err(Error::BadNumber(token_span.1, e)), + _ => Err(Error::Unexpected(token_span.1, ExpectedToken::Integer)), + } + } + + /// Parse a non-negative signed integer literal. + /// This is for attributes like `size`, `location` and others. + fn parse_non_negative_i32_literal<'a>(lexer: &mut Lexer<'a>) -> Result> { + match lexer.next() { + (Token::Number(Ok(Number::I32(num))), span) => { + u32::try_from(num).map_err(|_| Error::NegativeInt(span)) + } + (Token::Number(Err(e)), span) => Err(Error::BadNumber(span, e)), + other => Err(Error::Unexpected( + other.1, + ExpectedToken::Number(NumberType::I32), + )), + } + } + + /// Parse a non-negative integer literal that may be either signed or unsigned. + /// This is for the `workgroup_size` attribute and array lengths. + /// Note: these values should be no larger than [`i32::MAX`], but this is not checked here. + fn parse_generic_non_negative_int_literal<'a>(lexer: &mut Lexer<'a>) -> Result> { + match lexer.next() { + (Token::Number(Ok(Number::I32(num))), span) => { + u32::try_from(num).map_err(|_| Error::NegativeInt(span)) + } + (Token::Number(Ok(Number::U32(num))), _) => Ok(num), + (Token::Number(Err(e)), span) => Err(Error::BadNumber(span, e)), + other => Err(Error::Unexpected( + other.1, + ExpectedToken::Number(NumberType::I32), + )), + } + } + + fn parse_atomic_pointer<'a>( + &mut self, + lexer: &mut Lexer<'a>, + mut ctx: ExpressionContext<'a, '_, '_>, + ) -> Result, Error<'a>> { + let (pointer, pointer_span) = + lexer.capture_span(|lexer| self.parse_general_expression(lexer, ctx.reborrow()))?; + // Check if the pointer expression is to an atomic. + // The IR uses regular `Expression::Load` and `Statement::Store` for atomic load/stores, + // and it will not catch the use of a non-atomic variable here. + match *ctx.resolve_type(pointer)? { + crate::TypeInner::Pointer { base, .. } => match ctx.types[base].inner { + crate::TypeInner::Atomic { .. } => Ok(pointer), + ref other => { + log::error!("Pointer type to {:?} passed to atomic op", other); + Err(Error::InvalidAtomicPointer(pointer_span)) + } + }, + ref other => { + log::error!("Type {:?} passed to atomic op", other); + Err(Error::InvalidAtomicPointer(pointer_span)) + } + } + } + + /// Expects name to be peeked from lexer, does not consume if returns None. + fn parse_local_function_call<'a>( + &mut self, + lexer: &mut Lexer<'a>, + name: &'a str, + mut ctx: ExpressionContext<'a, '_, '_>, + ) -> Result, Error<'a>> { + let fun_handle = match ctx.functions.iter().find(|&(_, fun)| match fun.name { + Some(ref string) => string == name, + None => false, + }) { + Some((fun_handle, _)) => fun_handle, + None => return Ok(None), + }; + + let count = ctx.functions[fun_handle].arguments.len(); + let mut arguments = Vec::with_capacity(count); + let _ = lexer.next(); + lexer.open_arguments()?; + while arguments.len() != count { + if !arguments.is_empty() { + lexer.expect(Token::Separator(','))?; + } + let arg = self.parse_general_expression(lexer, ctx.reborrow())?; + arguments.push(arg); + } + lexer.close_arguments()?; + Ok(Some((fun_handle, arguments))) + } + + fn parse_atomic_helper<'a>( + &mut self, + lexer: &mut Lexer<'a>, + fun: crate::AtomicFunction, + mut ctx: ExpressionContext<'a, '_, '_>, + ) -> Result, Error<'a>> { + lexer.open_arguments()?; + let pointer = self.parse_general_expression(lexer, ctx.reborrow())?; + lexer.expect(Token::Separator(','))?; + let ctx_span = ctx.reborrow(); + let (value, value_span) = + lexer.capture_span(|lexer| self.parse_general_expression(lexer, ctx_span))?; + lexer.close_arguments()?; + + let expression = match *ctx.resolve_type(value)? { + crate::TypeInner::Scalar { kind, width } => crate::Expression::AtomicResult { + kind, + width, + comparison: false, + }, + _ => return Err(Error::InvalidAtomicOperandType(value_span)), + }; + + let span = NagaSpan::from(value_span); + let result = ctx.interrupt_emitter(expression, span); + ctx.block.push( + crate::Statement::Atomic { + pointer, + fun, + value, + result, + }, + span, + ); + Ok(result) + } + + /// Expects [`Rule::PrimaryExpr`] or [`Rule::SingularExpr`] on top; does not pop it. + /// Expects `word` to be peeked (still in lexer), doesn't consume if returning None. + fn parse_function_call_inner<'a>( + &mut self, + lexer: &mut Lexer<'a>, + name: &'a str, + mut ctx: ExpressionContext<'a, '_, '_>, + ) -> Result, Error<'a>> { + assert!(self.rules.last().is_some()); + let expr = if let Some(fun) = conv::map_relational_fun(name) { + let _ = lexer.next(); + lexer.open_arguments()?; + let argument = self.parse_general_expression(lexer, ctx.reborrow())?; + lexer.close_arguments()?; + crate::Expression::Relational { fun, argument } + } else if let Some(axis) = conv::map_derivative_axis(name) { + let _ = lexer.next(); + lexer.open_arguments()?; + let expr = self.parse_general_expression(lexer, ctx.reborrow())?; + lexer.close_arguments()?; + crate::Expression::Derivative { axis, expr } + } else if let Some(fun) = conv::map_standard_fun(name) { + let _ = lexer.next(); + lexer.open_arguments()?; + let arg_count = fun.argument_count(); + let arg = self.parse_general_expression(lexer, ctx.reborrow())?; + let arg1 = if arg_count > 1 { + lexer.expect(Token::Separator(','))?; + Some(self.parse_general_expression(lexer, ctx.reborrow())?) + } else { + None + }; + let arg2 = if arg_count > 2 { + lexer.expect(Token::Separator(','))?; + Some(self.parse_general_expression(lexer, ctx.reborrow())?) + } else { + None + }; + let arg3 = if arg_count > 3 { + lexer.expect(Token::Separator(','))?; + Some(self.parse_general_expression(lexer, ctx.reborrow())?) + } else { + None + }; + lexer.close_arguments()?; + crate::Expression::Math { + fun, + arg, + arg1, + arg2, + arg3, + } + } else { + match name { + "bitcast" => { + let _ = lexer.next(); + lexer.expect_generic_paren('<')?; + let (ty, type_span) = lexer.capture_span(|lexer| { + self.parse_type_decl(lexer, None, ctx.types, ctx.constants) + })?; + lexer.expect_generic_paren('>')?; + + lexer.open_arguments()?; + let expr = self.parse_general_expression(lexer, ctx.reborrow())?; + lexer.close_arguments()?; + + let kind = match ctx.types[ty].inner { + crate::TypeInner::Scalar { kind, .. } => kind, + crate::TypeInner::Vector { kind, .. } => kind, + _ => { + return Err(Error::BadTypeCast { + from_type: format!("{:?}", ctx.resolve_type(expr)?), + span: type_span, + to_type: format!("{:?}", ctx.types[ty].inner), + }) + } + }; + + crate::Expression::As { + expr, + kind, + convert: None, + } + } + "select" => { + let _ = lexer.next(); + lexer.open_arguments()?; + let reject = self.parse_general_expression(lexer, ctx.reborrow())?; + lexer.expect(Token::Separator(','))?; + let accept = self.parse_general_expression(lexer, ctx.reborrow())?; + lexer.expect(Token::Separator(','))?; + let condition = self.parse_general_expression(lexer, ctx.reborrow())?; + lexer.close_arguments()?; + crate::Expression::Select { + condition, + accept, + reject, + } + } + "arrayLength" => { + let _ = lexer.next(); + lexer.open_arguments()?; + let array = self.parse_general_expression(lexer, ctx.reborrow())?; + lexer.close_arguments()?; + crate::Expression::ArrayLength(array) + } + // atomics + "atomicLoad" => { + let _ = lexer.next(); + lexer.open_arguments()?; + let pointer = self.parse_atomic_pointer(lexer, ctx.reborrow())?; + lexer.close_arguments()?; + crate::Expression::Load { pointer } + } + "atomicAdd" => { + let _ = lexer.next(); + let handle = self.parse_atomic_helper( + lexer, + crate::AtomicFunction::Add, + ctx.reborrow(), + )?; + return Ok(Some(CalledFunction { + result: Some(handle), + })); + } + "atomicSub" => { + let _ = lexer.next(); + let handle = self.parse_atomic_helper( + lexer, + crate::AtomicFunction::Subtract, + ctx.reborrow(), + )?; + return Ok(Some(CalledFunction { + result: Some(handle), + })); + } + "atomicAnd" => { + let _ = lexer.next(); + let handle = self.parse_atomic_helper( + lexer, + crate::AtomicFunction::And, + ctx.reborrow(), + )?; + return Ok(Some(CalledFunction { + result: Some(handle), + })); + } + "atomicOr" => { + let _ = lexer.next(); + let handle = self.parse_atomic_helper( + lexer, + crate::AtomicFunction::InclusiveOr, + ctx.reborrow(), + )?; + return Ok(Some(CalledFunction { + result: Some(handle), + })); + } + "atomicXor" => { + let _ = lexer.next(); + let handle = self.parse_atomic_helper( + lexer, + crate::AtomicFunction::ExclusiveOr, + ctx.reborrow(), + )?; + return Ok(Some(CalledFunction { + result: Some(handle), + })); + } + "atomicMin" => { + let _ = lexer.next(); + let handle = + self.parse_atomic_helper(lexer, crate::AtomicFunction::Min, ctx)?; + return Ok(Some(CalledFunction { + result: Some(handle), + })); + } + "atomicMax" => { + let _ = lexer.next(); + let handle = + self.parse_atomic_helper(lexer, crate::AtomicFunction::Max, ctx)?; + return Ok(Some(CalledFunction { + result: Some(handle), + })); + } + "atomicExchange" => { + let _ = lexer.next(); + let handle = self.parse_atomic_helper( + lexer, + crate::AtomicFunction::Exchange { compare: None }, + ctx, + )?; + return Ok(Some(CalledFunction { + result: Some(handle), + })); + } + "atomicCompareExchangeWeak" => { + let _ = lexer.next(); + lexer.open_arguments()?; + let pointer = self.parse_general_expression(lexer, ctx.reborrow())?; + lexer.expect(Token::Separator(','))?; + let cmp = self.parse_general_expression(lexer, ctx.reborrow())?; + lexer.expect(Token::Separator(','))?; + let (value, value_span) = lexer.capture_span(|lexer| { + self.parse_general_expression(lexer, ctx.reborrow()) + })?; + lexer.close_arguments()?; + + let expression = match *ctx.resolve_type(value)? { + crate::TypeInner::Scalar { kind, width } => { + crate::Expression::AtomicResult { + kind, + width, + comparison: true, + } + } + _ => return Err(Error::InvalidAtomicOperandType(value_span)), + }; + + let span = NagaSpan::from(self.peek_rule_span(lexer)); + let result = ctx.interrupt_emitter(expression, span); + ctx.block.push( + crate::Statement::Atomic { + pointer, + fun: crate::AtomicFunction::Exchange { compare: Some(cmp) }, + value, + result, + }, + span, + ); + return Ok(Some(CalledFunction { + result: Some(result), + })); + } + // texture sampling + "textureSample" => { + let _ = lexer.next(); + lexer.open_arguments()?; + let (image, image_span) = + self.parse_general_expression_with_span(lexer, ctx.reborrow())?; + lexer.expect(Token::Separator(','))?; + let sampler_expr = self.parse_general_expression(lexer, ctx.reborrow())?; + lexer.expect(Token::Separator(','))?; + let coordinate = self.parse_general_expression(lexer, ctx.reborrow())?; + let sc = ctx.prepare_sampling(image, image_span)?; + let array_index = if sc.arrayed { + lexer.expect(Token::Separator(','))?; + Some(self.parse_general_expression(lexer, ctx.reborrow())?) + } else { + None + }; + let offset = if lexer.skip(Token::Separator(',')) { + Some(self.parse_const_expression(lexer, ctx.types, ctx.constants)?) + } else { + None + }; + lexer.close_arguments()?; + crate::Expression::ImageSample { + image: sc.image, + sampler: sampler_expr, + gather: None, + coordinate, + array_index, + offset, + level: crate::SampleLevel::Auto, + depth_ref: None, + } + } + "textureSampleLevel" => { + let _ = lexer.next(); + lexer.open_arguments()?; + let (image, image_span) = + self.parse_general_expression_with_span(lexer, ctx.reborrow())?; + lexer.expect(Token::Separator(','))?; + let sampler_expr = self.parse_general_expression(lexer, ctx.reborrow())?; + lexer.expect(Token::Separator(','))?; + let coordinate = self.parse_general_expression(lexer, ctx.reborrow())?; + let sc = ctx.prepare_sampling(image, image_span)?; + let array_index = if sc.arrayed { + lexer.expect(Token::Separator(','))?; + Some(self.parse_general_expression(lexer, ctx.reborrow())?) + } else { + None + }; + lexer.expect(Token::Separator(','))?; + let level = self.parse_general_expression(lexer, ctx.reborrow())?; + let offset = if lexer.skip(Token::Separator(',')) { + Some(self.parse_const_expression(lexer, ctx.types, ctx.constants)?) + } else { + None + }; + lexer.close_arguments()?; + crate::Expression::ImageSample { + image: sc.image, + sampler: sampler_expr, + gather: None, + coordinate, + array_index, + offset, + level: crate::SampleLevel::Exact(level), + depth_ref: None, + } + } + "textureSampleBias" => { + let _ = lexer.next(); + lexer.open_arguments()?; + let (image, image_span) = + self.parse_general_expression_with_span(lexer, ctx.reborrow())?; + lexer.expect(Token::Separator(','))?; + let sampler_expr = self.parse_general_expression(lexer, ctx.reborrow())?; + lexer.expect(Token::Separator(','))?; + let coordinate = self.parse_general_expression(lexer, ctx.reborrow())?; + let sc = ctx.prepare_sampling(image, image_span)?; + let array_index = if sc.arrayed { + lexer.expect(Token::Separator(','))?; + Some(self.parse_general_expression(lexer, ctx.reborrow())?) + } else { + None + }; + lexer.expect(Token::Separator(','))?; + let bias = self.parse_general_expression(lexer, ctx.reborrow())?; + let offset = if lexer.skip(Token::Separator(',')) { + Some(self.parse_const_expression(lexer, ctx.types, ctx.constants)?) + } else { + None + }; + lexer.close_arguments()?; + crate::Expression::ImageSample { + image: sc.image, + sampler: sampler_expr, + gather: None, + coordinate, + array_index, + offset, + level: crate::SampleLevel::Bias(bias), + depth_ref: None, + } + } + "textureSampleGrad" => { + let _ = lexer.next(); + lexer.open_arguments()?; + let (image, image_span) = + self.parse_general_expression_with_span(lexer, ctx.reborrow())?; + lexer.expect(Token::Separator(','))?; + let sampler_expr = self.parse_general_expression(lexer, ctx.reborrow())?; + lexer.expect(Token::Separator(','))?; + let coordinate = self.parse_general_expression(lexer, ctx.reborrow())?; + let sc = ctx.prepare_sampling(image, image_span)?; + let array_index = if sc.arrayed { + lexer.expect(Token::Separator(','))?; + Some(self.parse_general_expression(lexer, ctx.reborrow())?) + } else { + None + }; + lexer.expect(Token::Separator(','))?; + let x = self.parse_general_expression(lexer, ctx.reborrow())?; + lexer.expect(Token::Separator(','))?; + let y = self.parse_general_expression(lexer, ctx.reborrow())?; + let offset = if lexer.skip(Token::Separator(',')) { + Some(self.parse_const_expression(lexer, ctx.types, ctx.constants)?) + } else { + None + }; + lexer.close_arguments()?; + crate::Expression::ImageSample { + image: sc.image, + sampler: sampler_expr, + gather: None, + coordinate, + array_index, + offset, + level: crate::SampleLevel::Gradient { x, y }, + depth_ref: None, + } + } + "textureSampleCompare" => { + let _ = lexer.next(); + lexer.open_arguments()?; + let (image, image_span) = + self.parse_general_expression_with_span(lexer, ctx.reborrow())?; + lexer.expect(Token::Separator(','))?; + let sampler_expr = self.parse_general_expression(lexer, ctx.reborrow())?; + lexer.expect(Token::Separator(','))?; + let coordinate = self.parse_general_expression(lexer, ctx.reborrow())?; + let sc = ctx.prepare_sampling(image, image_span)?; + let array_index = if sc.arrayed { + lexer.expect(Token::Separator(','))?; + Some(self.parse_general_expression(lexer, ctx.reborrow())?) + } else { + None + }; + lexer.expect(Token::Separator(','))?; + let reference = self.parse_general_expression(lexer, ctx.reborrow())?; + let offset = if lexer.skip(Token::Separator(',')) { + Some(self.parse_const_expression(lexer, ctx.types, ctx.constants)?) + } else { + None + }; + lexer.close_arguments()?; + crate::Expression::ImageSample { + image: sc.image, + sampler: sampler_expr, + gather: None, + coordinate, + array_index, + offset, + level: crate::SampleLevel::Auto, + depth_ref: Some(reference), + } + } + "textureSampleCompareLevel" => { + let _ = lexer.next(); + lexer.open_arguments()?; + let (image, image_span) = + self.parse_general_expression_with_span(lexer, ctx.reborrow())?; + lexer.expect(Token::Separator(','))?; + let sampler_expr = self.parse_general_expression(lexer, ctx.reborrow())?; + lexer.expect(Token::Separator(','))?; + let coordinate = self.parse_general_expression(lexer, ctx.reborrow())?; + let sc = ctx.prepare_sampling(image, image_span)?; + let array_index = if sc.arrayed { + lexer.expect(Token::Separator(','))?; + Some(self.parse_general_expression(lexer, ctx.reborrow())?) + } else { + None + }; + lexer.expect(Token::Separator(','))?; + let reference = self.parse_general_expression(lexer, ctx.reborrow())?; + let offset = if lexer.skip(Token::Separator(',')) { + Some(self.parse_const_expression(lexer, ctx.types, ctx.constants)?) + } else { + None + }; + lexer.close_arguments()?; + crate::Expression::ImageSample { + image: sc.image, + sampler: sampler_expr, + gather: None, + coordinate, + array_index, + offset, + level: crate::SampleLevel::Zero, + depth_ref: Some(reference), + } + } + "textureGather" => { + let _ = lexer.next(); + lexer.open_arguments()?; + let component = if let (Token::Number(..), span) = lexer.peek() { + let index = Self::parse_non_negative_i32_literal(lexer)?; + lexer.expect(Token::Separator(','))?; + *crate::SwizzleComponent::XYZW + .get(index as usize) + .ok_or(Error::InvalidGatherComponent(span, index))? + } else { + crate::SwizzleComponent::X + }; + let (image, image_span) = + self.parse_general_expression_with_span(lexer, ctx.reborrow())?; + lexer.expect(Token::Separator(','))?; + let sampler_expr = self.parse_general_expression(lexer, ctx.reborrow())?; + lexer.expect(Token::Separator(','))?; + let coordinate = self.parse_general_expression(lexer, ctx.reborrow())?; + let sc = ctx.prepare_sampling(image, image_span)?; + let array_index = if sc.arrayed { + lexer.expect(Token::Separator(','))?; + Some(self.parse_general_expression(lexer, ctx.reborrow())?) + } else { + None + }; + let offset = if lexer.skip(Token::Separator(',')) { + Some(self.parse_const_expression(lexer, ctx.types, ctx.constants)?) + } else { + None + }; + lexer.close_arguments()?; + crate::Expression::ImageSample { + image: sc.image, + sampler: sampler_expr, + gather: Some(component), + coordinate, + array_index, + offset, + level: crate::SampleLevel::Zero, + depth_ref: None, + } + } + "textureGatherCompare" => { + let _ = lexer.next(); + lexer.open_arguments()?; + let (image, image_span) = + self.parse_general_expression_with_span(lexer, ctx.reborrow())?; + lexer.expect(Token::Separator(','))?; + let sampler_expr = self.parse_general_expression(lexer, ctx.reborrow())?; + lexer.expect(Token::Separator(','))?; + let coordinate = self.parse_general_expression(lexer, ctx.reborrow())?; + let sc = ctx.prepare_sampling(image, image_span)?; + let array_index = if sc.arrayed { + lexer.expect(Token::Separator(','))?; + Some(self.parse_general_expression(lexer, ctx.reborrow())?) + } else { + None + }; + lexer.expect(Token::Separator(','))?; + let reference = self.parse_general_expression(lexer, ctx.reborrow())?; + let offset = if lexer.skip(Token::Separator(',')) { + Some(self.parse_const_expression(lexer, ctx.types, ctx.constants)?) + } else { + None + }; + lexer.close_arguments()?; + crate::Expression::ImageSample { + image: sc.image, + sampler: sampler_expr, + gather: Some(crate::SwizzleComponent::X), + coordinate, + array_index, + offset, + level: crate::SampleLevel::Zero, + depth_ref: Some(reference), + } + } + "textureLoad" => { + let _ = lexer.next(); + lexer.open_arguments()?; + let (image, image_span) = + self.parse_general_expression_with_span(lexer, ctx.reborrow())?; + lexer.expect(Token::Separator(','))?; + let coordinate = self.parse_general_expression(lexer, ctx.reborrow())?; + let (class, arrayed) = match *ctx.resolve_type(image)? { + crate::TypeInner::Image { class, arrayed, .. } => (class, arrayed), + _ => return Err(Error::BadTexture(image_span)), + }; + let array_index = if arrayed { + lexer.expect(Token::Separator(','))?; + Some(self.parse_general_expression(lexer, ctx.reborrow())?) + } else { + None + }; + let level = if class.is_mipmapped() { + lexer.expect(Token::Separator(','))?; + Some(self.parse_general_expression(lexer, ctx.reborrow())?) + } else { + None + }; + let sample = if class.is_multisampled() { + lexer.expect(Token::Separator(','))?; + Some(self.parse_general_expression(lexer, ctx.reborrow())?) + } else { + None + }; + lexer.close_arguments()?; + crate::Expression::ImageLoad { + image, + coordinate, + array_index, + sample, + level, + } + } + "textureDimensions" => { + let _ = lexer.next(); + lexer.open_arguments()?; + let image = self.parse_general_expression(lexer, ctx.reborrow())?; + let level = if lexer.skip(Token::Separator(',')) { + let expr = self.parse_general_expression(lexer, ctx.reborrow())?; + Some(expr) + } else { + None + }; + lexer.close_arguments()?; + crate::Expression::ImageQuery { + image, + query: crate::ImageQuery::Size { level }, + } + } + "textureNumLevels" => { + let _ = lexer.next(); + lexer.open_arguments()?; + let image = self.parse_general_expression(lexer, ctx.reborrow())?; + lexer.close_arguments()?; + crate::Expression::ImageQuery { + image, + query: crate::ImageQuery::NumLevels, + } + } + "textureNumLayers" => { + let _ = lexer.next(); + lexer.open_arguments()?; + let image = self.parse_general_expression(lexer, ctx.reborrow())?; + lexer.close_arguments()?; + crate::Expression::ImageQuery { + image, + query: crate::ImageQuery::NumLayers, + } + } + "textureNumSamples" => { + let _ = lexer.next(); + lexer.open_arguments()?; + let image = self.parse_general_expression(lexer, ctx.reborrow())?; + lexer.close_arguments()?; + crate::Expression::ImageQuery { + image, + query: crate::ImageQuery::NumSamples, + } + } + // other + _ => { + let result = + match self.parse_local_function_call(lexer, name, ctx.reborrow())? { + Some((function, arguments)) => { + let span = NagaSpan::from(self.peek_rule_span(lexer)); + ctx.block.extend(ctx.emitter.finish(ctx.expressions)); + let result = ctx.functions[function].result.as_ref().map(|_| { + ctx.expressions + .append(crate::Expression::CallResult(function), span) + }); + ctx.emitter.start(ctx.expressions); + ctx.block.push( + crate::Statement::Call { + function, + arguments, + result, + }, + span, + ); + result + } + None => return Ok(None), + }; + return Ok(Some(CalledFunction { result })); + } + } + }; + let span = NagaSpan::from(self.peek_rule_span(lexer)); + let handle = ctx.expressions.append(expr, span); + Ok(Some(CalledFunction { + result: Some(handle), + })) + } + + fn parse_const_expression_impl<'a>( + &mut self, + first_token_span: TokenSpan<'a>, + lexer: &mut Lexer<'a>, + register_name: Option<&'a str>, + type_arena: &mut UniqueArena, + const_arena: &mut Arena, + ) -> Result, Error<'a>> { + self.push_rule_span(Rule::ConstantExpr, lexer); + let inner = match first_token_span { + (Token::Word("true"), _) => crate::ConstantInner::boolean(true), + (Token::Word("false"), _) => crate::ConstantInner::boolean(false), + (Token::Number(num), _) => match num { + Ok(Number::I32(num)) => crate::ConstantInner::Scalar { + value: crate::ScalarValue::Sint(num as i64), + width: 4, + }, + Ok(Number::U32(num)) => crate::ConstantInner::Scalar { + value: crate::ScalarValue::Uint(num as u64), + width: 4, + }, + Ok(Number::F32(num)) => crate::ConstantInner::Scalar { + value: crate::ScalarValue::Float(num as f64), + width: 4, + }, + Ok(Number::AbstractInt(_) | Number::AbstractFloat(_)) => unreachable!(), + Err(e) => return Err(Error::BadNumber(first_token_span.1, e)), + }, + (Token::Word(name), name_span) => { + // look for an existing constant first + for (handle, var) in const_arena.iter() { + match var.name { + Some(ref string) if string == name => { + self.pop_rule_span(lexer); + return Ok(handle); + } + _ => {} + } + } + let composite_ty = self.parse_type_decl_name( + lexer, + name, + name_span, + None, + TypeAttributes::default(), + type_arena, + const_arena, + )?; + + lexer.open_arguments()?; + //Note: this expects at least one argument + let mut components = Vec::new(); + while components.is_empty() || lexer.next_argument()? { + let component = self.parse_const_expression(lexer, type_arena, const_arena)?; + components.push(component); + } + crate::ConstantInner::Composite { + ty: composite_ty, + components, + } + } + other => return Err(Error::Unexpected(other.1, ExpectedToken::Constant)), + }; + + // Only set span if it's a named constant. Otherwise, the enclosing Expression should have + // the span. + let span = self.pop_rule_span(lexer); + let handle = if let Some(name) = register_name { + if crate::keywords::wgsl::RESERVED.contains(&name) { + return Err(Error::ReservedKeyword(span)); + } + const_arena.append( + crate::Constant { + name: Some(name.to_string()), + specialization: None, + inner, + }, + NagaSpan::from(span), + ) + } else { + const_arena.fetch_or_append( + crate::Constant { + name: None, + specialization: None, + inner, + }, + Default::default(), + ) + }; + + Ok(handle) + } + + fn parse_const_expression<'a>( + &mut self, + lexer: &mut Lexer<'a>, + type_arena: &mut UniqueArena, + const_arena: &mut Arena, + ) -> Result, Error<'a>> { + self.parse_const_expression_impl(lexer.next(), lexer, None, type_arena, const_arena) + } + + fn parse_primary_expression<'a>( + &mut self, + lexer: &mut Lexer<'a>, + mut ctx: ExpressionContext<'a, '_, '_>, + ) -> Result> { + // Will be popped inside match, possibly inside parse_function_call_inner or parse_construction + self.push_rule_span(Rule::PrimaryExpr, lexer); + let expr = match lexer.peek() { + (Token::Paren('('), _) => { + let _ = lexer.next(); + let (expr, _span) = + self.parse_general_expression_for_reference(lexer, ctx.reborrow())?; + lexer.expect(Token::Paren(')'))?; + self.pop_rule_span(lexer); + expr + } + (Token::Word("true" | "false") | Token::Number(..), _) => { + let const_handle = self.parse_const_expression(lexer, ctx.types, ctx.constants)?; + let span = NagaSpan::from(self.pop_rule_span(lexer)); + TypedExpression::non_reference( + ctx.interrupt_emitter(crate::Expression::Constant(const_handle), span), + ) + } + (Token::Word(word), span) => { + if let Some(definition) = ctx.symbol_table.lookup(word) { + let _ = lexer.next(); + self.pop_rule_span(lexer); + + *definition + } else if let Some(CalledFunction { result: Some(expr) }) = + self.parse_function_call_inner(lexer, word, ctx.reborrow())? + { + //TODO: resolve the duplicate call in `parse_singular_expression` + self.pop_rule_span(lexer); + TypedExpression::non_reference(expr) + } else { + let _ = lexer.next(); + if let Some(expr) = construction::parse_construction( + self, + lexer, + word, + span.clone(), + ctx.reborrow(), + )? { + TypedExpression::non_reference(expr) + } else { + return Err(Error::UnknownIdent(span, word)); + } + } + } + other => return Err(Error::Unexpected(other.1, ExpectedToken::PrimaryExpression)), + }; + Ok(expr) + } + + fn parse_postfix<'a>( + &mut self, + span_start: usize, + lexer: &mut Lexer<'a>, + mut ctx: ExpressionContext<'a, '_, '_>, + expr: TypedExpression, + ) -> Result> { + // Parse postfix expressions, adjusting `handle` and `is_reference` along the way. + // + // Most postfix expressions don't affect `is_reference`: for example, `s.x` is a + // reference whenever `s` is a reference. But swizzles (WGSL spec: "multiple + // component selection") apply the load rule, converting references to values, so + // those affect `is_reference` as well as `handle`. + let TypedExpression { + mut handle, + mut is_reference, + } = expr; + let mut prefix_span = lexer.span_from(span_start); + + loop { + // Step lightly around `resolve_type`'s mutable borrow. + ctx.resolve_type(handle)?; + + // Find the type of the composite whose elements, components or members we're + // accessing, skipping through references: except for swizzles, the `Access` + // or `AccessIndex` expressions we'd generate are the same either way. + // + // Pointers, however, are not permitted. For error checks below, note whether + // the base expression is a WGSL pointer. + let temp_inner; + let (composite, wgsl_pointer) = match *ctx.typifier.get(handle, ctx.types) { + crate::TypeInner::Pointer { base, .. } => (&ctx.types[base].inner, !is_reference), + crate::TypeInner::ValuePointer { + size: None, + kind, + width, + .. + } => { + temp_inner = crate::TypeInner::Scalar { kind, width }; + (&temp_inner, !is_reference) + } + crate::TypeInner::ValuePointer { + size: Some(size), + kind, + width, + .. + } => { + temp_inner = crate::TypeInner::Vector { size, kind, width }; + (&temp_inner, !is_reference) + } + ref other => (other, false), + }; + + let expression = match lexer.peek().0 { + Token::Separator('.') => { + let _ = lexer.next(); + let (name, name_span) = lexer.next_ident_with_span()?; + + // WGSL doesn't allow accessing members on pointers, or swizzling + // them. But Naga IR doesn't distinguish pointers and references, so + // we must check here. + if wgsl_pointer { + return Err(Error::Pointer( + "the value accessed by a `.member` expression", + prefix_span, + )); + } + + let access = match *composite { + crate::TypeInner::Struct { ref members, .. } => { + let index = members + .iter() + .position(|m| m.name.as_deref() == Some(name)) + .ok_or(Error::BadAccessor(name_span))? + as u32; + crate::Expression::AccessIndex { + base: handle, + index, + } + } + crate::TypeInner::Vector { .. } | crate::TypeInner::Matrix { .. } => { + match Composition::make(name, name_span)? { + Composition::Multi(size, pattern) => { + // Once you apply the load rule, the expression is no + // longer a reference. + let current_expr = TypedExpression { + handle, + is_reference, + }; + let vector = ctx.apply_load_rule(current_expr); + is_reference = false; + + crate::Expression::Swizzle { + size, + vector, + pattern, + } + } + Composition::Single(index) => crate::Expression::AccessIndex { + base: handle, + index, + }, + } + } + _ => return Err(Error::BadAccessor(name_span)), + }; + + access + } + Token::Paren('[') => { + let (_, open_brace_span) = lexer.next(); + let index = self.parse_general_expression(lexer, ctx.reborrow())?; + let close_brace_span = lexer.expect_span(Token::Paren(']'))?; + + // WGSL doesn't allow pointers to be subscripted. But Naga IR doesn't + // distinguish pointers and references, so we must check here. + if wgsl_pointer { + return Err(Error::Pointer( + "the value indexed by a `[]` subscripting expression", + prefix_span, + )); + } + + if let crate::Expression::Constant(constant) = ctx.expressions[index] { + let expr_span = open_brace_span.end..close_brace_span.start; + + let index = match ctx.constants[constant].inner { + ConstantInner::Scalar { + value: ScalarValue::Uint(int), + .. + } => u32::try_from(int).map_err(|_| Error::BadU32Constant(expr_span)), + ConstantInner::Scalar { + value: ScalarValue::Sint(int), + .. + } => u32::try_from(int).map_err(|_| Error::BadU32Constant(expr_span)), + _ => Err(Error::BadU32Constant(expr_span)), + }?; + + crate::Expression::AccessIndex { + base: handle, + index, + } + } else { + crate::Expression::Access { + base: handle, + index, + } + } + } + _ => break, + }; + + prefix_span = lexer.span_from(span_start); + handle = ctx + .expressions + .append(expression, NagaSpan::from(prefix_span.clone())); + } + + Ok(TypedExpression { + handle, + is_reference, + }) + } + + /// Parse a `unary_expression`. + fn parse_unary_expression<'a>( + &mut self, + lexer: &mut Lexer<'a>, + mut ctx: ExpressionContext<'a, '_, '_>, + ) -> Result> { + self.push_rule_span(Rule::UnaryExpr, lexer); + //TODO: refactor this to avoid backing up + let expr = match lexer.peek().0 { + Token::Operation('-') => { + let _ = lexer.next(); + let unloaded_expr = self.parse_unary_expression(lexer, ctx.reborrow())?; + let expr = ctx.apply_load_rule(unloaded_expr); + let expr = crate::Expression::Unary { + op: crate::UnaryOperator::Negate, + expr, + }; + let span = NagaSpan::from(self.peek_rule_span(lexer)); + TypedExpression::non_reference(ctx.expressions.append(expr, span)) + } + Token::Operation('!' | '~') => { + let _ = lexer.next(); + let unloaded_expr = self.parse_unary_expression(lexer, ctx.reborrow())?; + let expr = ctx.apply_load_rule(unloaded_expr); + let expr = crate::Expression::Unary { + op: crate::UnaryOperator::Not, + expr, + }; + let span = NagaSpan::from(self.peek_rule_span(lexer)); + TypedExpression::non_reference(ctx.expressions.append(expr, span)) + } + Token::Operation('*') => { + let _ = lexer.next(); + // The `*` operator does not accept a reference, so we must apply the Load + // Rule here. But the operator itself simply changes the type from + // `ptr` to `ref`, so we generate no code for the + // operator itself. We simply return a `TypedExpression` with + // `is_reference` set to true. + let unloaded_pointer = self.parse_unary_expression(lexer, ctx.reborrow())?; + let pointer = ctx.apply_load_rule(unloaded_pointer); + + // An expression like `&*ptr` may generate no Naga IR at all, but WGSL requires + // an error if `ptr` is not a pointer. So we have to type-check this ourselves. + if ctx.resolve_type(pointer)?.pointer_space().is_none() { + let span = ctx + .expressions + .get_span(pointer) + .to_range() + .unwrap_or_else(|| self.peek_rule_span(lexer)); + return Err(Error::NotPointer(span)); + } + + TypedExpression { + handle: pointer, + is_reference: true, + } + } + Token::Operation('&') => { + let _ = lexer.next(); + // The `&` operator simply converts a reference to a pointer. And since a + // reference is required, the Load Rule is not applied. + let operand = self.parse_unary_expression(lexer, ctx.reborrow())?; + if !operand.is_reference { + let span = ctx + .expressions + .get_span(operand.handle) + .to_range() + .unwrap_or_else(|| self.peek_rule_span(lexer)); + return Err(Error::NotReference("the operand of the `&` operator", span)); + } + + // No code is generated. We just declare the pointer a reference now. + TypedExpression { + is_reference: false, + ..operand + } + } + _ => self.parse_singular_expression(lexer, ctx.reborrow())?, + }; + + self.pop_rule_span(lexer); + Ok(expr) + } + + /// Parse a `singular_expression`. + fn parse_singular_expression<'a>( + &mut self, + lexer: &mut Lexer<'a>, + mut ctx: ExpressionContext<'a, '_, '_>, + ) -> Result> { + let start = lexer.start_byte_offset(); + self.push_rule_span(Rule::SingularExpr, lexer); + let primary_expr = self.parse_primary_expression(lexer, ctx.reborrow())?; + let singular_expr = self.parse_postfix(start, lexer, ctx.reborrow(), primary_expr)?; + self.pop_rule_span(lexer); + + Ok(singular_expr) + } + + fn parse_equality_expression<'a>( + &mut self, + lexer: &mut Lexer<'a>, + mut context: ExpressionContext<'a, '_, '_>, + ) -> Result> { + // equality_expression + context.parse_binary_op( + lexer, + |token| match token { + Token::LogicalOperation('=') => Some(crate::BinaryOperator::Equal), + Token::LogicalOperation('!') => Some(crate::BinaryOperator::NotEqual), + _ => None, + }, + // relational_expression + |lexer, mut context| { + context.parse_binary_op( + lexer, + |token| match token { + Token::Paren('<') => Some(crate::BinaryOperator::Less), + Token::Paren('>') => Some(crate::BinaryOperator::Greater), + Token::LogicalOperation('<') => Some(crate::BinaryOperator::LessEqual), + Token::LogicalOperation('>') => Some(crate::BinaryOperator::GreaterEqual), + _ => None, + }, + // shift_expression + |lexer, mut context| { + context.parse_binary_op( + lexer, + |token| match token { + Token::ShiftOperation('<') => { + Some(crate::BinaryOperator::ShiftLeft) + } + Token::ShiftOperation('>') => { + Some(crate::BinaryOperator::ShiftRight) + } + _ => None, + }, + // additive_expression + |lexer, mut context| { + context.parse_binary_splat_op( + lexer, + |token| match token { + Token::Operation('+') => Some(crate::BinaryOperator::Add), + Token::Operation('-') => { + Some(crate::BinaryOperator::Subtract) + } + _ => None, + }, + // multiplicative_expression + |lexer, mut context| { + context.parse_binary_splat_op( + lexer, + |token| match token { + Token::Operation('*') => { + Some(crate::BinaryOperator::Multiply) + } + Token::Operation('/') => { + Some(crate::BinaryOperator::Divide) + } + Token::Operation('%') => { + Some(crate::BinaryOperator::Modulo) + } + _ => None, + }, + |lexer, context| { + self.parse_unary_expression(lexer, context) + }, + ) + }, + ) + }, + ) + }, + ) + }, + ) + } + + fn parse_general_expression_with_span<'a>( + &mut self, + lexer: &mut Lexer<'a>, + mut ctx: ExpressionContext<'a, '_, '_>, + ) -> Result<(Handle, Span), Error<'a>> { + let (expr, span) = self.parse_general_expression_for_reference(lexer, ctx.reborrow())?; + Ok((ctx.apply_load_rule(expr), span)) + } + + fn parse_general_expression<'a>( + &mut self, + lexer: &mut Lexer<'a>, + mut ctx: ExpressionContext<'a, '_, '_>, + ) -> Result, Error<'a>> { + let (expr, _span) = self.parse_general_expression_for_reference(lexer, ctx.reborrow())?; + Ok(ctx.apply_load_rule(expr)) + } + + fn parse_general_expression_for_reference<'a>( + &mut self, + lexer: &mut Lexer<'a>, + mut context: ExpressionContext<'a, '_, '_>, + ) -> Result<(TypedExpression, Span), Error<'a>> { + self.push_rule_span(Rule::GeneralExpr, lexer); + // logical_or_expression + let handle = context.parse_binary_op( + lexer, + |token| match token { + Token::LogicalOperation('|') => Some(crate::BinaryOperator::LogicalOr), + _ => None, + }, + // logical_and_expression + |lexer, mut context| { + context.parse_binary_op( + lexer, + |token| match token { + Token::LogicalOperation('&') => Some(crate::BinaryOperator::LogicalAnd), + _ => None, + }, + // inclusive_or_expression + |lexer, mut context| { + context.parse_binary_op( + lexer, + |token| match token { + Token::Operation('|') => Some(crate::BinaryOperator::InclusiveOr), + _ => None, + }, + // exclusive_or_expression + |lexer, mut context| { + context.parse_binary_op( + lexer, + |token| match token { + Token::Operation('^') => { + Some(crate::BinaryOperator::ExclusiveOr) + } + _ => None, + }, + // and_expression + |lexer, mut context| { + context.parse_binary_op( + lexer, + |token| match token { + Token::Operation('&') => { + Some(crate::BinaryOperator::And) + } + _ => None, + }, + |lexer, context| { + self.parse_equality_expression(lexer, context) + }, + ) + }, + ) + }, + ) + }, + ) + }, + )?; + Ok((handle, self.pop_rule_span(lexer))) + } + + fn parse_variable_ident_decl<'a>( + &mut self, + lexer: &mut Lexer<'a>, + type_arena: &mut UniqueArena, + const_arena: &mut Arena, + ) -> Result<(&'a str, Span, Handle), Error<'a>> { + let (name, name_span) = lexer.next_ident_with_span()?; + lexer.expect(Token::Separator(':'))?; + let ty = self.parse_type_decl(lexer, None, type_arena, const_arena)?; + Ok((name, name_span, ty)) + } + + fn parse_variable_decl<'a>( + &mut self, + lexer: &mut Lexer<'a>, + type_arena: &mut UniqueArena, + const_arena: &mut Arena, + ) -> Result, Error<'a>> { + self.push_rule_span(Rule::VariableDecl, lexer); + let mut space = None; + + if lexer.skip(Token::Paren('<')) { + let (class_str, span) = lexer.next_ident_with_span()?; + space = Some(match class_str { + "storage" => { + let access = if lexer.skip(Token::Separator(',')) { + lexer.next_storage_access()? + } else { + // defaulting to `read` + crate::StorageAccess::LOAD + }; + crate::AddressSpace::Storage { access } + } + _ => conv::map_address_space(class_str, span)?, + }); + lexer.expect(Token::Paren('>'))?; + } + let name = lexer.next_ident()?; + lexer.expect(Token::Separator(':'))?; + let ty = self.parse_type_decl(lexer, None, type_arena, const_arena)?; + + let init = if lexer.skip(Token::Operation('=')) { + let handle = self.parse_const_expression(lexer, type_arena, const_arena)?; + Some(handle) + } else { + None + }; + lexer.expect(Token::Separator(';'))?; + let name_span = self.pop_rule_span(lexer); + Ok(ParsedVariable { + name, + name_span, + space, + ty, + init, + }) + } + + fn parse_struct_body<'a>( + &mut self, + lexer: &mut Lexer<'a>, + type_arena: &mut UniqueArena, + const_arena: &mut Arena, + ) -> Result<(Vec, u32), Error<'a>> { + let mut offset = 0; + let mut struct_alignment = Alignment::ONE; + let mut members = Vec::new(); + + lexer.expect(Token::Paren('{'))?; + let mut ready = true; + while !lexer.skip(Token::Paren('}')) { + if !ready { + return Err(Error::Unexpected( + lexer.next().1, + ExpectedToken::Token(Token::Separator(',')), + )); + } + let (mut size_attr, mut align_attr) = (None, None); + self.push_rule_span(Rule::Attribute, lexer); + let mut bind_parser = BindingParser::default(); + while lexer.skip(Token::Attribute) { + match lexer.next_ident_with_span()? { + ("size", _) => { + lexer.expect(Token::Paren('('))?; + let (value, span) = + lexer.capture_span(Self::parse_non_negative_i32_literal)?; + lexer.expect(Token::Paren(')'))?; + size_attr = Some((value, span)); + } + ("align", _) => { + lexer.expect(Token::Paren('('))?; + let (value, span) = + lexer.capture_span(Self::parse_non_negative_i32_literal)?; + lexer.expect(Token::Paren(')'))?; + align_attr = Some((value, span)); + } + (word, word_span) => bind_parser.parse(lexer, word, word_span)?, + } + } + + let bind_span = self.pop_rule_span(lexer); + let mut binding = bind_parser.finish(bind_span)?; + + let (name, span) = match lexer.next() { + (Token::Word(word), span) => (word, span), + other => return Err(Error::Unexpected(other.1, ExpectedToken::FieldName)), + }; + if crate::keywords::wgsl::RESERVED.contains(&name) { + return Err(Error::ReservedKeyword(span)); + } + lexer.expect(Token::Separator(':'))?; + let ty = self.parse_type_decl(lexer, None, type_arena, const_arena)?; + ready = lexer.skip(Token::Separator(',')); + + self.layouter.update(type_arena, const_arena).unwrap(); + + let member_min_size = self.layouter[ty].size; + let member_min_alignment = self.layouter[ty].alignment; + + let member_size = if let Some((size, span)) = size_attr { + if size < member_min_size { + return Err(Error::SizeAttributeTooLow(span, member_min_size)); + } else { + size + } + } else { + member_min_size + }; + + let member_alignment = if let Some((align, span)) = align_attr { + if let Some(alignment) = Alignment::new(align) { + if alignment < member_min_alignment { + return Err(Error::AlignAttributeTooLow(span, member_min_alignment)); + } else { + alignment + } + } else { + return Err(Error::NonPowerOfTwoAlignAttribute(span)); + } + } else { + member_min_alignment + }; + + offset = member_alignment.round_up(offset); + struct_alignment = struct_alignment.max(member_alignment); + + if let Some(ref mut binding) = binding { + binding.apply_default_interpolation(&type_arena[ty].inner); + } + + members.push(crate::StructMember { + name: Some(name.to_owned()), + ty, + binding, + offset, + }); + + offset += member_size; + } + + let struct_size = struct_alignment.round_up(offset); + Ok((members, struct_size)) + } + + fn parse_matrix_scalar_type<'a>( + &mut self, + lexer: &mut Lexer<'a>, + columns: crate::VectorSize, + rows: crate::VectorSize, + ) -> Result> { + let (kind, width, span) = lexer.next_scalar_generic_with_span()?; + match kind { + crate::ScalarKind::Float => Ok(crate::TypeInner::Matrix { + columns, + rows, + width, + }), + _ => Err(Error::BadMatrixScalarKind(span, kind, width)), + } + } + + fn parse_type_decl_impl<'a>( + &mut self, + lexer: &mut Lexer<'a>, + _attribute: TypeAttributes, + word: &'a str, + type_arena: &mut UniqueArena, + const_arena: &mut Arena, + ) -> Result, Error<'a>> { + if let Some((kind, width)) = conv::get_scalar_type(word) { + return Ok(Some(crate::TypeInner::Scalar { kind, width })); + } + + Ok(Some(match word { + "vec2" => { + let (kind, width) = lexer.next_scalar_generic()?; + crate::TypeInner::Vector { + size: crate::VectorSize::Bi, + kind, + width, + } + } + "vec3" => { + let (kind, width) = lexer.next_scalar_generic()?; + crate::TypeInner::Vector { + size: crate::VectorSize::Tri, + kind, + width, + } + } + "vec4" => { + let (kind, width) = lexer.next_scalar_generic()?; + crate::TypeInner::Vector { + size: crate::VectorSize::Quad, + kind, + width, + } + } + "mat2x2" => { + self.parse_matrix_scalar_type(lexer, crate::VectorSize::Bi, crate::VectorSize::Bi)? + } + "mat2x3" => { + self.parse_matrix_scalar_type(lexer, crate::VectorSize::Bi, crate::VectorSize::Tri)? + } + "mat2x4" => self.parse_matrix_scalar_type( + lexer, + crate::VectorSize::Bi, + crate::VectorSize::Quad, + )?, + "mat3x2" => { + self.parse_matrix_scalar_type(lexer, crate::VectorSize::Tri, crate::VectorSize::Bi)? + } + "mat3x3" => self.parse_matrix_scalar_type( + lexer, + crate::VectorSize::Tri, + crate::VectorSize::Tri, + )?, + "mat3x4" => self.parse_matrix_scalar_type( + lexer, + crate::VectorSize::Tri, + crate::VectorSize::Quad, + )?, + "mat4x2" => self.parse_matrix_scalar_type( + lexer, + crate::VectorSize::Quad, + crate::VectorSize::Bi, + )?, + "mat4x3" => self.parse_matrix_scalar_type( + lexer, + crate::VectorSize::Quad, + crate::VectorSize::Tri, + )?, + "mat4x4" => self.parse_matrix_scalar_type( + lexer, + crate::VectorSize::Quad, + crate::VectorSize::Quad, + )?, + "atomic" => { + let (kind, width) = lexer.next_scalar_generic()?; + crate::TypeInner::Atomic { kind, width } + } + "ptr" => { + lexer.expect_generic_paren('<')?; + let (ident, span) = lexer.next_ident_with_span()?; + let mut space = conv::map_address_space(ident, span)?; + lexer.expect(Token::Separator(','))?; + let base = self.parse_type_decl(lexer, None, type_arena, const_arena)?; + if let crate::AddressSpace::Storage { ref mut access } = space { + *access = if lexer.skip(Token::Separator(',')) { + lexer.next_storage_access()? + } else { + crate::StorageAccess::LOAD + }; + } + lexer.expect_generic_paren('>')?; + crate::TypeInner::Pointer { base, space } + } + "array" => { + lexer.expect_generic_paren('<')?; + let base = self.parse_type_decl(lexer, None, type_arena, const_arena)?; + let size = if lexer.skip(Token::Separator(',')) { + let const_handle = + self.parse_const_expression(lexer, type_arena, const_arena)?; + crate::ArraySize::Constant(const_handle) + } else { + crate::ArraySize::Dynamic + }; + lexer.expect_generic_paren('>')?; + + let stride = { + self.layouter.update(type_arena, const_arena).unwrap(); + self.layouter[base].to_stride() + }; + crate::TypeInner::Array { base, size, stride } + } + "binding_array" => { + lexer.expect_generic_paren('<')?; + let base = self.parse_type_decl(lexer, None, type_arena, const_arena)?; + let size = if lexer.skip(Token::Separator(',')) { + let const_handle = + self.parse_const_expression(lexer, type_arena, const_arena)?; + crate::ArraySize::Constant(const_handle) + } else { + crate::ArraySize::Dynamic + }; + lexer.expect_generic_paren('>')?; + + crate::TypeInner::BindingArray { base, size } + } + "sampler" => crate::TypeInner::Sampler { comparison: false }, + "sampler_comparison" => crate::TypeInner::Sampler { comparison: true }, + "texture_1d" => { + let (kind, width, span) = lexer.next_scalar_generic_with_span()?; + Self::check_texture_sample_type(kind, width, span)?; + crate::TypeInner::Image { + dim: crate::ImageDimension::D1, + arrayed: false, + class: crate::ImageClass::Sampled { kind, multi: false }, + } + } + "texture_1d_array" => { + let (kind, width, span) = lexer.next_scalar_generic_with_span()?; + Self::check_texture_sample_type(kind, width, span)?; + crate::TypeInner::Image { + dim: crate::ImageDimension::D1, + arrayed: true, + class: crate::ImageClass::Sampled { kind, multi: false }, + } + } + "texture_2d" => { + let (kind, width, span) = lexer.next_scalar_generic_with_span()?; + Self::check_texture_sample_type(kind, width, span)?; + crate::TypeInner::Image { + dim: crate::ImageDimension::D2, + arrayed: false, + class: crate::ImageClass::Sampled { kind, multi: false }, + } + } + "texture_2d_array" => { + let (kind, width, span) = lexer.next_scalar_generic_with_span()?; + Self::check_texture_sample_type(kind, width, span)?; + crate::TypeInner::Image { + dim: crate::ImageDimension::D2, + arrayed: true, + class: crate::ImageClass::Sampled { kind, multi: false }, + } + } + "texture_3d" => { + let (kind, width, span) = lexer.next_scalar_generic_with_span()?; + Self::check_texture_sample_type(kind, width, span)?; + crate::TypeInner::Image { + dim: crate::ImageDimension::D3, + arrayed: false, + class: crate::ImageClass::Sampled { kind, multi: false }, + } + } + "texture_cube" => { + let (kind, width, span) = lexer.next_scalar_generic_with_span()?; + Self::check_texture_sample_type(kind, width, span)?; + crate::TypeInner::Image { + dim: crate::ImageDimension::Cube, + arrayed: false, + class: crate::ImageClass::Sampled { kind, multi: false }, + } + } + "texture_cube_array" => { + let (kind, width, span) = lexer.next_scalar_generic_with_span()?; + Self::check_texture_sample_type(kind, width, span)?; + crate::TypeInner::Image { + dim: crate::ImageDimension::Cube, + arrayed: true, + class: crate::ImageClass::Sampled { kind, multi: false }, + } + } + "texture_multisampled_2d" => { + let (kind, width, span) = lexer.next_scalar_generic_with_span()?; + Self::check_texture_sample_type(kind, width, span)?; + crate::TypeInner::Image { + dim: crate::ImageDimension::D2, + arrayed: false, + class: crate::ImageClass::Sampled { kind, multi: true }, + } + } + "texture_multisampled_2d_array" => { + let (kind, width, span) = lexer.next_scalar_generic_with_span()?; + Self::check_texture_sample_type(kind, width, span)?; + crate::TypeInner::Image { + dim: crate::ImageDimension::D2, + arrayed: true, + class: crate::ImageClass::Sampled { kind, multi: true }, + } + } + "texture_depth_2d" => crate::TypeInner::Image { + dim: crate::ImageDimension::D2, + arrayed: false, + class: crate::ImageClass::Depth { multi: false }, + }, + "texture_depth_2d_array" => crate::TypeInner::Image { + dim: crate::ImageDimension::D2, + arrayed: true, + class: crate::ImageClass::Depth { multi: false }, + }, + "texture_depth_cube" => crate::TypeInner::Image { + dim: crate::ImageDimension::Cube, + arrayed: false, + class: crate::ImageClass::Depth { multi: false }, + }, + "texture_depth_cube_array" => crate::TypeInner::Image { + dim: crate::ImageDimension::Cube, + arrayed: true, + class: crate::ImageClass::Depth { multi: false }, + }, + "texture_depth_multisampled_2d" => crate::TypeInner::Image { + dim: crate::ImageDimension::D2, + arrayed: false, + class: crate::ImageClass::Depth { multi: true }, + }, + "texture_storage_1d" => { + let (format, access) = lexer.next_format_generic()?; + crate::TypeInner::Image { + dim: crate::ImageDimension::D1, + arrayed: false, + class: crate::ImageClass::Storage { format, access }, + } + } + "texture_storage_1d_array" => { + let (format, access) = lexer.next_format_generic()?; + crate::TypeInner::Image { + dim: crate::ImageDimension::D1, + arrayed: true, + class: crate::ImageClass::Storage { format, access }, + } + } + "texture_storage_2d" => { + let (format, access) = lexer.next_format_generic()?; + crate::TypeInner::Image { + dim: crate::ImageDimension::D2, + arrayed: false, + class: crate::ImageClass::Storage { format, access }, + } + } + "texture_storage_2d_array" => { + let (format, access) = lexer.next_format_generic()?; + crate::TypeInner::Image { + dim: crate::ImageDimension::D2, + arrayed: true, + class: crate::ImageClass::Storage { format, access }, + } + } + "texture_storage_3d" => { + let (format, access) = lexer.next_format_generic()?; + crate::TypeInner::Image { + dim: crate::ImageDimension::D3, + arrayed: false, + class: crate::ImageClass::Storage { format, access }, + } + } + _ => return Ok(None), + })) + } + + const fn check_texture_sample_type( + kind: crate::ScalarKind, + width: u8, + span: Span, + ) -> Result<(), Error<'static>> { + use crate::ScalarKind::*; + // Validate according to https://gpuweb.github.io/gpuweb/wgsl/#sampled-texture-type + match (kind, width) { + (Float | Sint | Uint, 4) => Ok(()), + _ => Err(Error::BadTextureSampleType { span, kind, width }), + } + } + + /// Parse type declaration of a given name and attribute. + #[allow(clippy::too_many_arguments)] + fn parse_type_decl_name<'a>( + &mut self, + lexer: &mut Lexer<'a>, + name: &'a str, + name_span: Span, + debug_name: Option<&'a str>, + attribute: TypeAttributes, + type_arena: &mut UniqueArena, + const_arena: &mut Arena, + ) -> Result, Error<'a>> { + Ok(match self.lookup_type.get(name) { + Some(&handle) => handle, + None => { + match self.parse_type_decl_impl(lexer, attribute, name, type_arena, const_arena)? { + Some(inner) => { + let span = name_span.start..lexer.end_byte_offset(); + type_arena.insert( + crate::Type { + name: debug_name.map(|s| s.to_string()), + inner, + }, + NagaSpan::from(span), + ) + } + None => return Err(Error::UnknownType(name_span)), + } + } + }) + } + + fn parse_type_decl<'a>( + &mut self, + lexer: &mut Lexer<'a>, + debug_name: Option<&'a str>, + type_arena: &mut UniqueArena, + const_arena: &mut Arena, + ) -> Result, Error<'a>> { + self.push_rule_span(Rule::TypeDecl, lexer); + let attribute = TypeAttributes::default(); + + if lexer.skip(Token::Attribute) { + let other = lexer.next(); + return Err(Error::Unexpected(other.1, ExpectedToken::TypeAttribute)); + } + + let (name, name_span) = lexer.next_ident_with_span()?; + let handle = self.parse_type_decl_name( + lexer, + name, + name_span, + debug_name, + attribute, + type_arena, + const_arena, + )?; + self.pop_rule_span(lexer); + // Only set span if it's the first occurrence of the type. + // Type spans therefore should only be used for errors in type declarations; + // use variable spans/expression spans/etc. otherwise + Ok(handle) + } + + /// Parse an assignment statement (will also parse increment and decrement statements) + fn parse_assignment_statement<'a, 'out>( + &mut self, + lexer: &mut Lexer<'a>, + mut context: StatementContext<'a, '_, 'out>, + block: &mut crate::Block, + emitter: &mut super::Emitter, + ) -> Result<(), Error<'a>> { + use crate::BinaryOperator as Bo; + + let span_start = lexer.start_byte_offset(); + emitter.start(context.expressions); + let (reference, lhs_span) = self + .parse_general_expression_for_reference(lexer, context.as_expression(block, emitter))?; + let op = lexer.next(); + // The left hand side of an assignment must be a reference. + if !matches!( + op.0, + Token::Operation('=') + | Token::AssignmentOperation(_) + | Token::IncrementOperation + | Token::DecrementOperation + ) { + return Err(Error::Unexpected(lhs_span, ExpectedToken::Assignment)); + } else if !reference.is_reference { + let ty = if context.named_expressions.contains_key(&reference.handle) { + InvalidAssignmentType::ImmutableBinding + } else { + match *context.expressions.get_mut(reference.handle) { + crate::Expression::Swizzle { .. } => InvalidAssignmentType::Swizzle, + _ => InvalidAssignmentType::Other, + } + }; + + return Err(Error::InvalidAssignment { span: lhs_span, ty }); + } + + let mut context = context.as_expression(block, emitter); + + let value = match op { + (Token::Operation('='), _) => { + self.parse_general_expression(lexer, context.reborrow())? + } + (Token::AssignmentOperation(c), span) => { + let op = match c { + '<' => Bo::ShiftLeft, + '>' => Bo::ShiftRight, + '+' => Bo::Add, + '-' => Bo::Subtract, + '*' => Bo::Multiply, + '/' => Bo::Divide, + '%' => Bo::Modulo, + '&' => Bo::And, + '|' => Bo::InclusiveOr, + '^' => Bo::ExclusiveOr, + //Note: `consume_token` shouldn't produce any other assignment ops + _ => unreachable!(), + }; + let mut left = context.expressions.append( + crate::Expression::Load { + pointer: reference.handle, + }, + lhs_span.into(), + ); + let mut right = self.parse_general_expression(lexer, context.reborrow())?; + + context.binary_op_splat(op, &mut left, &mut right)?; + + context + .expressions + .append(crate::Expression::Binary { op, left, right }, span.into()) + } + token @ (Token::IncrementOperation | Token::DecrementOperation, _) => { + let op = match token.0 { + Token::IncrementOperation => Bo::Add, + Token::DecrementOperation => Bo::Subtract, + _ => unreachable!(), + }; + let op_span = token.1; + + // prepare the typifier, but work around mutable borrowing... + let _ = context.resolve_type(reference.handle)?; + + let ty = context.typifier.get(reference.handle, context.types); + let (kind, width) = match *ty { + crate::TypeInner::ValuePointer { + size: None, + kind, + width, + .. + } => (kind, width), + crate::TypeInner::Pointer { base, .. } => match context.types[base].inner { + crate::TypeInner::Scalar { kind, width } => (kind, width), + _ => return Err(Error::BadIncrDecrReferenceType(lhs_span)), + }, + _ => return Err(Error::BadIncrDecrReferenceType(lhs_span)), + }; + let constant_inner = crate::ConstantInner::Scalar { + width, + value: match kind { + crate::ScalarKind::Sint => crate::ScalarValue::Sint(1), + crate::ScalarKind::Uint => crate::ScalarValue::Uint(1), + _ => return Err(Error::BadIncrDecrReferenceType(lhs_span)), + }, + }; + let constant = context.constants.append( + crate::Constant { + name: None, + specialization: None, + inner: constant_inner, + }, + crate::Span::default(), + ); + + let left = context.expressions.append( + crate::Expression::Load { + pointer: reference.handle, + }, + lhs_span.into(), + ); + let right = context.interrupt_emitter( + crate::Expression::Constant(constant), + crate::Span::default(), + ); + context.expressions.append( + crate::Expression::Binary { op, left, right }, + op_span.into(), + ) + } + other => return Err(Error::Unexpected(other.1, ExpectedToken::SwitchItem)), + }; + + let span_end = lexer.end_byte_offset(); + context + .block + .extend(context.emitter.finish(context.expressions)); + context.block.push( + crate::Statement::Store { + pointer: reference.handle, + value, + }, + NagaSpan::from(span_start..span_end), + ); + Ok(()) + } + + /// Parse a function call statement. + fn parse_function_statement<'a, 'out>( + &mut self, + lexer: &mut Lexer<'a>, + ident: &'a str, + mut context: ExpressionContext<'a, '_, 'out>, + ) -> Result<(), Error<'a>> { + self.push_rule_span(Rule::SingularExpr, lexer); + context.emitter.start(context.expressions); + if self + .parse_function_call_inner(lexer, ident, context.reborrow())? + .is_none() + { + let span = lexer.next().1; + return Err(Error::UnknownLocalFunction(span)); + } + context + .block + .extend(context.emitter.finish(context.expressions)); + self.pop_rule_span(lexer); + + Ok(()) + } + + fn parse_switch_case_body<'a, 'out>( + &mut self, + lexer: &mut Lexer<'a>, + mut context: StatementContext<'a, '_, 'out>, + ) -> Result<(bool, crate::Block), Error<'a>> { + let mut body = crate::Block::new(); + // Push a new lexical scope for the switch case body + context.symbol_table.push_scope(); + + lexer.expect(Token::Paren('{'))?; + let fall_through = loop { + // default statements + if lexer.skip(Token::Word("fallthrough")) { + lexer.expect(Token::Separator(';'))?; + lexer.expect(Token::Paren('}'))?; + break true; + } + if lexer.skip(Token::Paren('}')) { + break false; + } + self.parse_statement(lexer, context.reborrow(), &mut body, false)?; + }; + // Pop the switch case body lexical scope + context.symbol_table.pop_scope(); + + Ok((fall_through, body)) + } + + fn parse_statement<'a, 'out>( + &mut self, + lexer: &mut Lexer<'a>, + mut context: StatementContext<'a, '_, 'out>, + block: &'out mut crate::Block, + is_uniform_control_flow: bool, + ) -> Result<(), Error<'a>> { + self.push_rule_span(Rule::Statement, lexer); + match lexer.peek() { + (Token::Separator(';'), _) => { + let _ = lexer.next(); + self.pop_rule_span(lexer); + return Ok(()); + } + (Token::Paren('{'), _) => { + self.push_rule_span(Rule::Block, lexer); + // Push a new lexical scope for the block statement + context.symbol_table.push_scope(); + + let _ = lexer.next(); + let mut statements = crate::Block::new(); + while !lexer.skip(Token::Paren('}')) { + self.parse_statement( + lexer, + context.reborrow(), + &mut statements, + is_uniform_control_flow, + )?; + } + // Pop the block statement lexical scope + context.symbol_table.pop_scope(); + + self.pop_rule_span(lexer); + let span = NagaSpan::from(self.pop_rule_span(lexer)); + block.push(crate::Statement::Block(statements), span); + return Ok(()); + } + (Token::Word(word), _) => { + let mut emitter = super::Emitter::default(); + let statement = match word { + "_" => { + let _ = lexer.next(); + emitter.start(context.expressions); + lexer.expect(Token::Operation('='))?; + self.parse_general_expression( + lexer, + context.as_expression(block, &mut emitter), + )?; + lexer.expect(Token::Separator(';'))?; + block.extend(emitter.finish(context.expressions)); + None + } + "let" => { + let _ = lexer.next(); + emitter.start(context.expressions); + let (name, name_span) = lexer.next_ident_with_span()?; + if crate::keywords::wgsl::RESERVED.contains(&name) { + return Err(Error::ReservedKeyword(name_span)); + } + let given_ty = if lexer.skip(Token::Separator(':')) { + let ty = self.parse_type_decl( + lexer, + None, + context.types, + context.constants, + )?; + Some(ty) + } else { + None + }; + lexer.expect(Token::Operation('='))?; + let expr_id = self.parse_general_expression( + lexer, + context.as_expression(block, &mut emitter), + )?; + lexer.expect(Token::Separator(';'))?; + if let Some(ty) = given_ty { + // prepare the typifier, but work around mutable borrowing... + let _ = context + .as_expression(block, &mut emitter) + .resolve_type(expr_id)?; + let expr_inner = context.typifier.get(expr_id, context.types); + let given_inner = &context.types[ty].inner; + if !given_inner.equivalent(expr_inner, context.types) { + log::error!( + "Given type {:?} doesn't match expected {:?}", + given_inner, + expr_inner + ); + return Err(Error::InitializationTypeMismatch( + name_span, + expr_inner.to_wgsl(context.types, context.constants), + )); + } + } + block.extend(emitter.finish(context.expressions)); + context.symbol_table.add( + name, + TypedExpression { + handle: expr_id, + is_reference: false, + }, + ); + context + .named_expressions + .insert(expr_id, String::from(name)); + None + } + "var" => { + let _ = lexer.next(); + enum Init { + Empty, + Constant(Handle), + Variable(Handle), + } + + let (name, name_span) = lexer.next_ident_with_span()?; + if crate::keywords::wgsl::RESERVED.contains(&name) { + return Err(Error::ReservedKeyword(name_span)); + } + let given_ty = if lexer.skip(Token::Separator(':')) { + let ty = self.parse_type_decl( + lexer, + None, + context.types, + context.constants, + )?; + Some(ty) + } else { + None + }; + + let (init, ty) = if lexer.skip(Token::Operation('=')) { + emitter.start(context.expressions); + let value = self.parse_general_expression( + lexer, + context.as_expression(block, &mut emitter), + )?; + block.extend(emitter.finish(context.expressions)); + + // prepare the typifier, but work around mutable borrowing... + let _ = context + .as_expression(block, &mut emitter) + .resolve_type(value)?; + + //TODO: share more of this code with `let` arm + let ty = match given_ty { + Some(ty) => { + let expr_inner = context.typifier.get(value, context.types); + let given_inner = &context.types[ty].inner; + if !given_inner.equivalent(expr_inner, context.types) { + log::error!( + "Given type {:?} doesn't match expected {:?}", + given_inner, + expr_inner + ); + return Err(Error::InitializationTypeMismatch( + name_span, + expr_inner.to_wgsl(context.types, context.constants), + )); + } + ty + } + None => { + // register the type, if needed + match context.typifier[value].clone() { + TypeResolution::Handle(ty) => ty, + TypeResolution::Value(inner) => context.types.insert( + crate::Type { name: None, inner }, + Default::default(), + ), + } + } + }; + + let init = match context.expressions[value] { + crate::Expression::Constant(handle) if is_uniform_control_flow => { + Init::Constant(handle) + } + _ => Init::Variable(value), + }; + (init, ty) + } else { + match given_ty { + Some(ty) => (Init::Empty, ty), + None => { + log::error!( + "Variable '{}' without an initializer needs a type", + name + ); + return Err(Error::MissingType(name_span)); + } + } + }; + + lexer.expect(Token::Separator(';'))?; + let var_id = context.variables.append( + crate::LocalVariable { + name: Some(name.to_owned()), + ty, + init: match init { + Init::Constant(value) => Some(value), + _ => None, + }, + }, + NagaSpan::from(name_span), + ); + + // Doesn't make sense to assign a span to cached lookup + let expr_id = context + .expressions + .append(crate::Expression::LocalVariable(var_id), Default::default()); + context.symbol_table.add( + name, + TypedExpression { + handle: expr_id, + is_reference: true, + }, + ); + + if let Init::Variable(value) = init { + Some(crate::Statement::Store { + pointer: expr_id, + value, + }) + } else { + None + } + } + "return" => { + let _ = lexer.next(); + let value = if lexer.peek().0 != Token::Separator(';') { + emitter.start(context.expressions); + let handle = self.parse_general_expression( + lexer, + context.as_expression(block, &mut emitter), + )?; + block.extend(emitter.finish(context.expressions)); + Some(handle) + } else { + None + }; + lexer.expect(Token::Separator(';'))?; + Some(crate::Statement::Return { value }) + } + "if" => { + let _ = lexer.next(); + emitter.start(context.expressions); + let condition = self.parse_general_expression( + lexer, + context.as_expression(block, &mut emitter), + )?; + block.extend(emitter.finish(context.expressions)); + + let accept = self.parse_block(lexer, context.reborrow(), false)?; + + let mut elsif_stack = Vec::new(); + let mut elseif_span_start = lexer.start_byte_offset(); + let mut reject = loop { + if !lexer.skip(Token::Word("else")) { + break crate::Block::new(); + } + + if !lexer.skip(Token::Word("if")) { + // ... else { ... } + break self.parse_block(lexer, context.reborrow(), false)?; + } + + // ... else if (...) { ... } + let mut sub_emitter = super::Emitter::default(); + + sub_emitter.start(context.expressions); + let other_condition = self.parse_general_expression( + lexer, + context.as_expression(block, &mut sub_emitter), + )?; + let other_emit = sub_emitter.finish(context.expressions); + let other_block = self.parse_block(lexer, context.reborrow(), false)?; + elsif_stack.push(( + elseif_span_start, + other_condition, + other_emit, + other_block, + )); + elseif_span_start = lexer.start_byte_offset(); + }; + + let span_end = lexer.end_byte_offset(); + // reverse-fold the else-if blocks + //Note: we may consider uplifting this to the IR + for (other_span_start, other_cond, other_emit, other_block) in + elsif_stack.into_iter().rev() + { + let sub_stmt = crate::Statement::If { + condition: other_cond, + accept: other_block, + reject, + }; + reject = crate::Block::new(); + reject.extend(other_emit); + reject.push(sub_stmt, NagaSpan::from(other_span_start..span_end)) + } + + Some(crate::Statement::If { + condition, + accept, + reject, + }) + } + "switch" => { + let _ = lexer.next(); + emitter.start(context.expressions); + let selector = self.parse_general_expression( + lexer, + context.as_expression(block, &mut emitter), + )?; + let uint = Some(crate::ScalarKind::Uint) + == context + .as_expression(block, &mut emitter) + .resolve_type(selector)? + .scalar_kind(); + block.extend(emitter.finish(context.expressions)); + lexer.expect(Token::Paren('{'))?; + let mut cases = Vec::new(); + + loop { + // cases + default + match lexer.next() { + (Token::Word("case"), _) => { + // parse a list of values + let value = loop { + let value = Self::parse_switch_value(lexer, uint)?; + if lexer.skip(Token::Separator(',')) { + if lexer.skip(Token::Separator(':')) { + break value; + } + } else { + lexer.skip(Token::Separator(':')); + break value; + } + cases.push(crate::SwitchCase { + value: crate::SwitchValue::Integer(value), + body: crate::Block::new(), + fall_through: true, + }); + }; + + let (fall_through, body) = + self.parse_switch_case_body(lexer, context.reborrow())?; + + cases.push(crate::SwitchCase { + value: crate::SwitchValue::Integer(value), + body, + fall_through, + }); + } + (Token::Word("default"), _) => { + lexer.skip(Token::Separator(':')); + let (fall_through, body) = + self.parse_switch_case_body(lexer, context.reborrow())?; + cases.push(crate::SwitchCase { + value: crate::SwitchValue::Default, + body, + fall_through, + }); + } + (Token::Paren('}'), _) => break, + other => { + return Err(Error::Unexpected( + other.1, + ExpectedToken::SwitchItem, + )) + } + } + } + + Some(crate::Statement::Switch { selector, cases }) + } + "loop" => Some(self.parse_loop(lexer, context.reborrow(), &mut emitter)?), + "while" => { + let _ = lexer.next(); + let mut body = crate::Block::new(); + + let (condition, span) = lexer.capture_span(|lexer| { + emitter.start(context.expressions); + let condition = self.parse_general_expression( + lexer, + context.as_expression(&mut body, &mut emitter), + )?; + lexer.expect(Token::Paren('{'))?; + body.extend(emitter.finish(context.expressions)); + Ok(condition) + })?; + let mut reject = crate::Block::new(); + reject.push(crate::Statement::Break, NagaSpan::default()); + body.push( + crate::Statement::If { + condition, + accept: crate::Block::new(), + reject, + }, + NagaSpan::from(span), + ); + // Push a lexical scope for the while loop body + context.symbol_table.push_scope(); + + while !lexer.skip(Token::Paren('}')) { + self.parse_statement(lexer, context.reborrow(), &mut body, false)?; + } + // Pop the while loop body lexical scope + context.symbol_table.pop_scope(); + + Some(crate::Statement::Loop { + body, + continuing: crate::Block::new(), + break_if: None, + }) + } + "for" => { + let _ = lexer.next(); + lexer.expect(Token::Paren('('))?; + // Push a lexical scope for the for loop + context.symbol_table.push_scope(); + + if !lexer.skip(Token::Separator(';')) { + let num_statements = block.len(); + let (_, span) = lexer.capture_span(|lexer| { + self.parse_statement( + lexer, + context.reborrow(), + block, + is_uniform_control_flow, + ) + })?; + + if block.len() != num_statements { + match *block.last().unwrap() { + crate::Statement::Store { .. } + | crate::Statement::Call { .. } => {} + _ => return Err(Error::InvalidForInitializer(span)), + } + } + }; + + let mut body = crate::Block::new(); + if !lexer.skip(Token::Separator(';')) { + let (condition, span) = lexer.capture_span(|lexer| { + emitter.start(context.expressions); + let condition = self.parse_general_expression( + lexer, + context.as_expression(&mut body, &mut emitter), + )?; + lexer.expect(Token::Separator(';'))?; + body.extend(emitter.finish(context.expressions)); + Ok(condition) + })?; + let mut reject = crate::Block::new(); + reject.push(crate::Statement::Break, NagaSpan::default()); + body.push( + crate::Statement::If { + condition, + accept: crate::Block::new(), + reject, + }, + NagaSpan::from(span), + ); + }; + + let mut continuing = crate::Block::new(); + if !lexer.skip(Token::Paren(')')) { + match lexer.peek().0 { + Token::Word(ident) + if context.symbol_table.lookup(ident).is_none() => + { + self.parse_function_statement( + lexer, + ident, + context.as_expression(&mut continuing, &mut emitter), + )? + } + _ => self.parse_assignment_statement( + lexer, + context.reborrow(), + &mut continuing, + &mut emitter, + )?, + } + lexer.expect(Token::Paren(')'))?; + } + lexer.expect(Token::Paren('{'))?; + + while !lexer.skip(Token::Paren('}')) { + self.parse_statement(lexer, context.reborrow(), &mut body, false)?; + } + // Pop the for loop lexical scope + context.symbol_table.pop_scope(); + + Some(crate::Statement::Loop { + body, + continuing, + break_if: None, + }) + } + "break" => { + let (_, mut span) = lexer.next(); + // Check if the next token is an `if`, this indicates + // that the user tried to type out a `break if` which + // is illegal in this position. + let (peeked_token, peeked_span) = lexer.peek(); + if let Token::Word("if") = peeked_token { + span.end = peeked_span.end; + return Err(Error::InvalidBreakIf(span)); + } + Some(crate::Statement::Break) + } + "continue" => { + let _ = lexer.next(); + Some(crate::Statement::Continue) + } + "discard" => { + let _ = lexer.next(); + Some(crate::Statement::Kill) + } + "storageBarrier" => { + let _ = lexer.next(); + lexer.expect(Token::Paren('('))?; + lexer.expect(Token::Paren(')'))?; + Some(crate::Statement::Barrier(crate::Barrier::STORAGE)) + } + "workgroupBarrier" => { + let _ = lexer.next(); + lexer.expect(Token::Paren('('))?; + lexer.expect(Token::Paren(')'))?; + Some(crate::Statement::Barrier(crate::Barrier::WORK_GROUP)) + } + "atomicStore" => { + let _ = lexer.next(); + emitter.start(context.expressions); + lexer.open_arguments()?; + let mut expression_ctx = context.as_expression(block, &mut emitter); + let pointer = + self.parse_atomic_pointer(lexer, expression_ctx.reborrow())?; + lexer.expect(Token::Separator(','))?; + let value = self.parse_general_expression(lexer, expression_ctx)?; + lexer.close_arguments()?; + block.extend(emitter.finish(context.expressions)); + Some(crate::Statement::Store { pointer, value }) + } + "textureStore" => { + let _ = lexer.next(); + emitter.start(context.expressions); + lexer.open_arguments()?; + let mut expr_context = context.as_expression(block, &mut emitter); + let (image, image_span) = self + .parse_general_expression_with_span(lexer, expr_context.reborrow())?; + lexer.expect(Token::Separator(','))?; + let arrayed = match *expr_context.resolve_type(image)? { + crate::TypeInner::Image { arrayed, .. } => arrayed, + _ => return Err(Error::BadTexture(image_span)), + }; + let coordinate = self.parse_general_expression(lexer, expr_context)?; + let array_index = if arrayed { + lexer.expect(Token::Separator(','))?; + Some(self.parse_general_expression( + lexer, + context.as_expression(block, &mut emitter), + )?) + } else { + None + }; + lexer.expect(Token::Separator(','))?; + let value = self.parse_general_expression( + lexer, + context.as_expression(block, &mut emitter), + )?; + lexer.close_arguments()?; + block.extend(emitter.finish(context.expressions)); + Some(crate::Statement::ImageStore { + image, + coordinate, + array_index, + value, + }) + } + // assignment or a function call + ident => { + match context.symbol_table.lookup(ident) { + Some(_) => self.parse_assignment_statement( + lexer, + context, + block, + &mut emitter, + )?, + None => self.parse_function_statement( + lexer, + ident, + context.as_expression(block, &mut emitter), + )?, + } + lexer.expect(Token::Separator(';'))?; + None + } + }; + let span = NagaSpan::from(self.pop_rule_span(lexer)); + if let Some(statement) = statement { + block.push(statement, span); + } + } + _ => { + let mut emitter = super::Emitter::default(); + self.parse_assignment_statement(lexer, context, block, &mut emitter)?; + self.pop_rule_span(lexer); + } + } + Ok(()) + } + + fn parse_loop<'a>( + &mut self, + lexer: &mut Lexer<'a>, + mut context: StatementContext<'a, '_, '_>, + emitter: &mut super::Emitter, + ) -> Result> { + let _ = lexer.next(); + let mut body = crate::Block::new(); + let mut continuing = crate::Block::new(); + let mut break_if = None; + + // Push a lexical scope for the loop body + context.symbol_table.push_scope(); + + lexer.expect(Token::Paren('{'))?; + + loop { + if lexer.skip(Token::Word("continuing")) { + // Branch for the `continuing` block, this must be + // the last thing in the loop body + + // Expect a opening brace to start the continuing block + lexer.expect(Token::Paren('{'))?; + loop { + if lexer.skip(Token::Word("break")) { + // Branch for the `break if` statement, this statement + // has the form `break if ;` and must be the last + // statement in a continuing block + + // The break must be followed by an `if` to form + // the break if + lexer.expect(Token::Word("if"))?; + + // Start the emitter to begin parsing an expression + emitter.start(context.expressions); + let condition = self.parse_general_expression( + lexer, + context.as_expression(&mut body, emitter), + )?; + // Add all emits to the continuing body + continuing.extend(emitter.finish(context.expressions)); + // Set the condition of the break if to the newly parsed + // expression + break_if = Some(condition); + + // Expext a semicolon to close the statement + lexer.expect(Token::Separator(';'))?; + // Expect a closing brace to close the continuing block, + // since the break if must be the last statement + lexer.expect(Token::Paren('}'))?; + // Stop parsing the continuing block + break; + } else if lexer.skip(Token::Paren('}')) { + // If we encounter a closing brace it means we have reached + // the end of the continuing block and should stop processing + break; + } else { + // Otherwise try to parse a statement + self.parse_statement(lexer, context.reborrow(), &mut continuing, false)?; + } + } + // Since the continuing block must be the last part of the loop body, + // we expect to see a closing brace to end the loop body + lexer.expect(Token::Paren('}'))?; + break; + } + if lexer.skip(Token::Paren('}')) { + // If we encounter a closing brace it means we have reached + // the end of the loop body and should stop processing + break; + } + // Otherwise try to parse a statement + self.parse_statement(lexer, context.reborrow(), &mut body, false)?; + } + + // Pop the loop body lexical scope + context.symbol_table.pop_scope(); + + Ok(crate::Statement::Loop { + body, + continuing, + break_if, + }) + } + + fn parse_block<'a>( + &mut self, + lexer: &mut Lexer<'a>, + mut context: StatementContext<'a, '_, '_>, + is_uniform_control_flow: bool, + ) -> Result> { + self.push_rule_span(Rule::Block, lexer); + // Push a lexical scope for the block + context.symbol_table.push_scope(); + + lexer.expect(Token::Paren('{'))?; + let mut block = crate::Block::new(); + while !lexer.skip(Token::Paren('}')) { + self.parse_statement( + lexer, + context.reborrow(), + &mut block, + is_uniform_control_flow, + )?; + } + //Pop the block lexical scope + context.symbol_table.pop_scope(); + + self.pop_rule_span(lexer); + Ok(block) + } + + fn parse_varying_binding<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ) -> Result, Error<'a>> { + let mut bind_parser = BindingParser::default(); + self.push_rule_span(Rule::Attribute, lexer); + + while lexer.skip(Token::Attribute) { + let (word, span) = lexer.next_ident_with_span()?; + bind_parser.parse(lexer, word, span)?; + } + + let span = self.pop_rule_span(lexer); + bind_parser.finish(span) + } + + fn parse_function_decl<'a>( + &mut self, + lexer: &mut Lexer<'a>, + module: &mut crate::Module, + lookup_global_expression: &FastHashMap<&'a str, crate::Expression>, + ) -> Result<(crate::Function, &'a str), Error<'a>> { + self.push_rule_span(Rule::FunctionDecl, lexer); + // read function name + let mut symbol_table = super::SymbolTable::default(); + let (fun_name, span) = lexer.next_ident_with_span()?; + if crate::keywords::wgsl::RESERVED.contains(&fun_name) { + return Err(Error::ReservedKeyword(span)); + } + if let Some(entry) = self + .module_scope_identifiers + .insert(String::from(fun_name), span.clone()) + { + return Err(Error::Redefinition { + previous: entry, + current: span, + }); + } + // populate initial expressions + let mut expressions = Arena::new(); + for (&name, expression) in lookup_global_expression.iter() { + let (span, is_reference) = match *expression { + crate::Expression::GlobalVariable(handle) => ( + module.global_variables.get_span(handle), + module.global_variables[handle].space != crate::AddressSpace::Handle, + ), + crate::Expression::Constant(handle) => (module.constants.get_span(handle), false), + _ => unreachable!(), + }; + let expression = expressions.append(expression.clone(), span); + symbol_table.add( + name, + TypedExpression { + handle: expression, + is_reference, + }, + ); + } + // read parameter list + let mut arguments = Vec::new(); + lexer.expect(Token::Paren('('))?; + let mut ready = true; + while !lexer.skip(Token::Paren(')')) { + if !ready { + return Err(Error::Unexpected( + lexer.next().1, + ExpectedToken::Token(Token::Separator(',')), + )); + } + let mut binding = self.parse_varying_binding(lexer)?; + let (param_name, param_name_span, param_type) = + self.parse_variable_ident_decl(lexer, &mut module.types, &mut module.constants)?; + if crate::keywords::wgsl::RESERVED.contains(¶m_name) { + return Err(Error::ReservedKeyword(param_name_span)); + } + let param_index = arguments.len() as u32; + let expression = expressions.append( + crate::Expression::FunctionArgument(param_index), + NagaSpan::from(param_name_span), + ); + symbol_table.add( + param_name, + TypedExpression { + handle: expression, + is_reference: false, + }, + ); + if let Some(ref mut binding) = binding { + binding.apply_default_interpolation(&module.types[param_type].inner); + } + arguments.push(crate::FunctionArgument { + name: Some(param_name.to_string()), + ty: param_type, + binding, + }); + ready = lexer.skip(Token::Separator(',')); + } + // read return type + let result = if lexer.skip(Token::Arrow) && !lexer.skip(Token::Word("void")) { + let mut binding = self.parse_varying_binding(lexer)?; + let ty = self.parse_type_decl(lexer, None, &mut module.types, &mut module.constants)?; + if let Some(ref mut binding) = binding { + binding.apply_default_interpolation(&module.types[ty].inner); + } + Some(crate::FunctionResult { ty, binding }) + } else { + None + }; + + let mut fun = crate::Function { + name: Some(fun_name.to_string()), + arguments, + result, + local_variables: Arena::new(), + expressions, + named_expressions: crate::NamedExpressions::default(), + body: crate::Block::new(), + }; + + // read body + let mut typifier = super::Typifier::new(); + let mut named_expressions = crate::FastHashMap::default(); + fun.body = self.parse_block( + lexer, + StatementContext { + symbol_table: &mut symbol_table, + typifier: &mut typifier, + variables: &mut fun.local_variables, + expressions: &mut fun.expressions, + named_expressions: &mut named_expressions, + types: &mut module.types, + constants: &mut module.constants, + global_vars: &module.global_variables, + functions: &module.functions, + arguments: &fun.arguments, + }, + true, + )?; + // fixup the IR + ensure_block_returns(&mut fun.body); + // done + self.pop_rule_span(lexer); + + // Set named expressions after block parsing ends + fun.named_expressions = named_expressions; + + Ok((fun, fun_name)) + } + + fn parse_global_decl<'a>( + &mut self, + lexer: &mut Lexer<'a>, + module: &mut crate::Module, + lookup_global_expression: &mut FastHashMap<&'a str, crate::Expression>, + ) -> Result> { + // read attributes + let mut binding = None; + let mut stage = None; + let mut workgroup_size = [0u32; 3]; + let mut early_depth_test = None; + let (mut bind_index, mut bind_group) = (None, None); + + self.push_rule_span(Rule::Attribute, lexer); + while lexer.skip(Token::Attribute) { + match lexer.next_ident_with_span()? { + ("binding", _) => { + lexer.expect(Token::Paren('('))?; + bind_index = Some(Self::parse_non_negative_i32_literal(lexer)?); + lexer.expect(Token::Paren(')'))?; + } + ("group", _) => { + lexer.expect(Token::Paren('('))?; + bind_group = Some(Self::parse_non_negative_i32_literal(lexer)?); + lexer.expect(Token::Paren(')'))?; + } + ("vertex", _) => { + stage = Some(crate::ShaderStage::Vertex); + } + ("fragment", _) => { + stage = Some(crate::ShaderStage::Fragment); + } + ("compute", _) => { + stage = Some(crate::ShaderStage::Compute); + } + ("workgroup_size", _) => { + lexer.expect(Token::Paren('('))?; + workgroup_size = [1u32; 3]; + for (i, size) in workgroup_size.iter_mut().enumerate() { + *size = Self::parse_generic_non_negative_int_literal(lexer)?; + match lexer.next() { + (Token::Paren(')'), _) => break, + (Token::Separator(','), _) if i != 2 => (), + other => { + return Err(Error::Unexpected( + other.1, + ExpectedToken::WorkgroupSizeSeparator, + )) + } + } + } + } + ("early_depth_test", _) => { + let conservative = if lexer.skip(Token::Paren('(')) { + let (ident, ident_span) = lexer.next_ident_with_span()?; + let value = conv::map_conservative_depth(ident, ident_span)?; + lexer.expect(Token::Paren(')'))?; + Some(value) + } else { + None + }; + early_depth_test = Some(crate::EarlyDepthTest { conservative }); + } + (_, word_span) => return Err(Error::UnknownAttribute(word_span)), + } + } + + let attrib_span = self.pop_rule_span(lexer); + match (bind_group, bind_index) { + (Some(group), Some(index)) => { + binding = Some(crate::ResourceBinding { + group, + binding: index, + }); + } + (Some(_), None) => return Err(Error::MissingAttribute("binding", attrib_span)), + (None, Some(_)) => return Err(Error::MissingAttribute("group", attrib_span)), + (None, None) => {} + } + + // read items + let start = lexer.start_byte_offset(); + match lexer.next() { + (Token::Separator(';'), _) => {} + (Token::Word("struct"), _) => { + let (name, span) = lexer.next_ident_with_span()?; + if crate::keywords::wgsl::RESERVED.contains(&name) { + return Err(Error::ReservedKeyword(span)); + } + let (members, span) = + self.parse_struct_body(lexer, &mut module.types, &mut module.constants)?; + let type_span = NagaSpan::from(lexer.span_from(start)); + let ty = module.types.insert( + crate::Type { + name: Some(name.to_string()), + inner: crate::TypeInner::Struct { members, span }, + }, + type_span, + ); + self.lookup_type.insert(name.to_owned(), ty); + } + (Token::Word("type"), _) => { + let name = lexer.next_ident()?; + lexer.expect(Token::Operation('='))?; + let ty = self.parse_type_decl( + lexer, + Some(name), + &mut module.types, + &mut module.constants, + )?; + self.lookup_type.insert(name.to_owned(), ty); + lexer.expect(Token::Separator(';'))?; + } + (Token::Word("let"), _) => { + let (name, name_span) = lexer.next_ident_with_span()?; + if crate::keywords::wgsl::RESERVED.contains(&name) { + return Err(Error::ReservedKeyword(name_span)); + } + if let Some(entry) = self + .module_scope_identifiers + .insert(String::from(name), name_span.clone()) + { + return Err(Error::Redefinition { + previous: entry, + current: name_span, + }); + } + let given_ty = if lexer.skip(Token::Separator(':')) { + let ty = self.parse_type_decl( + lexer, + None, + &mut module.types, + &mut module.constants, + )?; + Some(ty) + } else { + None + }; + + lexer.expect(Token::Operation('='))?; + let first_token_span = lexer.next(); + let const_handle = self.parse_const_expression_impl( + first_token_span, + lexer, + Some(name), + &mut module.types, + &mut module.constants, + )?; + + if let Some(explicit_ty) = given_ty { + let con = &module.constants[const_handle]; + let type_match = match con.inner { + crate::ConstantInner::Scalar { width, value } => { + module.types[explicit_ty].inner + == crate::TypeInner::Scalar { + kind: value.scalar_kind(), + width, + } + } + crate::ConstantInner::Composite { ty, components: _ } => ty == explicit_ty, + }; + if !type_match { + let expected_inner_str = match con.inner { + crate::ConstantInner::Scalar { width, value } => { + crate::TypeInner::Scalar { + kind: value.scalar_kind(), + width, + } + .to_wgsl(&module.types, &module.constants) + } + crate::ConstantInner::Composite { .. } => module.types[explicit_ty] + .inner + .to_wgsl(&module.types, &module.constants), + }; + return Err(Error::InitializationTypeMismatch( + name_span, + expected_inner_str, + )); + } + } + + lexer.expect(Token::Separator(';'))?; + lookup_global_expression.insert(name, crate::Expression::Constant(const_handle)); + } + (Token::Word("var"), _) => { + let pvar = + self.parse_variable_decl(lexer, &mut module.types, &mut module.constants)?; + if crate::keywords::wgsl::RESERVED.contains(&pvar.name) { + return Err(Error::ReservedKeyword(pvar.name_span)); + } + if let Some(entry) = self + .module_scope_identifiers + .insert(String::from(pvar.name), pvar.name_span.clone()) + { + return Err(Error::Redefinition { + previous: entry, + current: pvar.name_span, + }); + } + let var_handle = module.global_variables.append( + crate::GlobalVariable { + name: Some(pvar.name.to_owned()), + space: pvar.space.unwrap_or(crate::AddressSpace::Handle), + binding: binding.take(), + ty: pvar.ty, + init: pvar.init, + }, + NagaSpan::from(pvar.name_span), + ); + lookup_global_expression + .insert(pvar.name, crate::Expression::GlobalVariable(var_handle)); + } + (Token::Word("fn"), _) => { + let (function, name) = + self.parse_function_decl(lexer, module, lookup_global_expression)?; + match stage { + Some(stage) => module.entry_points.push(crate::EntryPoint { + name: name.to_string(), + stage, + early_depth_test, + workgroup_size, + function, + }), + None => { + module + .functions + .append(function, NagaSpan::from(lexer.span_from(start))); + } + } + } + (Token::End, _) => return Ok(false), + other => return Err(Error::Unexpected(other.1, ExpectedToken::GlobalItem)), + } + + match binding { + None => Ok(true), + // we had the attribute but no var? + Some(_) => Err(Error::Other), + } + } + + pub fn parse(&mut self, source: &str) -> Result { + self.reset(); + + let mut module = crate::Module::default(); + let mut lexer = Lexer::new(source); + let mut lookup_global_expression = FastHashMap::default(); + loop { + match self.parse_global_decl(&mut lexer, &mut module, &mut lookup_global_expression) { + Err(error) => return Err(error.as_parse_error(lexer.source)), + Ok(true) => {} + Ok(false) => { + if !self.rules.is_empty() { + log::error!("Reached the end of file, but rule stack is not empty"); + return Err(Error::Other.as_parse_error(lexer.source)); + }; + return Ok(module); + } + } + } + } +} + +pub fn parse_str(source: &str) -> Result { + Parser::new().parse(source) +} diff --git a/third_party/rust/naga/src/front/wgsl/number.rs b/third_party/rust/naga/src/front/wgsl/number.rs new file mode 100644 index 0000000000..fafe1d2270 --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/number.rs @@ -0,0 +1,442 @@ +use std::borrow::Cow; + +use super::{NumberError, Token}; + +/// When using this type assume no Abstract Int/Float for now +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum Number { + /// Abstract Int (-2^63 ≤ i < 2^63) + AbstractInt(i64), + /// Abstract Float (IEEE-754 binary64) + AbstractFloat(f64), + /// Concrete i32 + I32(i32), + /// Concrete u32 + U32(u32), + /// Concrete f32 + F32(f32), +} + +impl Number { + /// Convert abstract numbers to a plausible concrete counterpart. + /// + /// Return concrete numbers unchanged. If the conversion would be + /// lossy, return an error. + fn abstract_to_concrete(self) -> Result { + match self { + Number::AbstractInt(num) => i32::try_from(num) + .map(Number::I32) + .map_err(|_| NumberError::NotRepresentable), + Number::AbstractFloat(num) => { + let num = num as f32; + if num.is_finite() { + Ok(Number::F32(num)) + } else { + Err(NumberError::NotRepresentable) + } + } + num => Ok(num), + } + } +} + +// TODO: when implementing Creation-Time Expressions, remove the ability to match the minus sign + +pub(super) fn consume_number(input: &str) -> (Token<'_>, &str) { + let (result, rest) = parse(input); + ( + Token::Number(result.and_then(Number::abstract_to_concrete)), + rest, + ) +} + +enum Kind { + Int(IntKind), + Float(FloatKind), +} + +enum IntKind { + I32, + U32, +} + +enum FloatKind { + F32, + F16, +} + +// The following regexes (from the WGSL spec) will be matched: + +// int_literal: +// | / 0 [iu]? / +// | / [1-9][0-9]* [iu]? / +// | / 0[xX][0-9a-fA-F]+ [iu]? / + +// decimal_float_literal: +// | / 0 [fh] / +// | / [1-9][0-9]* [fh] / +// | / [0-9]* \.[0-9]+ ([eE][+-]?[0-9]+)? [fh]? / +// | / [0-9]+ \.[0-9]* ([eE][+-]?[0-9]+)? [fh]? / +// | / [0-9]+ [eE][+-]?[0-9]+ [fh]? / + +// hex_float_literal: +// | / 0[xX][0-9a-fA-F]* \.[0-9a-fA-F]+ ([pP][+-]?[0-9]+ [fh]?)? / +// | / 0[xX][0-9a-fA-F]+ \.[0-9a-fA-F]* ([pP][+-]?[0-9]+ [fh]?)? / +// | / 0[xX][0-9a-fA-F]+ [pP][+-]?[0-9]+ [fh]? / + +// You could visualize the regex below via https://debuggex.com to get a rough idea what `parse` is doing +// -?(?:0[xX](?:([0-9a-fA-F]+\.[0-9a-fA-F]*|[0-9a-fA-F]*\.[0-9a-fA-F]+)(?:([pP][+-]?[0-9]+)([fh]?))?|([0-9a-fA-F]+)([pP][+-]?[0-9]+)([fh]?)|([0-9a-fA-F]+)([iu]?))|((?:[0-9]+[eE][+-]?[0-9]+|(?:[0-9]+\.[0-9]*|[0-9]*\.[0-9]+)(?:[eE][+-]?[0-9]+)?))([fh]?)|((?:[0-9]|[1-9][0-9]+))([iufh]?)) + +fn parse(input: &str) -> (Result, &str) { + /// returns `true` and consumes `X` bytes from the given byte buffer + /// if the given `X` nr of patterns are found at the start of the buffer + macro_rules! consume { + ($bytes:ident, $($pattern:pat),*) => { + match $bytes { + &[$($pattern),*, ref rest @ ..] => { $bytes = rest; true }, + _ => false, + } + }; + } + + /// consumes one byte from the given byte buffer + /// if one of the given patterns are found at the start of the buffer + /// returning the corresponding expr for the matched pattern + macro_rules! consume_map { + ($bytes:ident, [$($pattern:pat_param => $to:expr),*]) => { + match $bytes { + $( &[$pattern, ref rest @ ..] => { $bytes = rest; Some($to) }, )* + _ => None, + } + }; + } + + /// consumes all consecutive bytes matched by the `0-9` pattern from the given byte buffer + /// returning the number of consumed bytes + macro_rules! consume_dec_digits { + ($bytes:ident) => {{ + let start_len = $bytes.len(); + while let &[b'0'..=b'9', ref rest @ ..] = $bytes { + $bytes = rest; + } + start_len - $bytes.len() + }}; + } + + /// consumes all consecutive bytes matched by the `0-9 | a-f | A-F` pattern from the given byte buffer + /// returning the number of consumed bytes + macro_rules! consume_hex_digits { + ($bytes:ident) => {{ + let start_len = $bytes.len(); + while let &[b'0'..=b'9' | b'a'..=b'f' | b'A'..=b'F', ref rest @ ..] = $bytes { + $bytes = rest; + } + start_len - $bytes.len() + }}; + } + + /// maps the given `&[u8]` (tail of the initial `input: &str`) to a `&str` + macro_rules! rest_to_str { + ($bytes:ident) => { + &input[input.len() - $bytes.len()..] + }; + } + + struct ExtractSubStr<'a>(&'a str); + + impl<'a> ExtractSubStr<'a> { + /// given an `input` and a `start` (tail of the `input`) + /// creates a new [ExtractSubStr] + fn start(input: &'a str, start: &'a [u8]) -> Self { + let start = input.len() - start.len(); + Self(&input[start..]) + } + /// given an `end` (tail of the initial `input`) + /// returns a substring of `input` + fn end(&self, end: &'a [u8]) -> &'a str { + let end = self.0.len() - end.len(); + &self.0[..end] + } + } + + let mut bytes = input.as_bytes(); + + let general_extract = ExtractSubStr::start(input, bytes); + + let is_negative = consume!(bytes, b'-'); + + if consume!(bytes, b'0', b'x' | b'X') { + let digits_extract = ExtractSubStr::start(input, bytes); + + let consumed = consume_hex_digits!(bytes); + + if consume!(bytes, b'.') { + let consumed_after_period = consume_hex_digits!(bytes); + + if consumed + consumed_after_period == 0 { + return (Err(NumberError::Invalid), rest_to_str!(bytes)); + } + + let significand = general_extract.end(bytes); + + if consume!(bytes, b'p' | b'P') { + consume!(bytes, b'+' | b'-'); + let consumed = consume_dec_digits!(bytes); + + if consumed == 0 { + return (Err(NumberError::Invalid), rest_to_str!(bytes)); + } + + let number = general_extract.end(bytes); + + let kind = consume_map!(bytes, [b'f' => FloatKind::F32, b'h' => FloatKind::F16]); + + (parse_hex_float(number, kind), rest_to_str!(bytes)) + } else { + ( + parse_hex_float_missing_exponent(significand, None), + rest_to_str!(bytes), + ) + } + } else { + if consumed == 0 { + return (Err(NumberError::Invalid), rest_to_str!(bytes)); + } + + let significand = general_extract.end(bytes); + let digits = digits_extract.end(bytes); + + let exp_extract = ExtractSubStr::start(input, bytes); + + if consume!(bytes, b'p' | b'P') { + consume!(bytes, b'+' | b'-'); + let consumed = consume_dec_digits!(bytes); + + if consumed == 0 { + return (Err(NumberError::Invalid), rest_to_str!(bytes)); + } + + let exponent = exp_extract.end(bytes); + + let kind = consume_map!(bytes, [b'f' => FloatKind::F32, b'h' => FloatKind::F16]); + + ( + parse_hex_float_missing_period(significand, exponent, kind), + rest_to_str!(bytes), + ) + } else { + let kind = consume_map!(bytes, [b'i' => IntKind::I32, b'u' => IntKind::U32]); + + ( + parse_hex_int(is_negative, digits, kind), + rest_to_str!(bytes), + ) + } + } + } else { + let is_first_zero = bytes.first() == Some(&b'0'); + + let consumed = consume_dec_digits!(bytes); + + if consume!(bytes, b'.') { + let consumed_after_period = consume_dec_digits!(bytes); + + if consumed + consumed_after_period == 0 { + return (Err(NumberError::Invalid), rest_to_str!(bytes)); + } + + if consume!(bytes, b'e' | b'E') { + consume!(bytes, b'+' | b'-'); + let consumed = consume_dec_digits!(bytes); + + if consumed == 0 { + return (Err(NumberError::Invalid), rest_to_str!(bytes)); + } + } + + let number = general_extract.end(bytes); + + let kind = consume_map!(bytes, [b'f' => FloatKind::F32, b'h' => FloatKind::F16]); + + (parse_dec_float(number, kind), rest_to_str!(bytes)) + } else { + if consumed == 0 { + return (Err(NumberError::Invalid), rest_to_str!(bytes)); + } + + if consume!(bytes, b'e' | b'E') { + consume!(bytes, b'+' | b'-'); + let consumed = consume_dec_digits!(bytes); + + if consumed == 0 { + return (Err(NumberError::Invalid), rest_to_str!(bytes)); + } + + let number = general_extract.end(bytes); + + let kind = consume_map!(bytes, [b'f' => FloatKind::F32, b'h' => FloatKind::F16]); + + (parse_dec_float(number, kind), rest_to_str!(bytes)) + } else { + // make sure the multi-digit numbers don't start with zero + if consumed > 1 && is_first_zero { + return (Err(NumberError::Invalid), rest_to_str!(bytes)); + } + + let digits_with_sign = general_extract.end(bytes); + + let kind = consume_map!(bytes, [ + b'i' => Kind::Int(IntKind::I32), + b'u' => Kind::Int(IntKind::U32), + b'f' => Kind::Float(FloatKind::F32), + b'h' => Kind::Float(FloatKind::F16) + ]); + + ( + parse_dec(is_negative, digits_with_sign, kind), + rest_to_str!(bytes), + ) + } + } + } +} + +fn parse_hex_float_missing_exponent( + // format: -?0[xX] ( [0-9a-fA-F]+\.[0-9a-fA-F]* | [0-9a-fA-F]*\.[0-9a-fA-F]+ ) + significand: &str, + kind: Option, +) -> Result { + let hexf_input = format!("{}{}", significand, "p0"); + parse_hex_float(&hexf_input, kind) +} + +fn parse_hex_float_missing_period( + // format: -?0[xX] [0-9a-fA-F]+ + significand: &str, + // format: [pP][+-]?[0-9]+ + exponent: &str, + kind: Option, +) -> Result { + let hexf_input = format!("{}.{}", significand, exponent); + parse_hex_float(&hexf_input, kind) +} + +fn parse_hex_int( + is_negative: bool, + // format: [0-9a-fA-F]+ + digits: &str, + kind: Option, +) -> Result { + let digits_with_sign = if is_negative { + Cow::Owned(format!("-{}", digits)) + } else { + Cow::Borrowed(digits) + }; + parse_int(&digits_with_sign, kind, 16, is_negative) +} + +fn parse_dec( + is_negative: bool, + // format: -? ( [0-9] | [1-9][0-9]+ ) + digits_with_sign: &str, + kind: Option, +) -> Result { + match kind { + None => parse_int(digits_with_sign, None, 10, is_negative), + Some(Kind::Int(kind)) => parse_int(digits_with_sign, Some(kind), 10, is_negative), + Some(Kind::Float(kind)) => parse_dec_float(digits_with_sign, Some(kind)), + } +} + +// Float parsing notes + +// The following chapters of IEEE 754-2019 are relevant: +// +// 7.4 Overflow (largest finite number is exceeded by what would have been +// the rounded floating-point result were the exponent range unbounded) +// +// 7.5 Underflow (tiny non-zero result is detected; +// for decimal formats tininess is detected before rounding when a non-zero result +// computed as though both the exponent range and the precision were unbounded +// would lie strictly between 2^−126) +// +// 7.6 Inexact (rounded result differs from what would have been computed +// were both exponent range and precision unbounded) + +// The WGSL spec requires us to error: +// on overflow for decimal floating point literals +// on overflow and inexact for hexadecimal floating point literals +// (underflow is not mentioned) + +// hexf_parse errors on overflow, underflow, inexact +// rust std lib float from str handles overflow, underflow, inexact transparently (rounds and will not error) + +// Therefore we only check for overflow manually for decimal floating point literals + +// input format: -?0[xX] ( [0-9a-fA-F]+\.[0-9a-fA-F]* | [0-9a-fA-F]*\.[0-9a-fA-F]+ ) [pP][+-]?[0-9]+ +fn parse_hex_float(input: &str, kind: Option) -> Result { + match kind { + None => match hexf_parse::parse_hexf64(input, false) { + Ok(num) => Ok(Number::AbstractFloat(num)), + // can only be ParseHexfErrorKind::Inexact but we can't check since it's private + _ => Err(NumberError::NotRepresentable), + }, + Some(FloatKind::F32) => match hexf_parse::parse_hexf32(input, false) { + Ok(num) => Ok(Number::F32(num)), + // can only be ParseHexfErrorKind::Inexact but we can't check since it's private + _ => Err(NumberError::NotRepresentable), + }, + Some(FloatKind::F16) => Err(NumberError::UnimplementedF16), + } +} + +// input format: -? ( [0-9]+\.[0-9]* | [0-9]*\.[0-9]+ ) ([eE][+-]?[0-9]+)? +// | -? [0-9]+ [eE][+-]?[0-9]+ +fn parse_dec_float(input: &str, kind: Option) -> Result { + match kind { + None => { + let num = input.parse::().unwrap(); // will never fail + num.is_finite() + .then(|| Number::AbstractFloat(num)) + .ok_or(NumberError::NotRepresentable) + } + Some(FloatKind::F32) => { + let num = input.parse::().unwrap(); // will never fail + num.is_finite() + .then(|| Number::F32(num)) + .ok_or(NumberError::NotRepresentable) + } + Some(FloatKind::F16) => Err(NumberError::UnimplementedF16), + } +} + +fn parse_int( + input: &str, + kind: Option, + radix: u32, + is_negative: bool, +) -> Result { + fn map_err(e: core::num::ParseIntError) -> NumberError { + match *e.kind() { + core::num::IntErrorKind::PosOverflow | core::num::IntErrorKind::NegOverflow => { + NumberError::NotRepresentable + } + _ => unreachable!(), + } + } + match kind { + None => match i64::from_str_radix(input, radix) { + Ok(num) => Ok(Number::AbstractInt(num)), + Err(e) => Err(map_err(e)), + }, + Some(IntKind::I32) => match i32::from_str_radix(input, radix) { + Ok(num) => Ok(Number::I32(num)), + Err(e) => Err(map_err(e)), + }, + Some(IntKind::U32) if is_negative => Err(NumberError::NotRepresentable), + Some(IntKind::U32) => match u32::from_str_radix(input, radix) { + Ok(num) => Ok(Number::U32(num)), + Err(e) => Err(map_err(e)), + }, + } +} diff --git a/third_party/rust/naga/src/front/wgsl/tests.rs b/third_party/rust/naga/src/front/wgsl/tests.rs new file mode 100644 index 0000000000..33fc541acb --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/tests.rs @@ -0,0 +1,458 @@ +use super::parse_str; + +#[test] +fn parse_comment() { + parse_str( + "// + //// + ///////////////////////////////////////////////////////// asda + //////////////////// dad ////////// / + ///////////////////////////////////////////////////////////////////////////////////////////////////// + // + ", + ) + .unwrap(); +} + +#[test] +fn parse_types() { + parse_str("let a : i32 = 2;").unwrap(); + assert!(parse_str("let a : x32 = 2;").is_err()); + parse_str("var t: texture_2d;").unwrap(); + parse_str("var t: texture_cube_array;").unwrap(); + parse_str("var t: texture_multisampled_2d;").unwrap(); + parse_str("var t: texture_storage_1d;").unwrap(); + parse_str("var t: texture_storage_3d;").unwrap(); +} + +#[test] +fn parse_type_inference() { + parse_str( + " + fn foo() { + let a = 2u; + let b: u32 = a; + var x = 3.; + var y = vec2(1, 2); + }", + ) + .unwrap(); + assert!(parse_str( + " + fn foo() { let c : i32 = 2.0; }", + ) + .is_err()); +} + +#[test] +fn parse_type_cast() { + parse_str( + " + let a : i32 = 2; + fn main() { + var x: f32 = f32(a); + x = f32(i32(a + 1) / 2); + } + ", + ) + .unwrap(); + parse_str( + " + fn main() { + let x: vec2 = vec2(1.0, 2.0); + let y: vec2 = vec2(x); + } + ", + ) + .unwrap(); + parse_str( + " + fn main() { + let x: vec2 = vec2(0.0); + } + ", + ) + .unwrap(); + assert!(parse_str( + " + fn main() { + let x: vec2 = vec2(0); + } + ", + ) + .is_err()); +} + +#[test] +fn parse_struct() { + parse_str( + " + struct Foo { x: i32 } + struct Bar { + @size(16) x: vec2, + @align(16) y: f32, + @size(32) @align(128) z: vec3, + }; + struct Empty {} + var s: Foo; + ", + ) + .unwrap(); +} + +#[test] +fn parse_standard_fun() { + parse_str( + " + fn main() { + var x: i32 = min(max(1, 2), 3); + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_statement() { + parse_str( + " + fn main() { + ; + {} + {;} + } + ", + ) + .unwrap(); + + parse_str( + " + fn foo() {} + fn bar() { foo(); } + ", + ) + .unwrap(); +} + +#[test] +fn parse_if() { + parse_str( + " + fn main() { + if true { + discard; + } else {} + if 0 != 1 {} + if false { + return; + } else if true { + return; + } else {} + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_parentheses_if() { + parse_str( + " + fn main() { + if (true) { + discard; + } else {} + if (0 != 1) {} + if (false) { + return; + } else if (true) { + return; + } else {} + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_loop() { + parse_str( + " + fn main() { + var i: i32 = 0; + loop { + if i == 1 { break; } + continuing { i = 1; } + } + loop { + if i == 0 { continue; } + break; + } + } + ", + ) + .unwrap(); + parse_str( + " + fn main() { + var found: bool = false; + var i: i32 = 0; + while !found { + if i == 10 { + found = true; + } + + i = i + 1; + } + } + ", + ) + .unwrap(); + parse_str( + " + fn main() { + while true { + break; + } + } + ", + ) + .unwrap(); + parse_str( + " + fn main() { + var a: i32 = 0; + for(var i: i32 = 0; i < 4; i = i + 1) { + a = a + 2; + } + } + ", + ) + .unwrap(); + parse_str( + " + fn main() { + for(;;) { + break; + } + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_switch() { + parse_str( + " + fn main() { + var pos: f32; + switch (3) { + case 0, 1: { pos = 0.0; } + case 2: { pos = 1.0; fallthrough; } + case 3: {} + default: { pos = 3.0; } + } + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_switch_optional_colon_in_case() { + parse_str( + " + fn main() { + var pos: f32; + switch (3) { + case 0, 1 { pos = 0.0; } + case 2 { pos = 1.0; fallthrough; } + case 3 {} + default { pos = 3.0; } + } + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_parentheses_switch() { + parse_str( + " + fn main() { + var pos: f32; + switch pos > 1.0 { + default: { pos = 3.0; } + } + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_texture_load() { + parse_str( + " + var t: texture_3d; + fn foo() { + let r: vec4 = textureLoad(t, vec3(0.0, 1.0, 2.0), 1); + } + ", + ) + .unwrap(); + parse_str( + " + var t: texture_multisampled_2d_array; + fn foo() { + let r: vec4 = textureLoad(t, vec2(10, 20), 2, 3); + } + ", + ) + .unwrap(); + parse_str( + " + var t: texture_storage_1d_array; + fn foo() { + let r: vec4 = textureLoad(t, 10, 2); + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_texture_store() { + parse_str( + " + var t: texture_storage_2d; + fn foo() { + textureStore(t, vec2(10, 20), vec4(0.0, 1.0, 2.0, 3.0)); + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_texture_query() { + parse_str( + " + var t: texture_multisampled_2d_array; + fn foo() { + var dim: vec2 = textureDimensions(t); + dim = textureDimensions(t, 0); + let layers: i32 = textureNumLayers(t); + let samples: i32 = textureNumSamples(t); + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_postfix() { + parse_str( + "fn foo() { + let x: f32 = vec4(1.0, 2.0, 3.0, 4.0).xyz.rgbr.aaaa.wz.g; + let y: f32 = fract(vec2(0.5, x)).x; + }", + ) + .unwrap(); +} + +#[test] +fn parse_expressions() { + parse_str("fn foo() { + let x: f32 = select(0.0, 1.0, true); + let y: vec2 = select(vec2(1.0, 1.0), vec2(x, x), vec2(x < 0.5, x > 0.5)); + let z: bool = !(0.0 == 1.0); + }").unwrap(); +} + +#[test] +fn parse_pointers() { + parse_str( + "fn foo() { + var x: f32 = 1.0; + let px = &x; + let py = frexp(0.5, px); + }", + ) + .unwrap(); +} + +#[test] +fn parse_struct_instantiation() { + parse_str( + " + struct Foo { + a: f32, + b: vec3, + } + + @fragment + fn fs_main() { + var foo: Foo = Foo(0.0, vec3(0.0, 1.0, 42.0)); + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_array_length() { + parse_str( + " + struct Foo { + data: array + } // this is used as both input and output for convenience + + @group(0) @binding(0) + var foo: Foo; + + @group(0) @binding(1) + var bar: array; + + fn baz() { + var x: u32 = arrayLength(foo.data); + var y: u32 = arrayLength(bar); + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_storage_buffers() { + parse_str( + " + @group(0) @binding(0) + var foo: array; + ", + ) + .unwrap(); + parse_str( + " + @group(0) @binding(0) + var foo: array; + ", + ) + .unwrap(); + parse_str( + " + @group(0) @binding(0) + var foo: array; + ", + ) + .unwrap(); + parse_str( + " + @group(0) @binding(0) + var foo: array; + ", + ) + .unwrap(); +} -- cgit v1.2.3