From 8dd16259287f58f9273002717ec4d27e97127719 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Wed, 12 Jun 2024 07:43:14 +0200 Subject: Merging upstream version 127.0. Signed-off-by: Daniel Baumann --- .../rust/naga/src/proc/constant_evaluator.rs | 589 ++++++++++++++------- 1 file changed, 392 insertions(+), 197 deletions(-) (limited to 'third_party/rust/naga/src/proc/constant_evaluator.rs') diff --git a/third_party/rust/naga/src/proc/constant_evaluator.rs b/third_party/rust/naga/src/proc/constant_evaluator.rs index 983af3718c..ead3d00980 100644 --- a/third_party/rust/naga/src/proc/constant_evaluator.rs +++ b/third_party/rust/naga/src/proc/constant_evaluator.rs @@ -4,8 +4,8 @@ use arrayvec::ArrayVec; use crate::{ arena::{Arena, Handle, UniqueArena}, - ArraySize, BinaryOperator, Constant, Expression, Literal, ScalarKind, Span, Type, TypeInner, - UnaryOperator, + ArraySize, BinaryOperator, Constant, Expression, Literal, Override, ScalarKind, Span, Type, + TypeInner, UnaryOperator, }; /// A macro that allows dollar signs (`$`) to be emitted by other macros. Useful for generating @@ -253,9 +253,20 @@ gen_component_wise_extractor! { } #[derive(Debug)] -enum Behavior { - Wgsl, - Glsl, +enum Behavior<'a> { + Wgsl(WgslRestrictions<'a>), + Glsl(GlslRestrictions<'a>), +} + +impl Behavior<'_> { + /// Returns `true` if the inner WGSL/GLSL restrictions are runtime restrictions. + const fn has_runtime_restrictions(&self) -> bool { + matches!( + self, + &Behavior::Wgsl(WgslRestrictions::Runtime(_)) + | &Behavior::Glsl(GlslRestrictions::Runtime(_)) + ) + } } /// A context for evaluating constant expressions. @@ -278,7 +289,7 @@ enum Behavior { #[derive(Debug)] pub struct ConstantEvaluator<'a> { /// Which language's evaluation rules we should follow. - behavior: Behavior, + behavior: Behavior<'a>, /// The module's type arena. /// @@ -291,71 +302,155 @@ pub struct ConstantEvaluator<'a> { /// The module's constant arena. constants: &'a Arena, + /// The module's override arena. + overrides: &'a Arena, + /// The arena to which we are contributing expressions. expressions: &'a mut Arena, - /// When `self.expressions` refers to a function's local expression - /// arena, this needs to be populated - function_local_data: Option>, + /// Tracks the constness of expressions residing in [`Self::expressions`] + expression_kind_tracker: &'a mut ExpressionKindTracker, +} + +#[derive(Debug)] +enum WgslRestrictions<'a> { + /// - const-expressions will be evaluated and inserted in the arena + Const, + /// - const-expressions will be evaluated and inserted in the arena + /// - override-expressions will be inserted in the arena + Override, + /// - const-expressions will be evaluated and inserted in the arena + /// - override-expressions will be inserted in the arena + /// - runtime-expressions will be inserted in the arena + Runtime(FunctionLocalData<'a>), +} + +#[derive(Debug)] +enum GlslRestrictions<'a> { + /// - const-expressions will be evaluated and inserted in the arena + Const, + /// - const-expressions will be evaluated and inserted in the arena + /// - override-expressions will be inserted in the arena + /// - runtime-expressions will be inserted in the arena + Runtime(FunctionLocalData<'a>), } #[derive(Debug)] struct FunctionLocalData<'a> { /// Global constant expressions - const_expressions: &'a Arena, - /// Tracks the constness of expressions residing in `ConstantEvaluator.expressions` - expression_constness: &'a mut ExpressionConstnessTracker, + global_expressions: &'a Arena, emitter: &'a mut super::Emitter, block: &'a mut crate::Block, } +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +pub enum ExpressionKind { + Const, + Override, + Runtime, +} + #[derive(Debug)] -pub struct ExpressionConstnessTracker { - inner: bit_set::BitSet, +pub struct ExpressionKindTracker { + inner: Vec, } -impl ExpressionConstnessTracker { - pub fn new() -> Self { - Self { - inner: bit_set::BitSet::new(), - } +impl ExpressionKindTracker { + pub const fn new() -> Self { + Self { inner: Vec::new() } } /// Forces the the expression to not be const pub fn force_non_const(&mut self, value: Handle) { - self.inner.remove(value.index()); + self.inner[value.index()] = ExpressionKind::Runtime; } - fn insert(&mut self, value: Handle) { - self.inner.insert(value.index()); + pub fn insert(&mut self, value: Handle, expr_type: ExpressionKind) { + assert_eq!(self.inner.len(), value.index()); + self.inner.push(expr_type); + } + pub fn is_const(&self, h: Handle) -> bool { + matches!(self.type_of(h), ExpressionKind::Const) + } + + pub fn is_const_or_override(&self, h: Handle) -> bool { + matches!( + self.type_of(h), + ExpressionKind::Const | ExpressionKind::Override + ) } - pub fn is_const(&self, value: Handle) -> bool { - self.inner.contains(value.index()) + fn type_of(&self, value: Handle) -> ExpressionKind { + self.inner[value.index()] } pub fn from_arena(arena: &Arena) -> Self { - let mut tracker = Self::new(); - for (handle, expr) in arena.iter() { - let insert = match *expr { - crate::Expression::Literal(_) - | crate::Expression::ZeroValue(_) - | crate::Expression::Constant(_) => true, - crate::Expression::Compose { ref components, .. } => { - components.iter().all(|h| tracker.is_const(*h)) - } - crate::Expression::Splat { value, .. } => tracker.is_const(value), - _ => false, - }; - if insert { - tracker.insert(handle); - } + let mut tracker = Self { + inner: Vec::with_capacity(arena.len()), + }; + for (_, expr) in arena.iter() { + tracker.inner.push(tracker.type_of_with_expr(expr)); } tracker } + + fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind { + match *expr { + Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => { + ExpressionKind::Const + } + Expression::Override(_) => ExpressionKind::Override, + Expression::Compose { ref components, .. } => { + let mut expr_type = ExpressionKind::Const; + for component in components { + expr_type = expr_type.max(self.type_of(*component)) + } + expr_type + } + Expression::Splat { value, .. } => self.type_of(value), + Expression::AccessIndex { base, .. } => self.type_of(base), + Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)), + Expression::Swizzle { vector, .. } => self.type_of(vector), + Expression::Unary { expr, .. } => self.type_of(expr), + Expression::Binary { left, right, .. } => self.type_of(left).max(self.type_of(right)), + Expression::Math { + arg, + arg1, + arg2, + arg3, + .. + } => self + .type_of(arg) + .max( + arg1.map(|arg| self.type_of(arg)) + .unwrap_or(ExpressionKind::Const), + ) + .max( + arg2.map(|arg| self.type_of(arg)) + .unwrap_or(ExpressionKind::Const), + ) + .max( + arg3.map(|arg| self.type_of(arg)) + .unwrap_or(ExpressionKind::Const), + ), + Expression::As { expr, .. } => self.type_of(expr), + Expression::Select { + condition, + accept, + reject, + } => self + .type_of(condition) + .max(self.type_of(accept)) + .max(self.type_of(reject)), + Expression::Relational { argument, .. } => self.type_of(argument), + Expression::ArrayLength(expr) => self.type_of(expr), + _ => ExpressionKind::Runtime, + } + } } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum ConstantEvaluatorError { #[error("Constants cannot access function arguments")] FunctionArg, @@ -381,6 +476,8 @@ pub enum ConstantEvaluatorError { ImageExpression, #[error("Constants don't support ray query expressions")] RayQueryExpression, + #[error("Constants don't support subgroup expressions")] + SubgroupExpression, #[error("Cannot access the type")] InvalidAccessBase, #[error("Cannot access at the index")] @@ -432,6 +529,12 @@ pub enum ConstantEvaluatorError { ShiftedMoreThan32Bits, #[error(transparent)] Literal(#[from] crate::valid::LiteralError), + #[error("Can't use pipeline-overridable constants in const-expressions")] + Override, + #[error("Unexpected runtime-expression")] + RuntimeExpr, + #[error("Unexpected override-expression")] + OverrideExpr, } impl<'a> ConstantEvaluator<'a> { @@ -439,25 +542,49 @@ impl<'a> ConstantEvaluator<'a> { /// constant expression arena. /// /// Report errors according to WGSL's rules for constant evaluation. - pub fn for_wgsl_module(module: &'a mut crate::Module) -> Self { - Self::for_module(Behavior::Wgsl, module) + pub fn for_wgsl_module( + module: &'a mut crate::Module, + global_expression_kind_tracker: &'a mut ExpressionKindTracker, + in_override_ctx: bool, + ) -> Self { + Self::for_module( + Behavior::Wgsl(if in_override_ctx { + WgslRestrictions::Override + } else { + WgslRestrictions::Const + }), + module, + global_expression_kind_tracker, + ) } /// Return a [`ConstantEvaluator`] that will add expressions to `module`'s /// constant expression arena. /// /// Report errors according to GLSL's rules for constant evaluation. - pub fn for_glsl_module(module: &'a mut crate::Module) -> Self { - Self::for_module(Behavior::Glsl, module) + pub fn for_glsl_module( + module: &'a mut crate::Module, + global_expression_kind_tracker: &'a mut ExpressionKindTracker, + ) -> Self { + Self::for_module( + Behavior::Glsl(GlslRestrictions::Const), + module, + global_expression_kind_tracker, + ) } - fn for_module(behavior: Behavior, module: &'a mut crate::Module) -> Self { + fn for_module( + behavior: Behavior<'a>, + module: &'a mut crate::Module, + global_expression_kind_tracker: &'a mut ExpressionKindTracker, + ) -> Self { Self { behavior, types: &mut module.types, constants: &module.constants, - expressions: &mut module.const_expressions, - function_local_data: None, + overrides: &module.overrides, + expressions: &mut module.global_expressions, + expression_kind_tracker: global_expression_kind_tracker, } } @@ -468,18 +595,22 @@ impl<'a> ConstantEvaluator<'a> { pub fn for_wgsl_function( module: &'a mut crate::Module, expressions: &'a mut Arena, - expression_constness: &'a mut ExpressionConstnessTracker, + local_expression_kind_tracker: &'a mut ExpressionKindTracker, emitter: &'a mut super::Emitter, block: &'a mut crate::Block, ) -> Self { - Self::for_function( - Behavior::Wgsl, - module, + Self { + behavior: Behavior::Wgsl(WgslRestrictions::Runtime(FunctionLocalData { + global_expressions: &module.global_expressions, + emitter, + block, + })), + types: &mut module.types, + constants: &module.constants, + overrides: &module.overrides, expressions, - expression_constness, - emitter, - block, - ) + expression_kind_tracker: local_expression_kind_tracker, + } } /// Return a [`ConstantEvaluator`] that will add expressions to `function`'s @@ -489,39 +620,21 @@ impl<'a> ConstantEvaluator<'a> { pub fn for_glsl_function( module: &'a mut crate::Module, expressions: &'a mut Arena, - expression_constness: &'a mut ExpressionConstnessTracker, - emitter: &'a mut super::Emitter, - block: &'a mut crate::Block, - ) -> Self { - Self::for_function( - Behavior::Glsl, - module, - expressions, - expression_constness, - emitter, - block, - ) - } - - fn for_function( - behavior: Behavior, - module: &'a mut crate::Module, - expressions: &'a mut Arena, - expression_constness: &'a mut ExpressionConstnessTracker, + local_expression_kind_tracker: &'a mut ExpressionKindTracker, emitter: &'a mut super::Emitter, block: &'a mut crate::Block, ) -> Self { Self { - behavior, + behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData { + global_expressions: &module.global_expressions, + emitter, + block, + })), types: &mut module.types, constants: &module.constants, + overrides: &module.overrides, expressions, - function_local_data: Some(FunctionLocalData { - const_expressions: &module.const_expressions, - expression_constness, - emitter, - block, - }), + expression_kind_tracker: local_expression_kind_tracker, } } @@ -529,19 +642,18 @@ impl<'a> ConstantEvaluator<'a> { crate::proc::GlobalCtx { types: self.types, constants: self.constants, - const_expressions: match self.function_local_data { - Some(ref data) => data.const_expressions, + overrides: self.overrides, + global_expressions: match self.function_local_data() { + Some(data) => data.global_expressions, None => self.expressions, }, } } fn check(&self, expr: Handle) -> Result<(), ConstantEvaluatorError> { - if let Some(ref function_local_data) = self.function_local_data { - if !function_local_data.expression_constness.is_const(expr) { - log::debug!("check: SubexpressionsAreNotConstant"); - return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant); - } + if !self.expression_kind_tracker.is_const(expr) { + log::debug!("check: SubexpressionsAreNotConstant"); + return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant); } Ok(()) } @@ -554,11 +666,11 @@ impl<'a> ConstantEvaluator<'a> { Expression::Constant(c) => { // Are we working in a function's expression arena, or the // module's constant expression arena? - if let Some(ref function_local_data) = self.function_local_data { + if let Some(function_local_data) = self.function_local_data() { // Deep-copy the constant's value into our arena. self.copy_from( self.constants[c].init, - function_local_data.const_expressions, + function_local_data.global_expressions, ) } else { // "See through" the constant and use its initializer. @@ -580,9 +692,11 @@ impl<'a> ConstantEvaluator<'a> { /// [`ZeroValue`], and [`Swizzle`] expressions - to the expression arena /// `self` contributes to. /// - /// If `expr`'s value cannot be determined at compile time, return a an - /// error. If it's acceptable to evaluate `expr` at runtime, this error can - /// be ignored, and the caller can append `expr` to the arena itself. + /// If `expr`'s value cannot be determined at compile time, and `self` is + /// contributing to some function's expression arena, then append `expr` to + /// that arena unchanged (and thus unevaluated). Otherwise, `self` must be + /// contributing to the module's constant expression arena; since `expr`'s + /// value is not a constant, return an error. /// /// We only consider `expr` itself, without recursing into its operands. Its /// operands must all have been produced by prior calls to @@ -594,17 +708,82 @@ impl<'a> ConstantEvaluator<'a> { /// [`ZeroValue`]: Expression::ZeroValue /// [`Swizzle`]: Expression::Swizzle pub fn try_eval_and_append( + &mut self, + expr: Expression, + span: Span, + ) -> Result, ConstantEvaluatorError> { + match self.expression_kind_tracker.type_of_with_expr(&expr) { + ExpressionKind::Const => { + let eval_result = self.try_eval_and_append_impl(&expr, span); + // We should be able to evaluate `Const` expressions at this + // point. If we failed to, then that probably means we just + // haven't implemented that part of constant evaluation. Work + // around this by simply emitting it as a run-time expression. + if self.behavior.has_runtime_restrictions() + && matches!( + eval_result, + Err(ConstantEvaluatorError::NotImplemented(_) + | ConstantEvaluatorError::InvalidBinaryOpArgs,) + ) + { + Ok(self.append_expr(expr, span, ExpressionKind::Runtime)) + } else { + eval_result + } + } + ExpressionKind::Override => match self.behavior { + Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => { + Ok(self.append_expr(expr, span, ExpressionKind::Override)) + } + Behavior::Wgsl(WgslRestrictions::Const) => { + Err(ConstantEvaluatorError::OverrideExpr) + } + Behavior::Glsl(_) => { + unreachable!() + } + }, + ExpressionKind::Runtime => { + if self.behavior.has_runtime_restrictions() { + Ok(self.append_expr(expr, span, ExpressionKind::Runtime)) + } else { + Err(ConstantEvaluatorError::RuntimeExpr) + } + } + } + } + + /// Is the [`Self::expressions`] arena the global module expression arena? + const fn is_global_arena(&self) -> bool { + matches!( + self.behavior, + Behavior::Wgsl(WgslRestrictions::Const | WgslRestrictions::Override) + | Behavior::Glsl(GlslRestrictions::Const) + ) + } + + const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> { + match self.behavior { + Behavior::Wgsl(WgslRestrictions::Runtime(ref function_local_data)) + | Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => { + Some(function_local_data) + } + _ => None, + } + } + + fn try_eval_and_append_impl( &mut self, expr: &Expression, span: Span, ) -> Result, ConstantEvaluatorError> { log::trace!("try_eval_and_append: {:?}", expr); match *expr { - Expression::Constant(c) if self.function_local_data.is_none() => { + Expression::Constant(c) if self.is_global_arena() => { // "See through" the constant and use its initializer. // This is mainly done to avoid having constants pointing to other constants. Ok(self.constants[c].init) } + Expression::Override(_) => Err(ConstantEvaluatorError::Override), Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => { self.register_evaluated_expr(expr.clone(), span) } @@ -685,8 +864,8 @@ impl<'a> ConstantEvaluator<'a> { format!("{fun:?} built-in function"), )), Expression::ArrayLength(expr) => match self.behavior { - Behavior::Wgsl => Err(ConstantEvaluatorError::ArrayLength), - Behavior::Glsl => { + Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength), + Behavior::Glsl(_) => { let expr = self.check_and_get(expr)?; self.array_length(expr, span) } @@ -707,6 +886,12 @@ impl<'a> ConstantEvaluator<'a> { Expression::RayQueryProceedResult | Expression::RayQueryGetIntersection { .. } => { Err(ConstantEvaluatorError::RayQueryExpression) } + Expression::SubgroupBallotResult { .. } => { + Err(ConstantEvaluatorError::SubgroupExpression) + } + Expression::SubgroupOperationResult { .. } => { + Err(ConstantEvaluatorError::SubgroupExpression) + } } } @@ -765,10 +950,10 @@ impl<'a> ConstantEvaluator<'a> { pattern: [crate::SwizzleComponent; 4], ) -> Result, ConstantEvaluatorError> { let mut get_dst_ty = |ty| match self.types[ty].inner { - crate::TypeInner::Vector { size: _, scalar } => Ok(self.types.insert( + TypeInner::Vector { size: _, scalar } => Ok(self.types.insert( Type { name: None, - inner: crate::TypeInner::Vector { size, scalar }, + inner: TypeInner::Vector { size, scalar }, }, span, )), @@ -1059,13 +1244,11 @@ impl<'a> ConstantEvaluator<'a> { Expression::ZeroValue(ty) | Expression::Compose { ty, .. } => { match self.types[ty].inner { TypeInner::Array { size, .. } => match size { - crate::ArraySize::Constant(len) => { + ArraySize::Constant(len) => { let expr = Expression::Literal(Literal::U32(len.get())); self.register_evaluated_expr(expr, span) } - crate::ArraySize::Dynamic => { - Err(ConstantEvaluatorError::ArrayLengthDynamic) - } + ArraySize::Dynamic => Err(ConstantEvaluatorError::ArrayLengthDynamic), }, _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg), } @@ -1128,7 +1311,7 @@ impl<'a> ConstantEvaluator<'a> { Expression::ZeroValue(ty) if matches!( self.types[ty].inner, - crate::TypeInner::Scalar(crate::Scalar { + TypeInner::Scalar(crate::Scalar { kind: ScalarKind::Uint, .. }) @@ -1443,7 +1626,7 @@ impl<'a> ConstantEvaluator<'a> { return self.cast(expr, target, span); }; - let crate::TypeInner::Array { + let TypeInner::Array { base: _, size, stride: _, @@ -1853,29 +2036,35 @@ impl<'a> ConstantEvaluator<'a> { crate::valid::check_literal_value(literal)?; } - if let Some(FunctionLocalData { - ref mut emitter, - ref mut block, - ref mut expression_constness, - .. - }) = self.function_local_data - { - let is_running = emitter.is_running(); - let needs_pre_emit = expr.needs_pre_emit(); - if is_running && needs_pre_emit { - block.extend(emitter.finish(self.expressions)); - let h = self.expressions.append(expr, span); - emitter.start(self.expressions); - expression_constness.insert(h); - Ok(h) - } else { - let h = self.expressions.append(expr, span); - expression_constness.insert(h); - Ok(h) + Ok(self.append_expr(expr, span, ExpressionKind::Const)) + } + + fn append_expr( + &mut self, + expr: Expression, + span: Span, + expr_type: ExpressionKind, + ) -> Handle { + let h = match self.behavior { + Behavior::Wgsl(WgslRestrictions::Runtime(ref mut function_local_data)) + | Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => { + let is_running = function_local_data.emitter.is_running(); + let needs_pre_emit = expr.needs_pre_emit(); + if is_running && needs_pre_emit { + function_local_data + .block + .extend(function_local_data.emitter.finish(self.expressions)); + let h = self.expressions.append(expr, span); + function_local_data.emitter.start(self.expressions); + h + } else { + self.expressions.append(expr, span) + } } - } else { - Ok(self.expressions.append(expr, span)) - } + _ => self.expressions.append(expr, span), + }; + self.expression_kind_tracker.insert(h, expr_type); + h } fn resolve_type( @@ -2029,13 +2218,14 @@ mod tests { UniqueArena, VectorSize, }; - use super::{Behavior, ConstantEvaluator}; + use super::{Behavior, ConstantEvaluator, ExpressionKindTracker, WgslRestrictions}; #[test] fn unary_op() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); - let mut const_expressions = Arena::new(); + let overrides = Arena::new(); + let mut global_expressions = Arena::new(); let scalar_ty = types.insert( Type { @@ -2059,9 +2249,8 @@ mod tests { let h = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: scalar_ty, - init: const_expressions + init: global_expressions .append(Expression::Literal(Literal::I32(4)), Default::default()), }, Default::default(), @@ -2070,9 +2259,8 @@ mod tests { let h1 = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: scalar_ty, - init: const_expressions + init: global_expressions .append(Expression::Literal(Literal::I32(8)), Default::default()), }, Default::default(), @@ -2081,9 +2269,8 @@ mod tests { let vec_h = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: vec_ty, - init: const_expressions.append( + init: global_expressions.append( Expression::Compose { ty: vec_ty, components: vec![constants[h].init, constants[h1].init], @@ -2094,8 +2281,8 @@ mod tests { Default::default(), ); - let expr = const_expressions.append(Expression::Constant(h), Default::default()); - let expr1 = const_expressions.append(Expression::Constant(vec_h), Default::default()); + let expr = global_expressions.append(Expression::Constant(h), Default::default()); + let expr1 = global_expressions.append(Expression::Constant(vec_h), Default::default()); let expr2 = Expression::Unary { op: UnaryOperator::Negate, @@ -2112,35 +2299,37 @@ mod tests { expr: expr1, }; + let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions); let mut solver = ConstantEvaluator { - behavior: Behavior::Wgsl, + behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, constants: &constants, - expressions: &mut const_expressions, - function_local_data: None, + overrides: &overrides, + expressions: &mut global_expressions, + expression_kind_tracker, }; let res1 = solver - .try_eval_and_append(&expr2, Default::default()) + .try_eval_and_append(expr2, Default::default()) .unwrap(); let res2 = solver - .try_eval_and_append(&expr3, Default::default()) + .try_eval_and_append(expr3, Default::default()) .unwrap(); let res3 = solver - .try_eval_and_append(&expr4, Default::default()) + .try_eval_and_append(expr4, Default::default()) .unwrap(); assert_eq!( - const_expressions[res1], + global_expressions[res1], Expression::Literal(Literal::I32(-4)) ); assert_eq!( - const_expressions[res2], + global_expressions[res2], Expression::Literal(Literal::I32(!4)) ); - let res3_inner = &const_expressions[res3]; + let res3_inner = &global_expressions[res3]; match *res3_inner { Expression::Compose { @@ -2150,11 +2339,11 @@ mod tests { assert_eq!(*ty, vec_ty); let mut components_iter = components.iter().copied(); assert_eq!( - const_expressions[components_iter.next().unwrap()], + global_expressions[components_iter.next().unwrap()], Expression::Literal(Literal::I32(!4)) ); assert_eq!( - const_expressions[components_iter.next().unwrap()], + global_expressions[components_iter.next().unwrap()], Expression::Literal(Literal::I32(!8)) ); assert!(components_iter.next().is_none()); @@ -2167,7 +2356,8 @@ mod tests { fn cast() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); - let mut const_expressions = Arena::new(); + let overrides = Arena::new(); + let mut global_expressions = Arena::new(); let scalar_ty = types.insert( Type { @@ -2180,15 +2370,14 @@ mod tests { let h = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: scalar_ty, - init: const_expressions + init: global_expressions .append(Expression::Literal(Literal::I32(4)), Default::default()), }, Default::default(), ); - let expr = const_expressions.append(Expression::Constant(h), Default::default()); + let expr = global_expressions.append(Expression::Constant(h), Default::default()); let root = Expression::As { expr, @@ -2196,20 +2385,22 @@ mod tests { convert: Some(crate::BOOL_WIDTH), }; + let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions); let mut solver = ConstantEvaluator { - behavior: Behavior::Wgsl, + behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, constants: &constants, - expressions: &mut const_expressions, - function_local_data: None, + overrides: &overrides, + expressions: &mut global_expressions, + expression_kind_tracker, }; let res = solver - .try_eval_and_append(&root, Default::default()) + .try_eval_and_append(root, Default::default()) .unwrap(); assert_eq!( - const_expressions[res], + global_expressions[res], Expression::Literal(Literal::Bool(true)) ); } @@ -2218,7 +2409,8 @@ mod tests { fn access() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); - let mut const_expressions = Arena::new(); + let overrides = Arena::new(); + let mut global_expressions = Arena::new(); let matrix_ty = types.insert( Type { @@ -2247,7 +2439,7 @@ mod tests { let mut vec2_components = Vec::with_capacity(3); for i in 0..3 { - let h = const_expressions.append( + let h = global_expressions.append( Expression::Literal(Literal::F32(i as f32)), Default::default(), ); @@ -2256,7 +2448,7 @@ mod tests { } for i in 3..6 { - let h = const_expressions.append( + let h = global_expressions.append( Expression::Literal(Literal::F32(i as f32)), Default::default(), ); @@ -2267,9 +2459,8 @@ mod tests { let vec1 = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: vec_ty, - init: const_expressions.append( + init: global_expressions.append( Expression::Compose { ty: vec_ty, components: vec1_components, @@ -2283,9 +2474,8 @@ mod tests { let vec2 = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: vec_ty, - init: const_expressions.append( + init: global_expressions.append( Expression::Compose { ty: vec_ty, components: vec2_components, @@ -2299,9 +2489,8 @@ mod tests { let h = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: matrix_ty, - init: const_expressions.append( + init: global_expressions.append( Expression::Compose { ty: matrix_ty, components: vec![constants[vec1].init, constants[vec2].init], @@ -2312,20 +2501,22 @@ mod tests { Default::default(), ); - let base = const_expressions.append(Expression::Constant(h), Default::default()); + let base = global_expressions.append(Expression::Constant(h), Default::default()); + let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions); let mut solver = ConstantEvaluator { - behavior: Behavior::Wgsl, + behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, constants: &constants, - expressions: &mut const_expressions, - function_local_data: None, + overrides: &overrides, + expressions: &mut global_expressions, + expression_kind_tracker, }; let root1 = Expression::AccessIndex { base, index: 1 }; let res1 = solver - .try_eval_and_append(&root1, Default::default()) + .try_eval_and_append(root1, Default::default()) .unwrap(); let root2 = Expression::AccessIndex { @@ -2334,10 +2525,10 @@ mod tests { }; let res2 = solver - .try_eval_and_append(&root2, Default::default()) + .try_eval_and_append(root2, Default::default()) .unwrap(); - match const_expressions[res1] { + match global_expressions[res1] { Expression::Compose { ref ty, ref components, @@ -2345,15 +2536,15 @@ mod tests { assert_eq!(*ty, vec_ty); let mut components_iter = components.iter().copied(); assert_eq!( - const_expressions[components_iter.next().unwrap()], + global_expressions[components_iter.next().unwrap()], Expression::Literal(Literal::F32(3.)) ); assert_eq!( - const_expressions[components_iter.next().unwrap()], + global_expressions[components_iter.next().unwrap()], Expression::Literal(Literal::F32(4.)) ); assert_eq!( - const_expressions[components_iter.next().unwrap()], + global_expressions[components_iter.next().unwrap()], Expression::Literal(Literal::F32(5.)) ); assert!(components_iter.next().is_none()); @@ -2362,7 +2553,7 @@ mod tests { } assert_eq!( - const_expressions[res2], + global_expressions[res2], Expression::Literal(Literal::F32(5.)) ); } @@ -2371,7 +2562,8 @@ mod tests { fn compose_of_constants() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); - let mut const_expressions = Arena::new(); + let overrides = Arena::new(); + let mut global_expressions = Arena::new(); let i32_ty = types.insert( Type { @@ -2395,27 +2587,28 @@ mod tests { let h = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: i32_ty, - init: const_expressions + init: global_expressions .append(Expression::Literal(Literal::I32(4)), Default::default()), }, Default::default(), ); - let h_expr = const_expressions.append(Expression::Constant(h), Default::default()); + let h_expr = global_expressions.append(Expression::Constant(h), Default::default()); + let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions); let mut solver = ConstantEvaluator { - behavior: Behavior::Wgsl, + behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, constants: &constants, - expressions: &mut const_expressions, - function_local_data: None, + overrides: &overrides, + expressions: &mut global_expressions, + expression_kind_tracker, }; let solved_compose = solver .try_eval_and_append( - &Expression::Compose { + Expression::Compose { ty: vec2_i32_ty, components: vec![h_expr, h_expr], }, @@ -2424,7 +2617,7 @@ mod tests { .unwrap(); let solved_negate = solver .try_eval_and_append( - &Expression::Unary { + Expression::Unary { op: UnaryOperator::Negate, expr: solved_compose, }, @@ -2432,11 +2625,11 @@ mod tests { ) .unwrap(); - let pass = match const_expressions[solved_negate] { + let pass = match global_expressions[solved_negate] { Expression::Compose { ty, ref components } => { ty == vec2_i32_ty && components.iter().all(|&component| { - let component = &const_expressions[component]; + let component = &global_expressions[component]; matches!(*component, Expression::Literal(Literal::I32(-4))) }) } @@ -2451,7 +2644,8 @@ mod tests { fn splat_of_constant() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); - let mut const_expressions = Arena::new(); + let overrides = Arena::new(); + let mut global_expressions = Arena::new(); let i32_ty = types.insert( Type { @@ -2475,27 +2669,28 @@ mod tests { let h = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: i32_ty, - init: const_expressions + init: global_expressions .append(Expression::Literal(Literal::I32(4)), Default::default()), }, Default::default(), ); - let h_expr = const_expressions.append(Expression::Constant(h), Default::default()); + let h_expr = global_expressions.append(Expression::Constant(h), Default::default()); + let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions); let mut solver = ConstantEvaluator { - behavior: Behavior::Wgsl, + behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, constants: &constants, - expressions: &mut const_expressions, - function_local_data: None, + overrides: &overrides, + expressions: &mut global_expressions, + expression_kind_tracker, }; let solved_compose = solver .try_eval_and_append( - &Expression::Splat { + Expression::Splat { size: VectorSize::Bi, value: h_expr, }, @@ -2504,7 +2699,7 @@ mod tests { .unwrap(); let solved_negate = solver .try_eval_and_append( - &Expression::Unary { + Expression::Unary { op: UnaryOperator::Negate, expr: solved_compose, }, @@ -2512,11 +2707,11 @@ mod tests { ) .unwrap(); - let pass = match const_expressions[solved_negate] { + let pass = match global_expressions[solved_negate] { Expression::Compose { ty, ref components } => { ty == vec2_i32_ty && components.iter().all(|&component| { - let component = &const_expressions[component]; + let component = &global_expressions[component]; matches!(*component, Expression::Literal(Literal::I32(-4))) }) } -- cgit v1.2.3