diff options
Diffstat (limited to 'third_party/rust/naga/src/front/glsl/context.rs')
-rw-r--r-- | third_party/rust/naga/src/front/glsl/context.rs | 1588 |
1 files changed, 1588 insertions, 0 deletions
diff --git a/third_party/rust/naga/src/front/glsl/context.rs b/third_party/rust/naga/src/front/glsl/context.rs new file mode 100644 index 0000000000..48574fc5b0 --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/context.rs @@ -0,0 +1,1588 @@ +use super::{ + ast::{ + GlobalLookup, GlobalLookupKind, HirExpr, HirExprKind, ParameterInfo, ParameterQualifier, + VariableReference, + }, + error::{Error, ErrorKind}, + types::{scalar_components, type_power}, + Parser, Result, +}; +use crate::{ + front::{Emitter, Typifier}, + AddressSpace, Arena, BinaryOperator, Block, Constant, Expression, FastHashMap, + FunctionArgument, Handle, LocalVariable, RelationalFunction, ScalarKind, ScalarValue, Span, + Statement, Type, TypeInner, VectorSize, +}; +use std::{convert::TryFrom, ops::Index}; + +/// The position at which an expression is, used while lowering +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub enum ExprPos { + /// The expression is in the left hand side of an assignment + Lhs, + /// The expression is in the right hand side of an assignment + Rhs, + /// The expression is an array being indexed, needed to allow constant + /// arrays to be dynamically indexed + AccessBase { + /// The index is a constant + constant_index: bool, + }, +} + +impl ExprPos { + /// Returns an lhs position if the current position is lhs otherwise AccessBase + const fn maybe_access_base(&self, constant_index: bool) -> Self { + match *self { + ExprPos::Lhs => *self, + _ => ExprPos::AccessBase { constant_index }, + } + } +} + +#[derive(Debug)] +pub struct Context { + pub expressions: Arena<Expression>, + pub locals: Arena<LocalVariable>, + pub arguments: Vec<FunctionArgument>, + + pub parameters: Vec<Handle<Type>>, + pub parameters_info: Vec<ParameterInfo>, + + pub symbol_table: crate::front::SymbolTable<String, VariableReference>, + pub samplers: FastHashMap<Handle<Expression>, Handle<Expression>>, + + pub typifier: Typifier, + emitter: Emitter, + stmt_ctx: Option<StmtContext>, +} + +impl Context { + pub fn new(parser: &Parser, body: &mut Block) -> Self { + let mut this = Context { + expressions: Arena::new(), + locals: Arena::new(), + arguments: Vec::new(), + + parameters: Vec::new(), + parameters_info: Vec::new(), + + symbol_table: crate::front::SymbolTable::default(), + samplers: FastHashMap::default(), + + typifier: Typifier::new(), + emitter: Emitter::default(), + stmt_ctx: Some(StmtContext::new()), + }; + + this.emit_start(); + + for &(ref name, lookup) in parser.global_variables.iter() { + this.add_global(parser, name, lookup, body) + } + + this + } + + pub fn add_global( + &mut self, + parser: &Parser, + name: &str, + GlobalLookup { + kind, + entry_arg, + mutable, + }: GlobalLookup, + body: &mut Block, + ) { + self.emit_end(body); + let (expr, load, constant) = match kind { + GlobalLookupKind::Variable(v) => { + let span = parser.module.global_variables.get_span(v); + let res = ( + self.expressions.append(Expression::GlobalVariable(v), span), + parser.module.global_variables[v].space != AddressSpace::Handle, + None, + ); + self.emit_start(); + + res + } + GlobalLookupKind::BlockSelect(handle, index) => { + let span = parser.module.global_variables.get_span(handle); + let base = self + .expressions + .append(Expression::GlobalVariable(handle), span); + self.emit_start(); + let expr = self + .expressions + .append(Expression::AccessIndex { base, index }, span); + + ( + expr, + { + let ty = parser.module.global_variables[handle].ty; + + match parser.module.types[ty].inner { + TypeInner::Struct { ref members, .. } => { + if let TypeInner::Array { + size: crate::ArraySize::Dynamic, + .. + } = parser.module.types[members[index as usize].ty].inner + { + false + } else { + true + } + } + _ => true, + } + }, + None, + ) + } + GlobalLookupKind::Constant(v, ty) => { + let span = parser.module.constants.get_span(v); + let res = ( + self.expressions.append(Expression::Constant(v), span), + false, + Some((v, ty)), + ); + self.emit_start(); + res + } + }; + + let var = VariableReference { + expr, + load, + mutable, + constant, + entry_arg, + }; + + self.symbol_table.add(name.into(), var); + } + + /// Starts the expression emitter + /// + /// # Panics + /// + /// - If called twice in a row without calling [`emit_end`][Self::emit_end]. + #[inline] + pub fn emit_start(&mut self) { + self.emitter.start(&self.expressions) + } + + /// Emits all the expressions captured by the emitter to the passed `body` + /// + /// # Panics + /// + /// - If called before calling [`emit_start`]. + /// - If called twice in a row without calling [`emit_start`]. + /// + /// [`emit_start`]: Self::emit_start + pub fn emit_end(&mut self, body: &mut Block) { + body.extend(self.emitter.finish(&self.expressions)) + } + + /// Emits all the expressions captured by the emitter to the passed `body` + /// and starts the emitter again + /// + /// # Panics + /// + /// - If called before calling [`emit_start`][Self::emit_start]. + pub fn emit_restart(&mut self, body: &mut Block) { + self.emit_end(body); + self.emit_start() + } + + pub fn add_expression( + &mut self, + expr: Expression, + meta: Span, + body: &mut Block, + ) -> Handle<Expression> { + let needs_pre_emit = expr.needs_pre_emit(); + if needs_pre_emit { + self.emit_end(body); + } + let handle = self.expressions.append(expr, meta); + if needs_pre_emit { + self.emit_start(); + } + handle + } + + /// Add variable to current scope + /// + /// Returns a variable if a variable with the same name was already defined, + /// otherwise returns `None` + pub fn add_local_var( + &mut self, + name: String, + expr: Handle<Expression>, + mutable: bool, + ) -> Option<VariableReference> { + let var = VariableReference { + expr, + load: true, + mutable, + constant: None, + entry_arg: None, + }; + + self.symbol_table.add(name, var) + } + + /// Add function argument to current scope + pub fn add_function_arg( + &mut self, + parser: &mut Parser, + body: &mut Block, + name_meta: Option<(String, Span)>, + ty: Handle<Type>, + qualifier: ParameterQualifier, + ) { + let index = self.arguments.len(); + let mut arg = FunctionArgument { + name: name_meta.as_ref().map(|&(ref name, _)| name.clone()), + ty, + binding: None, + }; + self.parameters.push(ty); + + let opaque = match parser.module.types[ty].inner { + TypeInner::Image { .. } | TypeInner::Sampler { .. } => true, + _ => false, + }; + + if qualifier.is_lhs() { + let span = parser.module.types.get_span(arg.ty); + arg.ty = parser.module.types.insert( + Type { + name: None, + inner: TypeInner::Pointer { + base: arg.ty, + space: AddressSpace::Function, + }, + }, + span, + ) + } + + self.arguments.push(arg); + + self.parameters_info.push(ParameterInfo { + qualifier, + depth: false, + }); + + if let Some((name, meta)) = name_meta { + let expr = self.add_expression(Expression::FunctionArgument(index as u32), meta, body); + let mutable = qualifier != ParameterQualifier::Const && !opaque; + let load = qualifier.is_lhs(); + + let var = if mutable && !load { + let handle = self.locals.append( + LocalVariable { + name: Some(name.clone()), + ty, + init: None, + }, + meta, + ); + let local_expr = self.add_expression(Expression::LocalVariable(handle), meta, body); + + self.emit_restart(body); + + body.push( + Statement::Store { + pointer: local_expr, + value: expr, + }, + meta, + ); + + VariableReference { + expr: local_expr, + load: true, + mutable, + constant: None, + entry_arg: None, + } + } else { + VariableReference { + expr, + load, + mutable, + constant: None, + entry_arg: None, + } + }; + + self.symbol_table.add(name, var); + } + } + + /// Returns a [`StmtContext`](StmtContext) to be used in parsing and lowering + /// + /// # Panics + /// - If more than one [`StmtContext`](StmtContext) are active at the same + /// time or if the previous call didn't use it in lowering. + #[must_use] + pub fn stmt_ctx(&mut self) -> StmtContext { + self.stmt_ctx.take().unwrap() + } + + /// Lowers a [`HirExpr`](HirExpr) which might produce a [`Expression`](Expression). + /// + /// consumes a [`StmtContext`](StmtContext) returning it to the context so + /// that it can be used again later. + pub fn lower( + &mut self, + mut stmt: StmtContext, + parser: &mut Parser, + expr: Handle<HirExpr>, + pos: ExprPos, + body: &mut Block, + ) -> Result<(Option<Handle<Expression>>, Span)> { + let res = self.lower_inner(&stmt, parser, expr, pos, body); + + stmt.hir_exprs.clear(); + self.stmt_ctx = Some(stmt); + + res + } + + /// Similar to [`lower`](Self::lower) but returns an error if the expression + /// returns void (ie. doesn't produce a [`Expression`](Expression)). + /// + /// consumes a [`StmtContext`](StmtContext) returning it to the context so + /// that it can be used again later. + pub fn lower_expect( + &mut self, + mut stmt: StmtContext, + parser: &mut Parser, + expr: Handle<HirExpr>, + pos: ExprPos, + body: &mut Block, + ) -> Result<(Handle<Expression>, Span)> { + let res = self.lower_expect_inner(&stmt, parser, expr, pos, body); + + stmt.hir_exprs.clear(); + self.stmt_ctx = Some(stmt); + + res + } + + /// internal implementation of [`lower_expect`](Self::lower_expect) + /// + /// this method is only public because it's used in + /// [`function_call`](Parser::function_call), unless you know what + /// you're doing use [`lower_expect`](Self::lower_expect) + pub fn lower_expect_inner( + &mut self, + stmt: &StmtContext, + parser: &mut Parser, + expr: Handle<HirExpr>, + pos: ExprPos, + body: &mut Block, + ) -> Result<(Handle<Expression>, Span)> { + let (maybe_expr, meta) = self.lower_inner(stmt, parser, expr, pos, body)?; + + let expr = match maybe_expr { + Some(e) => e, + None => { + return Err(Error { + kind: ErrorKind::SemanticError("Expression returns void".into()), + meta, + }) + } + }; + + Ok((expr, meta)) + } + + fn lower_store( + &mut self, + pointer: Handle<Expression>, + value: Handle<Expression>, + meta: Span, + body: &mut Block, + ) { + if let Expression::Swizzle { + size, + mut vector, + pattern, + } = self.expressions[pointer] + { + // Stores to swizzled values are not directly supported, + // lower them as series of per-component stores. + let size = match size { + VectorSize::Bi => 2, + VectorSize::Tri => 3, + VectorSize::Quad => 4, + }; + + if let Expression::Load { pointer } = self.expressions[vector] { + vector = pointer; + } + + #[allow(clippy::needless_range_loop)] + for index in 0..size { + let dst = self.add_expression( + Expression::AccessIndex { + base: vector, + index: pattern[index].index(), + }, + meta, + body, + ); + let src = self.add_expression( + Expression::AccessIndex { + base: value, + index: index as u32, + }, + meta, + body, + ); + + self.emit_restart(body); + + body.push( + Statement::Store { + pointer: dst, + value: src, + }, + meta, + ); + } + } else { + self.emit_restart(body); + + body.push(Statement::Store { pointer, value }, meta); + } + } + + /// Internal implementation of [`lower`](Self::lower) + fn lower_inner( + &mut self, + stmt: &StmtContext, + parser: &mut Parser, + expr: Handle<HirExpr>, + pos: ExprPos, + body: &mut Block, + ) -> Result<(Option<Handle<Expression>>, Span)> { + let HirExpr { ref kind, meta } = stmt.hir_exprs[expr]; + + log::debug!("Lowering {:?}", expr); + + let handle = match *kind { + HirExprKind::Access { base, index } => { + let (index, index_meta) = + self.lower_expect_inner(stmt, parser, index, ExprPos::Rhs, body)?; + let maybe_constant_index = match pos { + // Don't try to generate `AccessIndex` if in a LHS position, since it + // wouldn't produce a pointer. + ExprPos::Lhs => None, + _ => parser.solve_constant(self, index, index_meta).ok(), + }; + + let base = self + .lower_expect_inner( + stmt, + parser, + base, + pos.maybe_access_base(maybe_constant_index.is_some()), + body, + )? + .0; + + let pointer = maybe_constant_index + .and_then(|constant| { + Some(self.add_expression( + Expression::AccessIndex { + base, + index: match parser.module.constants[constant].inner { + crate::ConstantInner::Scalar { + value: ScalarValue::Uint(i), + .. + } => u32::try_from(i).ok()?, + crate::ConstantInner::Scalar { + value: ScalarValue::Sint(i), + .. + } => u32::try_from(i).ok()?, + _ => return None, + }, + }, + meta, + body, + )) + }) + .unwrap_or_else(|| { + self.add_expression(Expression::Access { base, index }, meta, body) + }); + + if ExprPos::Rhs == pos { + let resolved = parser.resolve_type(self, pointer, meta)?; + if resolved.pointer_space().is_some() { + return Ok(( + Some(self.add_expression(Expression::Load { pointer }, meta, body)), + meta, + )); + } + } + + pointer + } + HirExprKind::Select { base, ref field } => { + let base = self + .lower_expect_inner(stmt, parser, base, pos.maybe_access_base(true), body)? + .0; + + parser.field_selection(self, pos, body, base, field, meta)? + } + HirExprKind::Constant(constant) if pos != ExprPos::Lhs => { + self.add_expression(Expression::Constant(constant), meta, body) + } + HirExprKind::Binary { left, op, right } if pos != ExprPos::Lhs => { + let (mut left, left_meta) = + self.lower_expect_inner(stmt, parser, left, ExprPos::Rhs, body)?; + let (mut right, right_meta) = + self.lower_expect_inner(stmt, parser, right, ExprPos::Rhs, body)?; + + match op { + BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight => self + .implicit_conversion(parser, &mut right, right_meta, ScalarKind::Uint, 4)?, + _ => self.binary_implicit_conversion( + parser, &mut left, left_meta, &mut right, right_meta, + )?, + } + + parser.typifier_grow(self, left, left_meta)?; + parser.typifier_grow(self, right, right_meta)?; + + let left_inner = self.typifier.get(left, &parser.module.types); + let right_inner = self.typifier.get(right, &parser.module.types); + + match (left_inner, right_inner) { + ( + &TypeInner::Matrix { + columns: left_columns, + rows: left_rows, + width: left_width, + }, + &TypeInner::Matrix { + columns: right_columns, + rows: right_rows, + width: right_width, + }, + ) => { + let dimensions_ok = if op == BinaryOperator::Multiply { + left_columns == right_rows + } else { + left_columns == right_columns && left_rows == right_rows + }; + + // Check that the two arguments have the same dimensions + if !dimensions_ok || left_width != right_width { + parser.errors.push(Error { + kind: ErrorKind::SemanticError( + format!( + "Cannot apply operation to {:?} and {:?}", + left_inner, right_inner + ) + .into(), + ), + meta, + }) + } + + match op { + BinaryOperator::Divide => { + // Naga IR doesn't support matrix division so we need to + // divide the columns individually and reassemble the matrix + let mut components = Vec::with_capacity(left_columns as usize); + + for index in 0..left_columns as u32 { + // Get the column vectors + let left_vector = self.add_expression( + Expression::AccessIndex { base: left, index }, + meta, + body, + ); + let right_vector = self.add_expression( + Expression::AccessIndex { base: right, index }, + meta, + body, + ); + + // Divide the vectors + let column = self.expressions.append( + Expression::Binary { + op, + left: left_vector, + right: right_vector, + }, + meta, + ); + + components.push(column) + } + + // Rebuild the matrix from the divided vectors + self.expressions.append( + Expression::Compose { + ty: parser.module.types.insert( + Type { + name: None, + inner: TypeInner::Matrix { + columns: left_columns, + rows: left_rows, + width: left_width, + }, + }, + Span::default(), + ), + components, + }, + meta, + ) + } + BinaryOperator::Equal | BinaryOperator::NotEqual => { + // Naga IR doesn't support matrix comparisons so we need to + // compare the columns individually and then fold them together + // + // The folding is done using a logical and for equality and + // a logical or for inequality + let equals = op == BinaryOperator::Equal; + + let (op, combine, fun) = match equals { + true => ( + BinaryOperator::Equal, + BinaryOperator::LogicalAnd, + RelationalFunction::All, + ), + false => ( + BinaryOperator::NotEqual, + BinaryOperator::LogicalOr, + RelationalFunction::Any, + ), + }; + + let mut root = None; + + for index in 0..left_columns as u32 { + // Get the column vectors + let left_vector = self.add_expression( + Expression::AccessIndex { base: left, index }, + meta, + body, + ); + let right_vector = self.add_expression( + Expression::AccessIndex { base: right, index }, + meta, + body, + ); + + let argument = self.expressions.append( + Expression::Binary { + op, + left: left_vector, + right: right_vector, + }, + meta, + ); + + // The result of comparing two vectors is a boolean vector + // so use a relational function like all to get a single + // boolean value + let compare = self.add_expression( + Expression::Relational { fun, argument }, + meta, + body, + ); + + // Fold the result + root = Some(match root { + Some(right) => self.add_expression( + Expression::Binary { + op: combine, + left: compare, + right, + }, + meta, + body, + ), + None => compare, + }); + } + + root.unwrap() + } + _ => self.add_expression( + Expression::Binary { left, op, right }, + meta, + body, + ), + } + } + (&TypeInner::Vector { .. }, &TypeInner::Vector { .. }) => match op { + BinaryOperator::Equal | BinaryOperator::NotEqual => { + let equals = op == BinaryOperator::Equal; + + let (op, fun) = match equals { + true => (BinaryOperator::Equal, RelationalFunction::All), + false => (BinaryOperator::NotEqual, RelationalFunction::Any), + }; + + let argument = self + .expressions + .append(Expression::Binary { op, left, right }, meta); + + self.add_expression( + Expression::Relational { fun, argument }, + meta, + body, + ) + } + _ => { + self.add_expression(Expression::Binary { left, op, right }, meta, body) + } + }, + (&TypeInner::Vector { size, .. }, &TypeInner::Scalar { .. }) => match op { + BinaryOperator::Add + | BinaryOperator::Subtract + | BinaryOperator::Divide + | BinaryOperator::And + | BinaryOperator::ExclusiveOr + | BinaryOperator::InclusiveOr + | BinaryOperator::ShiftLeft + | BinaryOperator::ShiftRight => { + let scalar_vector = self.add_expression( + Expression::Splat { size, value: right }, + meta, + body, + ); + + self.add_expression( + Expression::Binary { + op, + left, + right: scalar_vector, + }, + meta, + body, + ) + } + _ => { + self.add_expression(Expression::Binary { left, op, right }, meta, body) + } + }, + (&TypeInner::Scalar { .. }, &TypeInner::Vector { size, .. }) => match op { + BinaryOperator::Add + | BinaryOperator::Subtract + | BinaryOperator::Divide + | BinaryOperator::And + | BinaryOperator::ExclusiveOr + | BinaryOperator::InclusiveOr => { + let scalar_vector = self.add_expression( + Expression::Splat { size, value: left }, + meta, + body, + ); + + self.add_expression( + Expression::Binary { + op, + left: scalar_vector, + right, + }, + meta, + body, + ) + } + _ => { + self.add_expression(Expression::Binary { left, op, right }, meta, body) + } + }, + ( + &TypeInner::Scalar { + width: left_width, .. + }, + &TypeInner::Matrix { + rows, + columns, + width: right_width, + }, + ) => { + // Check that the two arguments have the same width + if left_width != right_width { + parser.errors.push(Error { + kind: ErrorKind::SemanticError( + format!( + "Cannot apply operation to {:?} and {:?}", + left_inner, right_inner + ) + .into(), + ), + meta, + }) + } + + match op { + BinaryOperator::Divide + | BinaryOperator::Add + | BinaryOperator::Subtract => { + // Naga IR doesn't support all matrix by scalar operations so + // we need for some to turn the scalar into a vector by + // splatting it and then for each column vector apply the + // operation and finally reconstruct the matrix + let scalar_vector = self.add_expression( + Expression::Splat { + size: rows, + value: left, + }, + meta, + body, + ); + + let mut components = Vec::with_capacity(columns as usize); + + for index in 0..columns as u32 { + // Get the column vector + let matrix_column = self.add_expression( + Expression::AccessIndex { base: right, index }, + meta, + body, + ); + + // Apply the operation to the splatted vector and + // the column vector + let column = self.expressions.append( + Expression::Binary { + op, + left: scalar_vector, + right: matrix_column, + }, + meta, + ); + + components.push(column) + } + + // Rebuild the matrix from the operation result vectors + self.expressions.append( + Expression::Compose { + ty: parser.module.types.insert( + Type { + name: None, + inner: TypeInner::Matrix { + columns, + rows, + width: left_width, + }, + }, + Span::default(), + ), + components, + }, + meta, + ) + } + _ => self.add_expression( + Expression::Binary { left, op, right }, + meta, + body, + ), + } + } + ( + &TypeInner::Matrix { + rows, + columns, + width: left_width, + }, + &TypeInner::Scalar { + width: right_width, .. + }, + ) => { + // Check that the two arguments have the same width + if left_width != right_width { + parser.errors.push(Error { + kind: ErrorKind::SemanticError( + format!( + "Cannot apply operation to {:?} and {:?}", + left_inner, right_inner + ) + .into(), + ), + meta, + }) + } + + match op { + BinaryOperator::Divide + | BinaryOperator::Add + | BinaryOperator::Subtract => { + // Naga IR doesn't support all matrix by scalar operations so + // we need for some to turn the scalar into a vector by + // splatting it and then for each column vector apply the + // operation and finally reconstruct the matrix + + let scalar_vector = self.add_expression( + Expression::Splat { + size: rows, + value: right, + }, + meta, + body, + ); + + let mut components = Vec::with_capacity(columns as usize); + + for index in 0..columns as u32 { + // Get the column vector + let matrix_column = self.add_expression( + Expression::AccessIndex { base: left, index }, + meta, + body, + ); + + // Apply the operation to the splatted vector and + // the column vector + let column = self.expressions.append( + Expression::Binary { + op, + left: matrix_column, + right: scalar_vector, + }, + meta, + ); + + components.push(column) + } + + // Rebuild the matrix from the operation result vectors + self.expressions.append( + Expression::Compose { + ty: parser.module.types.insert( + Type { + name: None, + inner: TypeInner::Matrix { + columns, + rows, + width: left_width, + }, + }, + Span::default(), + ), + components, + }, + meta, + ) + } + _ => self.add_expression( + Expression::Binary { left, op, right }, + meta, + body, + ), + } + } + _ => self.add_expression(Expression::Binary { left, op, right }, meta, body), + } + } + HirExprKind::Unary { op, expr } if pos != ExprPos::Lhs => { + let expr = self + .lower_expect_inner(stmt, parser, expr, ExprPos::Rhs, body)? + .0; + + self.add_expression(Expression::Unary { op, expr }, meta, body) + } + HirExprKind::Variable(ref var) => match pos { + ExprPos::Lhs => { + if !var.mutable { + parser.errors.push(Error { + kind: ErrorKind::SemanticError( + "Variable cannot be used in LHS position".into(), + ), + meta, + }) + } + + var.expr + } + ExprPos::AccessBase { constant_index } => { + // If the index isn't constant all accesses backed by a constant base need + // to be done trough a proxy local variable, since constants have a non + // pointer type which is required for dynamic indexing + if !constant_index { + if let Some((constant, ty)) = var.constant { + let local = self.locals.append( + LocalVariable { + name: None, + ty, + init: Some(constant), + }, + Span::default(), + ); + + self.add_expression( + Expression::LocalVariable(local), + Span::default(), + body, + ) + } else { + var.expr + } + } else { + var.expr + } + } + _ if var.load => { + self.add_expression(Expression::Load { pointer: var.expr }, meta, body) + } + _ => var.expr, + }, + HirExprKind::Call(ref call) if pos != ExprPos::Lhs => { + let maybe_expr = parser.function_or_constructor_call( + self, + stmt, + body, + call.kind.clone(), + &call.args, + meta, + )?; + return Ok((maybe_expr, meta)); + } + // `HirExprKind::Conditional` represents the ternary operator in glsl (`:?`) + // + // The ternary operator is defined to only evaluate one of the two possible + // expressions which means that it's behavior is that of an `if` statement, + // and it's merely syntatic sugar for it. + HirExprKind::Conditional { + condition, + accept, + reject, + } if ExprPos::Lhs != pos => { + // Given an expression `a ? b : c`, we need to produce a Naga + // statement roughly like: + // + // var temp; + // if a { + // temp = convert(b); + // } else { + // temp = convert(c); + // } + // + // where `convert` stands for type conversions to bring `b` and `c` to + // the same type, and then use `temp` to represent the value of the whole + // conditional expression in subsequent code. + + // Lower the condition first to the current bodyy + let condition = self + .lower_expect_inner(stmt, parser, condition, ExprPos::Rhs, body)? + .0; + + // Emit all expressions since we will be adding statements to + // other bodies next + self.emit_restart(body); + + // Create the bodies for the two cases + let mut accept_body = Block::new(); + let mut reject_body = Block::new(); + + // Lower the `true` branch + let (mut accept, accept_meta) = + self.lower_expect_inner(stmt, parser, accept, pos, &mut accept_body)?; + + // Flush the body of the `true` branch, to start emitting on the + // `false` branch + self.emit_restart(&mut accept_body); + + // Lower the `false` branch + let (mut reject, reject_meta) = + self.lower_expect_inner(stmt, parser, reject, pos, &mut reject_body)?; + + // Flush the body of the `false` branch + self.emit_restart(&mut reject_body); + + // We need to do some custom implicit conversions since the two target expressions + // are in different bodies + if let ( + Some((accept_power, accept_width, accept_kind)), + Some((reject_power, reject_width, reject_kind)), + ) = ( + // Get the components of both branches and calculate the type power + self.expr_scalar_components(parser, accept, accept_meta)? + .and_then(|(kind, width)| Some((type_power(kind, width)?, width, kind))), + self.expr_scalar_components(parser, reject, reject_meta)? + .and_then(|(kind, width)| Some((type_power(kind, width)?, width, kind))), + ) { + match accept_power.cmp(&reject_power) { + std::cmp::Ordering::Less => { + self.conversion(&mut accept, accept_meta, reject_kind, reject_width)?; + // The expression belongs to the `true` branch so we need to flush to + // the respective body + self.emit_end(&mut accept_body); + } + // Technically there's nothing to flush but later we will need to + // add some expressions that must not be emitted so instead + // of flushing, starting and flushing again, just make sure + // everything is flushed. + std::cmp::Ordering::Equal => self.emit_end(body), + std::cmp::Ordering::Greater => { + self.conversion(&mut reject, reject_meta, accept_kind, accept_width)?; + // The expression belongs to the `false` branch so we need to flush to + // the respective body + self.emit_end(&mut reject_body); + } + } + } else { + // Technically there's nothing to flush but later we will need to + // add some expressions that must not be emitted. + self.emit_end(body) + } + + // We need to get the type of the resulting expression to create the local, + // this must be done after implicit conversions to ensure both branches have + // the same type. + let ty = parser.resolve_type_handle(self, accept, accept_meta)?; + + // Add the local that will hold the result of our conditional + let local = self.locals.append( + LocalVariable { + name: None, + ty, + init: None, + }, + meta, + ); + + // Note: `Expression::LocalVariable` must not be emited so it's important + // that at this point the emitter is flushed but not started. + let local_expr = self + .expressions + .append(Expression::LocalVariable(local), meta); + + // Add to each body the store to the result variable + accept_body.push( + Statement::Store { + pointer: local_expr, + value: accept, + }, + accept_meta, + ); + reject_body.push( + Statement::Store { + pointer: local_expr, + value: reject, + }, + reject_meta, + ); + + // Finally add the `If` to the main body with the `condition` we lowered + // earlier and the branches we prepared. + body.push( + Statement::If { + condition, + accept: accept_body, + reject: reject_body, + }, + meta, + ); + + // Restart the emitter + self.emit_start(); + + // Note: `Expression::Load` must be emited before it's used so make + // sure the emitter is active here. + self.expressions.append( + Expression::Load { + pointer: local_expr, + }, + meta, + ) + } + HirExprKind::Assign { tgt, value } if ExprPos::Lhs != pos => { + let (pointer, ptr_meta) = + self.lower_expect_inner(stmt, parser, tgt, ExprPos::Lhs, body)?; + let (mut value, value_meta) = + self.lower_expect_inner(stmt, parser, value, ExprPos::Rhs, body)?; + + let ty = match *parser.resolve_type(self, pointer, ptr_meta)? { + TypeInner::Pointer { base, .. } => &parser.module.types[base].inner, + ref ty => ty, + }; + + if let Some((kind, width)) = scalar_components(ty) { + self.implicit_conversion(parser, &mut value, value_meta, kind, width)?; + } + + self.lower_store(pointer, value, meta, body); + + value + } + HirExprKind::PrePostfix { op, postfix, expr } if ExprPos::Lhs != pos => { + let (pointer, _) = + self.lower_expect_inner(stmt, parser, expr, ExprPos::Lhs, body)?; + let left = if let Expression::Swizzle { .. } = self.expressions[pointer] { + pointer + } else { + self.add_expression(Expression::Load { pointer }, meta, body) + }; + + let make_constant_inner = |kind, width| { + let value = match kind { + ScalarKind::Sint => crate::ScalarValue::Sint(1), + ScalarKind::Uint => crate::ScalarValue::Uint(1), + ScalarKind::Float => crate::ScalarValue::Float(1.0), + ScalarKind::Bool => return None, + }; + + Some(crate::ConstantInner::Scalar { width, value }) + }; + let res = match *parser.resolve_type(self, left, meta)? { + TypeInner::Scalar { kind, width } => { + let ty = TypeInner::Scalar { kind, width }; + make_constant_inner(kind, width).map(|i| (ty, i, None, None)) + } + TypeInner::Vector { size, kind, width } => { + let ty = TypeInner::Vector { size, kind, width }; + make_constant_inner(kind, width).map(|i| (ty, i, Some(size), None)) + } + TypeInner::Matrix { + columns, + rows, + width, + } => { + let ty = TypeInner::Matrix { + columns, + rows, + width, + }; + make_constant_inner(ScalarKind::Float, width) + .map(|i| (ty, i, Some(rows), Some(columns))) + } + _ => None, + }; + let (ty_inner, inner, rows, columns) = match res { + Some(res) => res, + None => { + parser.errors.push(Error { + kind: ErrorKind::SemanticError( + "Increment/decrement only works on scalar/vector/matrix".into(), + ), + meta, + }); + return Ok((Some(left), meta)); + } + }; + + let constant_1 = parser.module.constants.append( + Constant { + name: None, + specialization: None, + inner, + }, + Default::default(), + ); + let mut right = self.add_expression(Expression::Constant(constant_1), meta, body); + + // Glsl allows pre/postfixes operations on vectors and matrices, so if the + // target is either of them change the right side of the addition to be splatted + // to the same size as the target, furthermore if the target is a matrix + // use a composed matrix using the splatted value. + if let Some(size) = rows { + right = + self.add_expression(Expression::Splat { size, value: right }, meta, body); + + if let Some(cols) = columns { + let ty = parser.module.types.insert( + Type { + name: None, + inner: ty_inner, + }, + meta, + ); + + right = self.add_expression( + Expression::Compose { + ty, + components: std::iter::repeat(right).take(cols as usize).collect(), + }, + meta, + body, + ); + } + } + + let value = self.add_expression(Expression::Binary { op, left, right }, meta, body); + + self.lower_store(pointer, value, meta, body); + + if postfix { + left + } else { + value + } + } + HirExprKind::Method { + expr: object, + ref name, + ref args, + } if ExprPos::Lhs != pos => { + let args = args + .iter() + .map(|e| self.lower_expect_inner(stmt, parser, *e, ExprPos::Rhs, body)) + .collect::<Result<Vec<_>>>()?; + match name.as_ref() { + "length" => { + if !args.is_empty() { + parser.errors.push(Error { + kind: ErrorKind::SemanticError( + ".length() doesn't take any arguments".into(), + ), + meta, + }); + } + let lowered_array = + self.lower_expect_inner(stmt, parser, object, pos, body)?.0; + let array_type = parser.resolve_type(self, lowered_array, meta)?; + + match *array_type { + TypeInner::Array { + size: crate::ArraySize::Constant(size), + .. + } => { + let mut array_length = + self.add_expression(Expression::Constant(size), meta, body); + self.forced_conversion( + parser, + &mut array_length, + meta, + ScalarKind::Sint, + 4, + )?; + array_length + } + // let the error be handled in type checking if it's not a dynamic array + _ => { + let mut array_length = self.add_expression( + Expression::ArrayLength(lowered_array), + meta, + body, + ); + self.conversion(&mut array_length, meta, ScalarKind::Sint, 4)?; + array_length + } + } + } + _ => { + return Err(Error { + kind: ErrorKind::SemanticError( + format!("unknown method '{}'", name).into(), + ), + meta, + }); + } + } + } + _ => { + return Err(Error { + kind: ErrorKind::SemanticError( + format!("{:?} cannot be in the left hand side", stmt.hir_exprs[expr]) + .into(), + ), + meta, + }) + } + }; + + log::trace!( + "Lowered {:?}\n\tKind = {:?}\n\tPos = {:?}\n\tResult = {:?}", + expr, + kind, + pos, + handle + ); + + Ok((Some(handle), meta)) + } + + pub fn expr_scalar_components( + &mut self, + parser: &Parser, + expr: Handle<Expression>, + meta: Span, + ) -> Result<Option<(ScalarKind, crate::Bytes)>> { + let ty = parser.resolve_type(self, expr, meta)?; + Ok(scalar_components(ty)) + } + + pub fn expr_power( + &mut self, + parser: &Parser, + expr: Handle<Expression>, + meta: Span, + ) -> Result<Option<u32>> { + Ok(self + .expr_scalar_components(parser, expr, meta)? + .and_then(|(kind, width)| type_power(kind, width))) + } + + pub fn conversion( + &mut self, + expr: &mut Handle<Expression>, + meta: Span, + kind: ScalarKind, + width: crate::Bytes, + ) -> Result<()> { + *expr = self.expressions.append( + Expression::As { + expr: *expr, + kind, + convert: Some(width), + }, + meta, + ); + + Ok(()) + } + + pub fn implicit_conversion( + &mut self, + parser: &Parser, + expr: &mut Handle<Expression>, + meta: Span, + kind: ScalarKind, + width: crate::Bytes, + ) -> Result<()> { + if let (Some(tgt_power), Some(expr_power)) = ( + type_power(kind, width), + self.expr_power(parser, *expr, meta)?, + ) { + if tgt_power > expr_power { + self.conversion(expr, meta, kind, width)?; + } + } + + Ok(()) + } + + pub fn forced_conversion( + &mut self, + parser: &Parser, + expr: &mut Handle<Expression>, + meta: Span, + kind: ScalarKind, + width: crate::Bytes, + ) -> Result<()> { + if let Some((expr_scalar_kind, expr_width)) = + self.expr_scalar_components(parser, *expr, meta)? + { + if expr_scalar_kind != kind || expr_width != width { + self.conversion(expr, meta, kind, width)?; + } + } + + Ok(()) + } + + pub fn binary_implicit_conversion( + &mut self, + parser: &Parser, + left: &mut Handle<Expression>, + left_meta: Span, + right: &mut Handle<Expression>, + right_meta: Span, + ) -> Result<()> { + let left_components = self.expr_scalar_components(parser, *left, left_meta)?; + let right_components = self.expr_scalar_components(parser, *right, right_meta)?; + + if let ( + Some((left_power, left_width, left_kind)), + Some((right_power, right_width, right_kind)), + ) = ( + left_components.and_then(|(kind, width)| Some((type_power(kind, width)?, width, kind))), + right_components + .and_then(|(kind, width)| Some((type_power(kind, width)?, width, kind))), + ) { + match left_power.cmp(&right_power) { + std::cmp::Ordering::Less => { + self.conversion(left, left_meta, right_kind, right_width)?; + } + std::cmp::Ordering::Equal => {} + std::cmp::Ordering::Greater => { + self.conversion(right, right_meta, left_kind, left_width)?; + } + } + } + + Ok(()) + } + + pub fn implicit_splat( + &mut self, + parser: &Parser, + expr: &mut Handle<Expression>, + meta: Span, + vector_size: Option<VectorSize>, + ) -> Result<()> { + let expr_type = parser.resolve_type(self, *expr, meta)?; + + if let (&TypeInner::Scalar { .. }, Some(size)) = (expr_type, vector_size) { + *expr = self + .expressions + .append(Expression::Splat { size, value: *expr }, meta) + } + + Ok(()) + } + + pub fn vector_resize( + &mut self, + size: VectorSize, + vector: Handle<Expression>, + meta: Span, + body: &mut Block, + ) -> Handle<Expression> { + self.add_expression( + Expression::Swizzle { + size, + vector, + pattern: crate::SwizzleComponent::XYZW, + }, + meta, + body, + ) + } +} + +impl Index<Handle<Expression>> for Context { + type Output = Expression; + + fn index(&self, index: Handle<Expression>) -> &Self::Output { + &self.expressions[index] + } +} + +/// Helper struct passed when parsing expressions +/// +/// This struct should only be obtained trough [`stmt_ctx`](Context::stmt_ctx) +/// and only one of these may be active at any time per context. +#[derive(Debug)] +pub struct StmtContext { + /// A arena of high level expressions which can be lowered trough a + /// [`Context`](Context) to naga's [`Expression`](crate::Expression)s + pub hir_exprs: Arena<HirExpr>, +} + +impl StmtContext { + const fn new() -> Self { + StmtContext { + hir_exprs: Arena::new(), + } + } +} |