diff options
Diffstat (limited to 'third_party/rust/naga/src/front/wgsl/construction.rs')
-rw-r--r-- | third_party/rust/naga/src/front/wgsl/construction.rs | 679 |
1 files changed, 679 insertions, 0 deletions
diff --git a/third_party/rust/naga/src/front/wgsl/construction.rs b/third_party/rust/naga/src/front/wgsl/construction.rs new file mode 100644 index 0000000000..43e719d0f3 --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/construction.rs @@ -0,0 +1,679 @@ +use crate::{ + proc::TypeResolution, Arena, ArraySize, Bytes, Constant, ConstantInner, Expression, Handle, + ScalarKind, ScalarValue, Span as NagaSpan, Type, TypeInner, UniqueArena, VectorSize, +}; + +use super::{Error, ExpressionContext, Lexer, Parser, Rule, Span, Token}; + +/// Represents the type of the constructor +/// +/// Vectors, Matrices and Arrays can have partial type information +/// which later gets inferred from the constructor parameters +enum ConstructorType { + Scalar { + kind: ScalarKind, + width: Bytes, + }, + PartialVector { + size: VectorSize, + }, + Vector { + size: VectorSize, + kind: ScalarKind, + width: Bytes, + }, + PartialMatrix { + columns: VectorSize, + rows: VectorSize, + }, + Matrix { + columns: VectorSize, + rows: VectorSize, + width: Bytes, + }, + PartialArray, + Array { + base: Handle<Type>, + size: ArraySize, + stride: u32, + }, + Struct(Handle<Type>), +} + +impl ConstructorType { + const fn to_type_resolution(&self) -> Option<TypeResolution> { + Some(match *self { + ConstructorType::Scalar { kind, width } => { + TypeResolution::Value(TypeInner::Scalar { kind, width }) + } + ConstructorType::Vector { size, kind, width } => { + TypeResolution::Value(TypeInner::Vector { size, kind, width }) + } + ConstructorType::Matrix { + columns, + rows, + width, + } => TypeResolution::Value(TypeInner::Matrix { + columns, + rows, + width, + }), + ConstructorType::Array { base, size, stride } => { + TypeResolution::Value(TypeInner::Array { base, size, stride }) + } + ConstructorType::Struct(handle) => TypeResolution::Handle(handle), + _ => return None, + }) + } +} + +impl ConstructorType { + fn to_error_string(&self, types: &UniqueArena<Type>, constants: &Arena<Constant>) -> String { + match *self { + ConstructorType::Scalar { kind, width } => kind.to_wgsl(width), + ConstructorType::PartialVector { size } => { + format!("vec{}<?>", size as u32,) + } + ConstructorType::Vector { size, kind, width } => { + format!("vec{}<{}>", size as u32, kind.to_wgsl(width)) + } + ConstructorType::PartialMatrix { columns, rows } => { + format!("mat{}x{}<?>", columns as u32, rows as u32,) + } + ConstructorType::Matrix { + columns, + rows, + width, + } => { + format!( + "mat{}x{}<{}>", + columns as u32, + rows as u32, + ScalarKind::Float.to_wgsl(width) + ) + } + ConstructorType::PartialArray => "array<?, ?>".to_string(), + ConstructorType::Array { base, size, .. } => { + format!( + "array<{}, {}>", + types[base].name.as_deref().unwrap_or("?"), + match size { + ArraySize::Constant(size) => { + constants[size] + .to_array_length() + .map(|len| len.to_string()) + .unwrap_or_else(|| "?".to_string()) + } + _ => unreachable!(), + } + ) + } + ConstructorType::Struct(handle) => types[handle] + .name + .clone() + .unwrap_or_else(|| "?".to_string()), + } + } +} + +fn parse_constructor_type<'a>( + parser: &mut Parser, + lexer: &mut Lexer<'a>, + word: &'a str, + type_arena: &mut UniqueArena<Type>, + const_arena: &mut Arena<Constant>, +) -> Result<Option<ConstructorType>, Error<'a>> { + if let Some((kind, width)) = super::conv::get_scalar_type(word) { + return Ok(Some(ConstructorType::Scalar { kind, width })); + } + + let partial = match word { + "vec2" => ConstructorType::PartialVector { + size: VectorSize::Bi, + }, + "vec3" => ConstructorType::PartialVector { + size: VectorSize::Tri, + }, + "vec4" => ConstructorType::PartialVector { + size: VectorSize::Quad, + }, + "mat2x2" => ConstructorType::PartialMatrix { + columns: VectorSize::Bi, + rows: VectorSize::Bi, + }, + "mat2x3" => ConstructorType::PartialMatrix { + columns: VectorSize::Bi, + rows: VectorSize::Tri, + }, + "mat2x4" => ConstructorType::PartialMatrix { + columns: VectorSize::Bi, + rows: VectorSize::Quad, + }, + "mat3x2" => ConstructorType::PartialMatrix { + columns: VectorSize::Tri, + rows: VectorSize::Bi, + }, + "mat3x3" => ConstructorType::PartialMatrix { + columns: VectorSize::Tri, + rows: VectorSize::Tri, + }, + "mat3x4" => ConstructorType::PartialMatrix { + columns: VectorSize::Tri, + rows: VectorSize::Quad, + }, + "mat4x2" => ConstructorType::PartialMatrix { + columns: VectorSize::Quad, + rows: VectorSize::Bi, + }, + "mat4x3" => ConstructorType::PartialMatrix { + columns: VectorSize::Quad, + rows: VectorSize::Tri, + }, + "mat4x4" => ConstructorType::PartialMatrix { + columns: VectorSize::Quad, + rows: VectorSize::Quad, + }, + "array" => ConstructorType::PartialArray, + _ => return Ok(None), + }; + + // parse component type if present + match (lexer.peek().0, partial) { + (Token::Paren('<'), ConstructorType::PartialVector { size }) => { + let (kind, width) = lexer.next_scalar_generic()?; + Ok(Some(ConstructorType::Vector { size, kind, width })) + } + (Token::Paren('<'), ConstructorType::PartialMatrix { columns, rows }) => { + let (kind, width, span) = lexer.next_scalar_generic_with_span()?; + match kind { + ScalarKind::Float => Ok(Some(ConstructorType::Matrix { + columns, + rows, + width, + })), + _ => Err(Error::BadMatrixScalarKind(span, kind, width)), + } + } + (Token::Paren('<'), ConstructorType::PartialArray) => { + lexer.expect_generic_paren('<')?; + let base = parser.parse_type_decl(lexer, None, type_arena, const_arena)?; + let size = if lexer.skip(Token::Separator(',')) { + let const_handle = parser.parse_const_expression(lexer, type_arena, const_arena)?; + ArraySize::Constant(const_handle) + } else { + ArraySize::Dynamic + }; + lexer.expect_generic_paren('>')?; + + let stride = { + parser.layouter.update(type_arena, const_arena).unwrap(); + parser.layouter[base].to_stride() + }; + + Ok(Some(ConstructorType::Array { base, size, stride })) + } + (_, partial) => Ok(Some(partial)), + } +} + +/// Expects [`Rule::PrimaryExpr`] on top of rule stack; if returning Some(_), pops it. +pub(super) fn parse_construction<'a>( + parser: &mut Parser, + lexer: &mut Lexer<'a>, + type_name: &'a str, + type_span: Span, + mut ctx: ExpressionContext<'a, '_, '_>, +) -> Result<Option<Handle<Expression>>, Error<'a>> { + assert_eq!( + parser.rules.last().map(|&(ref rule, _)| rule.clone()), + Some(Rule::PrimaryExpr) + ); + let dst_ty = match parser.lookup_type.get(type_name) { + Some(&handle) => ConstructorType::Struct(handle), + None => match parse_constructor_type(parser, lexer, type_name, ctx.types, ctx.constants)? { + Some(inner) => inner, + None => { + match parser.parse_type_decl_impl( + lexer, + super::TypeAttributes::default(), + type_name, + ctx.types, + ctx.constants, + )? { + Some(_) => { + return Err(Error::TypeNotConstructible(type_span)); + } + None => return Ok(None), + } + } + }, + }; + + lexer.open_arguments()?; + + let mut components = Vec::new(); + let mut spans = Vec::new(); + + if lexer.peek().0 == Token::Paren(')') { + let _ = lexer.next(); + } else { + while components.is_empty() || lexer.next_argument()? { + let (component, span) = lexer + .capture_span(|lexer| parser.parse_general_expression(lexer, ctx.reborrow()))?; + components.push(component); + spans.push(span); + } + } + + enum Components<'a> { + None, + One { + component: Handle<Expression>, + span: Span, + ty: &'a TypeInner, + }, + Many { + components: Vec<Handle<Expression>>, + spans: Vec<Span>, + first_component_ty: &'a TypeInner, + }, + } + + impl<'a> Components<'a> { + fn into_components_vec(self) -> Vec<Handle<Expression>> { + match self { + Components::None => vec![], + Components::One { component, .. } => vec![component], + Components::Many { components, .. } => components, + } + } + } + + let components = match *components.as_slice() { + [] => Components::None, + [component] => { + ctx.resolve_type(component)?; + Components::One { + component, + span: spans[0].clone(), + ty: ctx.typifier.get(component, ctx.types), + } + } + [component, ..] => { + ctx.resolve_type(component)?; + Components::Many { + components, + spans, + first_component_ty: ctx.typifier.get(component, ctx.types), + } + } + }; + + let expr = match (components, dst_ty) { + // Empty constructor + (Components::None, dst_ty) => { + let ty = match dst_ty.to_type_resolution() { + Some(TypeResolution::Handle(handle)) => handle, + Some(TypeResolution::Value(inner)) => ctx + .types + .insert(Type { name: None, inner }, Default::default()), + None => return Err(Error::TypeNotInferrable(type_span)), + }; + + return match ctx.create_zero_value_constant(ty) { + Some(constant) => { + let span = parser.pop_rule_span(lexer); + Ok(Some(ctx.interrupt_emitter( + Expression::Constant(constant), + span.into(), + ))) + } + None => Err(Error::TypeNotConstructible(type_span)), + }; + } + + // Scalar constructor & conversion (scalar -> scalar) + ( + Components::One { + component, + ty: &TypeInner::Scalar { .. }, + .. + }, + ConstructorType::Scalar { kind, width }, + ) => Expression::As { + expr: component, + kind, + convert: Some(width), + }, + + // Vector conversion (vector -> vector) + ( + Components::One { + component, + ty: &TypeInner::Vector { size: src_size, .. }, + .. + }, + ConstructorType::Vector { + size: dst_size, + kind: dst_kind, + width: dst_width, + }, + ) if dst_size == src_size => Expression::As { + expr: component, + kind: dst_kind, + convert: Some(dst_width), + }, + + // Vector conversion (vector -> vector) - partial + ( + Components::One { + component, + ty: + &TypeInner::Vector { + size: src_size, + kind: src_kind, + .. + }, + .. + }, + ConstructorType::PartialVector { size: dst_size }, + ) if dst_size == src_size => Expression::As { + expr: component, + kind: src_kind, + convert: None, + }, + + // Matrix conversion (matrix -> matrix) + ( + Components::One { + component, + ty: + &TypeInner::Matrix { + columns: src_columns, + rows: src_rows, + .. + }, + .. + }, + ConstructorType::Matrix { + columns: dst_columns, + rows: dst_rows, + width: dst_width, + }, + ) if dst_columns == src_columns && dst_rows == src_rows => Expression::As { + expr: component, + kind: ScalarKind::Float, + convert: Some(dst_width), + }, + + // Matrix conversion (matrix -> matrix) - partial + ( + Components::One { + component, + ty: + &TypeInner::Matrix { + columns: src_columns, + rows: src_rows, + .. + }, + .. + }, + ConstructorType::PartialMatrix { + columns: dst_columns, + rows: dst_rows, + }, + ) if dst_columns == src_columns && dst_rows == src_rows => Expression::As { + expr: component, + kind: ScalarKind::Float, + convert: None, + }, + + // Vector constructor (splat) - infer type + ( + Components::One { + component, + ty: &TypeInner::Scalar { .. }, + .. + }, + ConstructorType::PartialVector { size }, + ) => Expression::Splat { + size, + value: component, + }, + + // Vector constructor (splat) + ( + Components::One { + component, + ty: + &TypeInner::Scalar { + kind: src_kind, + width: src_width, + .. + }, + .. + }, + ConstructorType::Vector { + size, + kind: dst_kind, + width: dst_width, + }, + ) if dst_kind == src_kind || dst_width == src_width => Expression::Splat { + size, + value: component, + }, + + // Vector constructor (by elements) + ( + Components::Many { + components, + first_component_ty: + &TypeInner::Scalar { kind, width } | &TypeInner::Vector { kind, width, .. }, + .. + }, + ConstructorType::PartialVector { size }, + ) + | ( + Components::Many { + components, + first_component_ty: &TypeInner::Scalar { .. } | &TypeInner::Vector { .. }, + .. + }, + ConstructorType::Vector { size, width, kind }, + ) => { + let ty = ctx.types.insert( + Type { + name: None, + inner: TypeInner::Vector { size, kind, width }, + }, + Default::default(), + ); + Expression::Compose { ty, components } + } + + // Matrix constructor (by elements) + ( + Components::Many { + components, + first_component_ty: &TypeInner::Scalar { width, .. }, + .. + }, + ConstructorType::PartialMatrix { columns, rows }, + ) + | ( + Components::Many { + components, + first_component_ty: &TypeInner::Scalar { .. }, + .. + }, + ConstructorType::Matrix { + columns, + rows, + width, + }, + ) => { + let vec_ty = ctx.types.insert( + Type { + name: None, + inner: TypeInner::Vector { + width, + kind: ScalarKind::Float, + size: rows, + }, + }, + Default::default(), + ); + + let components = components + .chunks(rows as usize) + .map(|vec_components| { + ctx.expressions.append( + Expression::Compose { + ty: vec_ty, + components: Vec::from(vec_components), + }, + Default::default(), + ) + }) + .collect(); + + let ty = ctx.types.insert( + Type { + name: None, + inner: TypeInner::Matrix { + columns, + rows, + width, + }, + }, + Default::default(), + ); + Expression::Compose { ty, components } + } + + // Matrix constructor (by columns) + ( + Components::Many { + components, + first_component_ty: &TypeInner::Vector { width, .. }, + .. + }, + ConstructorType::PartialMatrix { columns, rows }, + ) + | ( + Components::Many { + components, + first_component_ty: &TypeInner::Vector { .. }, + .. + }, + ConstructorType::Matrix { + columns, + rows, + width, + }, + ) => { + let ty = ctx.types.insert( + Type { + name: None, + inner: TypeInner::Matrix { + columns, + rows, + width, + }, + }, + Default::default(), + ); + Expression::Compose { ty, components } + } + + // Array constructor - infer type + (components, ConstructorType::PartialArray) => { + let components = components.into_components_vec(); + + let base = match ctx.typifier[components[0]].clone() { + TypeResolution::Handle(ty) => ty, + TypeResolution::Value(inner) => ctx + .types + .insert(Type { name: None, inner }, Default::default()), + }; + + let size = Constant { + name: None, + specialization: None, + inner: ConstantInner::Scalar { + width: 4, + value: ScalarValue::Uint(components.len() as u64), + }, + }; + + let inner = TypeInner::Array { + base, + size: ArraySize::Constant(ctx.constants.append(size, Default::default())), + stride: { + parser.layouter.update(ctx.types, ctx.constants).unwrap(); + parser.layouter[base].to_stride() + }, + }; + + let ty = ctx + .types + .insert(Type { name: None, inner }, Default::default()); + + Expression::Compose { ty, components } + } + + // Array constructor + (components, ConstructorType::Array { base, size, stride }) => { + let components = components.into_components_vec(); + let inner = TypeInner::Array { base, size, stride }; + let ty = ctx + .types + .insert(Type { name: None, inner }, Default::default()); + Expression::Compose { ty, components } + } + + // Struct constructor + (components, ConstructorType::Struct(ty)) => Expression::Compose { + ty, + components: components.into_components_vec(), + }, + + // ERRORS + + // Bad conversion (type cast) + ( + Components::One { + span, ty: src_ty, .. + }, + dst_ty, + ) => { + return Err(Error::BadTypeCast { + span, + from_type: src_ty.to_wgsl(ctx.types, ctx.constants), + to_type: dst_ty.to_error_string(ctx.types, ctx.constants), + }); + } + + // Too many parameters for scalar constructor + (Components::Many { spans, .. }, ConstructorType::Scalar { .. }) => { + return Err(Error::UnexpectedComponents(Span { + start: spans[1].start, + end: spans.last().unwrap().end, + })); + } + + // Parameters are of the wrong type for vector or matrix constructor + ( + Components::Many { spans, .. }, + ConstructorType::Vector { .. } + | ConstructorType::Matrix { .. } + | ConstructorType::PartialVector { .. } + | ConstructorType::PartialMatrix { .. }, + ) => { + return Err(Error::InvalidConstructorComponentType(spans[0].clone(), 0)); + } + }; + + let span = NagaSpan::from(parser.pop_rule_span(lexer)); + Ok(Some(ctx.expressions.append(expr, span))) +} |