summaryrefslogtreecommitdiffstats
path: root/third_party/rust/naga/src/proc/constant_evaluator.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/naga/src/proc/constant_evaluator.rs')
-rw-r--r--third_party/rust/naga/src/proc/constant_evaluator.rs589
1 files changed, 392 insertions, 197 deletions
diff --git a/third_party/rust/naga/src/proc/constant_evaluator.rs b/third_party/rust/naga/src/proc/constant_evaluator.rs
index 983af3718c..ead3d00980 100644
--- a/third_party/rust/naga/src/proc/constant_evaluator.rs
+++ b/third_party/rust/naga/src/proc/constant_evaluator.rs
@@ -4,8 +4,8 @@ use arrayvec::ArrayVec;
use crate::{
arena::{Arena, Handle, UniqueArena},
- ArraySize, BinaryOperator, Constant, Expression, Literal, ScalarKind, Span, Type, TypeInner,
- UnaryOperator,
+ ArraySize, BinaryOperator, Constant, Expression, Literal, Override, ScalarKind, Span, Type,
+ TypeInner, UnaryOperator,
};
/// A macro that allows dollar signs (`$`) to be emitted by other macros. Useful for generating
@@ -253,9 +253,20 @@ gen_component_wise_extractor! {
}
#[derive(Debug)]
-enum Behavior {
- Wgsl,
- Glsl,
+enum Behavior<'a> {
+ Wgsl(WgslRestrictions<'a>),
+ Glsl(GlslRestrictions<'a>),
+}
+
+impl Behavior<'_> {
+ /// Returns `true` if the inner WGSL/GLSL restrictions are runtime restrictions.
+ const fn has_runtime_restrictions(&self) -> bool {
+ matches!(
+ self,
+ &Behavior::Wgsl(WgslRestrictions::Runtime(_))
+ | &Behavior::Glsl(GlslRestrictions::Runtime(_))
+ )
+ }
}
/// A context for evaluating constant expressions.
@@ -278,7 +289,7 @@ enum Behavior {
#[derive(Debug)]
pub struct ConstantEvaluator<'a> {
/// Which language's evaluation rules we should follow.
- behavior: Behavior,
+ behavior: Behavior<'a>,
/// The module's type arena.
///
@@ -291,71 +302,155 @@ pub struct ConstantEvaluator<'a> {
/// The module's constant arena.
constants: &'a Arena<Constant>,
+ /// The module's override arena.
+ overrides: &'a Arena<Override>,
+
/// The arena to which we are contributing expressions.
expressions: &'a mut Arena<Expression>,
- /// When `self.expressions` refers to a function's local expression
- /// arena, this needs to be populated
- function_local_data: Option<FunctionLocalData<'a>>,
+ /// Tracks the constness of expressions residing in [`Self::expressions`]
+ expression_kind_tracker: &'a mut ExpressionKindTracker,
+}
+
+#[derive(Debug)]
+enum WgslRestrictions<'a> {
+ /// - const-expressions will be evaluated and inserted in the arena
+ Const,
+ /// - const-expressions will be evaluated and inserted in the arena
+ /// - override-expressions will be inserted in the arena
+ Override,
+ /// - const-expressions will be evaluated and inserted in the arena
+ /// - override-expressions will be inserted in the arena
+ /// - runtime-expressions will be inserted in the arena
+ Runtime(FunctionLocalData<'a>),
+}
+
+#[derive(Debug)]
+enum GlslRestrictions<'a> {
+ /// - const-expressions will be evaluated and inserted in the arena
+ Const,
+ /// - const-expressions will be evaluated and inserted in the arena
+ /// - override-expressions will be inserted in the arena
+ /// - runtime-expressions will be inserted in the arena
+ Runtime(FunctionLocalData<'a>),
}
#[derive(Debug)]
struct FunctionLocalData<'a> {
/// Global constant expressions
- const_expressions: &'a Arena<Expression>,
- /// Tracks the constness of expressions residing in `ConstantEvaluator.expressions`
- expression_constness: &'a mut ExpressionConstnessTracker,
+ global_expressions: &'a Arena<Expression>,
emitter: &'a mut super::Emitter,
block: &'a mut crate::Block,
}
+#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
+pub enum ExpressionKind {
+ Const,
+ Override,
+ Runtime,
+}
+
#[derive(Debug)]
-pub struct ExpressionConstnessTracker {
- inner: bit_set::BitSet,
+pub struct ExpressionKindTracker {
+ inner: Vec<ExpressionKind>,
}
-impl ExpressionConstnessTracker {
- pub fn new() -> Self {
- Self {
- inner: bit_set::BitSet::new(),
- }
+impl ExpressionKindTracker {
+ pub const fn new() -> Self {
+ Self { inner: Vec::new() }
}
/// Forces the the expression to not be const
pub fn force_non_const(&mut self, value: Handle<Expression>) {
- self.inner.remove(value.index());
+ self.inner[value.index()] = ExpressionKind::Runtime;
}
- fn insert(&mut self, value: Handle<Expression>) {
- self.inner.insert(value.index());
+ pub fn insert(&mut self, value: Handle<Expression>, expr_type: ExpressionKind) {
+ assert_eq!(self.inner.len(), value.index());
+ self.inner.push(expr_type);
+ }
+ pub fn is_const(&self, h: Handle<Expression>) -> bool {
+ matches!(self.type_of(h), ExpressionKind::Const)
+ }
+
+ pub fn is_const_or_override(&self, h: Handle<Expression>) -> bool {
+ matches!(
+ self.type_of(h),
+ ExpressionKind::Const | ExpressionKind::Override
+ )
}
- pub fn is_const(&self, value: Handle<Expression>) -> bool {
- self.inner.contains(value.index())
+ fn type_of(&self, value: Handle<Expression>) -> ExpressionKind {
+ self.inner[value.index()]
}
pub fn from_arena(arena: &Arena<Expression>) -> Self {
- let mut tracker = Self::new();
- for (handle, expr) in arena.iter() {
- let insert = match *expr {
- crate::Expression::Literal(_)
- | crate::Expression::ZeroValue(_)
- | crate::Expression::Constant(_) => true,
- crate::Expression::Compose { ref components, .. } => {
- components.iter().all(|h| tracker.is_const(*h))
- }
- crate::Expression::Splat { value, .. } => tracker.is_const(value),
- _ => false,
- };
- if insert {
- tracker.insert(handle);
- }
+ let mut tracker = Self {
+ inner: Vec::with_capacity(arena.len()),
+ };
+ for (_, expr) in arena.iter() {
+ tracker.inner.push(tracker.type_of_with_expr(expr));
}
tracker
}
+
+ fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind {
+ match *expr {
+ Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
+ ExpressionKind::Const
+ }
+ Expression::Override(_) => ExpressionKind::Override,
+ Expression::Compose { ref components, .. } => {
+ let mut expr_type = ExpressionKind::Const;
+ for component in components {
+ expr_type = expr_type.max(self.type_of(*component))
+ }
+ expr_type
+ }
+ Expression::Splat { value, .. } => self.type_of(value),
+ Expression::AccessIndex { base, .. } => self.type_of(base),
+ Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)),
+ Expression::Swizzle { vector, .. } => self.type_of(vector),
+ Expression::Unary { expr, .. } => self.type_of(expr),
+ Expression::Binary { left, right, .. } => self.type_of(left).max(self.type_of(right)),
+ Expression::Math {
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ ..
+ } => self
+ .type_of(arg)
+ .max(
+ arg1.map(|arg| self.type_of(arg))
+ .unwrap_or(ExpressionKind::Const),
+ )
+ .max(
+ arg2.map(|arg| self.type_of(arg))
+ .unwrap_or(ExpressionKind::Const),
+ )
+ .max(
+ arg3.map(|arg| self.type_of(arg))
+ .unwrap_or(ExpressionKind::Const),
+ ),
+ Expression::As { expr, .. } => self.type_of(expr),
+ Expression::Select {
+ condition,
+ accept,
+ reject,
+ } => self
+ .type_of(condition)
+ .max(self.type_of(accept))
+ .max(self.type_of(reject)),
+ Expression::Relational { argument, .. } => self.type_of(argument),
+ Expression::ArrayLength(expr) => self.type_of(expr),
+ _ => ExpressionKind::Runtime,
+ }
+ }
}
#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
pub enum ConstantEvaluatorError {
#[error("Constants cannot access function arguments")]
FunctionArg,
@@ -381,6 +476,8 @@ pub enum ConstantEvaluatorError {
ImageExpression,
#[error("Constants don't support ray query expressions")]
RayQueryExpression,
+ #[error("Constants don't support subgroup expressions")]
+ SubgroupExpression,
#[error("Cannot access the type")]
InvalidAccessBase,
#[error("Cannot access at the index")]
@@ -432,6 +529,12 @@ pub enum ConstantEvaluatorError {
ShiftedMoreThan32Bits,
#[error(transparent)]
Literal(#[from] crate::valid::LiteralError),
+ #[error("Can't use pipeline-overridable constants in const-expressions")]
+ Override,
+ #[error("Unexpected runtime-expression")]
+ RuntimeExpr,
+ #[error("Unexpected override-expression")]
+ OverrideExpr,
}
impl<'a> ConstantEvaluator<'a> {
@@ -439,25 +542,49 @@ impl<'a> ConstantEvaluator<'a> {
/// constant expression arena.
///
/// Report errors according to WGSL's rules for constant evaluation.
- pub fn for_wgsl_module(module: &'a mut crate::Module) -> Self {
- Self::for_module(Behavior::Wgsl, module)
+ pub fn for_wgsl_module(
+ module: &'a mut crate::Module,
+ global_expression_kind_tracker: &'a mut ExpressionKindTracker,
+ in_override_ctx: bool,
+ ) -> Self {
+ Self::for_module(
+ Behavior::Wgsl(if in_override_ctx {
+ WgslRestrictions::Override
+ } else {
+ WgslRestrictions::Const
+ }),
+ module,
+ global_expression_kind_tracker,
+ )
}
/// Return a [`ConstantEvaluator`] that will add expressions to `module`'s
/// constant expression arena.
///
/// Report errors according to GLSL's rules for constant evaluation.
- pub fn for_glsl_module(module: &'a mut crate::Module) -> Self {
- Self::for_module(Behavior::Glsl, module)
+ pub fn for_glsl_module(
+ module: &'a mut crate::Module,
+ global_expression_kind_tracker: &'a mut ExpressionKindTracker,
+ ) -> Self {
+ Self::for_module(
+ Behavior::Glsl(GlslRestrictions::Const),
+ module,
+ global_expression_kind_tracker,
+ )
}
- fn for_module(behavior: Behavior, module: &'a mut crate::Module) -> Self {
+ fn for_module(
+ behavior: Behavior<'a>,
+ module: &'a mut crate::Module,
+ global_expression_kind_tracker: &'a mut ExpressionKindTracker,
+ ) -> Self {
Self {
behavior,
types: &mut module.types,
constants: &module.constants,
- expressions: &mut module.const_expressions,
- function_local_data: None,
+ overrides: &module.overrides,
+ expressions: &mut module.global_expressions,
+ expression_kind_tracker: global_expression_kind_tracker,
}
}
@@ -468,18 +595,22 @@ impl<'a> ConstantEvaluator<'a> {
pub fn for_wgsl_function(
module: &'a mut crate::Module,
expressions: &'a mut Arena<Expression>,
- expression_constness: &'a mut ExpressionConstnessTracker,
+ local_expression_kind_tracker: &'a mut ExpressionKindTracker,
emitter: &'a mut super::Emitter,
block: &'a mut crate::Block,
) -> Self {
- Self::for_function(
- Behavior::Wgsl,
- module,
+ Self {
+ behavior: Behavior::Wgsl(WgslRestrictions::Runtime(FunctionLocalData {
+ global_expressions: &module.global_expressions,
+ emitter,
+ block,
+ })),
+ types: &mut module.types,
+ constants: &module.constants,
+ overrides: &module.overrides,
expressions,
- expression_constness,
- emitter,
- block,
- )
+ expression_kind_tracker: local_expression_kind_tracker,
+ }
}
/// Return a [`ConstantEvaluator`] that will add expressions to `function`'s
@@ -489,39 +620,21 @@ impl<'a> ConstantEvaluator<'a> {
pub fn for_glsl_function(
module: &'a mut crate::Module,
expressions: &'a mut Arena<Expression>,
- expression_constness: &'a mut ExpressionConstnessTracker,
- emitter: &'a mut super::Emitter,
- block: &'a mut crate::Block,
- ) -> Self {
- Self::for_function(
- Behavior::Glsl,
- module,
- expressions,
- expression_constness,
- emitter,
- block,
- )
- }
-
- fn for_function(
- behavior: Behavior,
- module: &'a mut crate::Module,
- expressions: &'a mut Arena<Expression>,
- expression_constness: &'a mut ExpressionConstnessTracker,
+ local_expression_kind_tracker: &'a mut ExpressionKindTracker,
emitter: &'a mut super::Emitter,
block: &'a mut crate::Block,
) -> Self {
Self {
- behavior,
+ behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData {
+ global_expressions: &module.global_expressions,
+ emitter,
+ block,
+ })),
types: &mut module.types,
constants: &module.constants,
+ overrides: &module.overrides,
expressions,
- function_local_data: Some(FunctionLocalData {
- const_expressions: &module.const_expressions,
- expression_constness,
- emitter,
- block,
- }),
+ expression_kind_tracker: local_expression_kind_tracker,
}
}
@@ -529,19 +642,18 @@ impl<'a> ConstantEvaluator<'a> {
crate::proc::GlobalCtx {
types: self.types,
constants: self.constants,
- const_expressions: match self.function_local_data {
- Some(ref data) => data.const_expressions,
+ overrides: self.overrides,
+ global_expressions: match self.function_local_data() {
+ Some(data) => data.global_expressions,
None => self.expressions,
},
}
}
fn check(&self, expr: Handle<Expression>) -> Result<(), ConstantEvaluatorError> {
- if let Some(ref function_local_data) = self.function_local_data {
- if !function_local_data.expression_constness.is_const(expr) {
- log::debug!("check: SubexpressionsAreNotConstant");
- return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
- }
+ if !self.expression_kind_tracker.is_const(expr) {
+ log::debug!("check: SubexpressionsAreNotConstant");
+ return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
}
Ok(())
}
@@ -554,11 +666,11 @@ impl<'a> ConstantEvaluator<'a> {
Expression::Constant(c) => {
// Are we working in a function's expression arena, or the
// module's constant expression arena?
- if let Some(ref function_local_data) = self.function_local_data {
+ if let Some(function_local_data) = self.function_local_data() {
// Deep-copy the constant's value into our arena.
self.copy_from(
self.constants[c].init,
- function_local_data.const_expressions,
+ function_local_data.global_expressions,
)
} else {
// "See through" the constant and use its initializer.
@@ -580,9 +692,11 @@ impl<'a> ConstantEvaluator<'a> {
/// [`ZeroValue`], and [`Swizzle`] expressions - to the expression arena
/// `self` contributes to.
///
- /// If `expr`'s value cannot be determined at compile time, return a an
- /// error. If it's acceptable to evaluate `expr` at runtime, this error can
- /// be ignored, and the caller can append `expr` to the arena itself.
+ /// If `expr`'s value cannot be determined at compile time, and `self` is
+ /// contributing to some function's expression arena, then append `expr` to
+ /// that arena unchanged (and thus unevaluated). Otherwise, `self` must be
+ /// contributing to the module's constant expression arena; since `expr`'s
+ /// value is not a constant, return an error.
///
/// We only consider `expr` itself, without recursing into its operands. Its
/// operands must all have been produced by prior calls to
@@ -595,16 +709,81 @@ impl<'a> ConstantEvaluator<'a> {
/// [`Swizzle`]: Expression::Swizzle
pub fn try_eval_and_append(
&mut self,
+ expr: Expression,
+ span: Span,
+ ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
+ match self.expression_kind_tracker.type_of_with_expr(&expr) {
+ ExpressionKind::Const => {
+ let eval_result = self.try_eval_and_append_impl(&expr, span);
+ // We should be able to evaluate `Const` expressions at this
+ // point. If we failed to, then that probably means we just
+ // haven't implemented that part of constant evaluation. Work
+ // around this by simply emitting it as a run-time expression.
+ if self.behavior.has_runtime_restrictions()
+ && matches!(
+ eval_result,
+ Err(ConstantEvaluatorError::NotImplemented(_)
+ | ConstantEvaluatorError::InvalidBinaryOpArgs,)
+ )
+ {
+ Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
+ } else {
+ eval_result
+ }
+ }
+ ExpressionKind::Override => match self.behavior {
+ Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => {
+ Ok(self.append_expr(expr, span, ExpressionKind::Override))
+ }
+ Behavior::Wgsl(WgslRestrictions::Const) => {
+ Err(ConstantEvaluatorError::OverrideExpr)
+ }
+ Behavior::Glsl(_) => {
+ unreachable!()
+ }
+ },
+ ExpressionKind::Runtime => {
+ if self.behavior.has_runtime_restrictions() {
+ Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
+ } else {
+ Err(ConstantEvaluatorError::RuntimeExpr)
+ }
+ }
+ }
+ }
+
+ /// Is the [`Self::expressions`] arena the global module expression arena?
+ const fn is_global_arena(&self) -> bool {
+ matches!(
+ self.behavior,
+ Behavior::Wgsl(WgslRestrictions::Const | WgslRestrictions::Override)
+ | Behavior::Glsl(GlslRestrictions::Const)
+ )
+ }
+
+ const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> {
+ match self.behavior {
+ Behavior::Wgsl(WgslRestrictions::Runtime(ref function_local_data))
+ | Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => {
+ Some(function_local_data)
+ }
+ _ => None,
+ }
+ }
+
+ fn try_eval_and_append_impl(
+ &mut self,
expr: &Expression,
span: Span,
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
log::trace!("try_eval_and_append: {:?}", expr);
match *expr {
- Expression::Constant(c) if self.function_local_data.is_none() => {
+ Expression::Constant(c) if self.is_global_arena() => {
// "See through" the constant and use its initializer.
// This is mainly done to avoid having constants pointing to other constants.
Ok(self.constants[c].init)
}
+ Expression::Override(_) => Err(ConstantEvaluatorError::Override),
Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
self.register_evaluated_expr(expr.clone(), span)
}
@@ -685,8 +864,8 @@ impl<'a> ConstantEvaluator<'a> {
format!("{fun:?} built-in function"),
)),
Expression::ArrayLength(expr) => match self.behavior {
- Behavior::Wgsl => Err(ConstantEvaluatorError::ArrayLength),
- Behavior::Glsl => {
+ Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
+ Behavior::Glsl(_) => {
let expr = self.check_and_get(expr)?;
self.array_length(expr, span)
}
@@ -707,6 +886,12 @@ impl<'a> ConstantEvaluator<'a> {
Expression::RayQueryProceedResult | Expression::RayQueryGetIntersection { .. } => {
Err(ConstantEvaluatorError::RayQueryExpression)
}
+ Expression::SubgroupBallotResult { .. } => {
+ Err(ConstantEvaluatorError::SubgroupExpression)
+ }
+ Expression::SubgroupOperationResult { .. } => {
+ Err(ConstantEvaluatorError::SubgroupExpression)
+ }
}
}
@@ -765,10 +950,10 @@ impl<'a> ConstantEvaluator<'a> {
pattern: [crate::SwizzleComponent; 4],
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
let mut get_dst_ty = |ty| match self.types[ty].inner {
- crate::TypeInner::Vector { size: _, scalar } => Ok(self.types.insert(
+ TypeInner::Vector { size: _, scalar } => Ok(self.types.insert(
Type {
name: None,
- inner: crate::TypeInner::Vector { size, scalar },
+ inner: TypeInner::Vector { size, scalar },
},
span,
)),
@@ -1059,13 +1244,11 @@ impl<'a> ConstantEvaluator<'a> {
Expression::ZeroValue(ty) | Expression::Compose { ty, .. } => {
match self.types[ty].inner {
TypeInner::Array { size, .. } => match size {
- crate::ArraySize::Constant(len) => {
+ ArraySize::Constant(len) => {
let expr = Expression::Literal(Literal::U32(len.get()));
self.register_evaluated_expr(expr, span)
}
- crate::ArraySize::Dynamic => {
- Err(ConstantEvaluatorError::ArrayLengthDynamic)
- }
+ ArraySize::Dynamic => Err(ConstantEvaluatorError::ArrayLengthDynamic),
},
_ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
}
@@ -1128,7 +1311,7 @@ impl<'a> ConstantEvaluator<'a> {
Expression::ZeroValue(ty)
if matches!(
self.types[ty].inner,
- crate::TypeInner::Scalar(crate::Scalar {
+ TypeInner::Scalar(crate::Scalar {
kind: ScalarKind::Uint,
..
})
@@ -1443,7 +1626,7 @@ impl<'a> ConstantEvaluator<'a> {
return self.cast(expr, target, span);
};
- let crate::TypeInner::Array {
+ let TypeInner::Array {
base: _,
size,
stride: _,
@@ -1853,29 +2036,35 @@ impl<'a> ConstantEvaluator<'a> {
crate::valid::check_literal_value(literal)?;
}
- if let Some(FunctionLocalData {
- ref mut emitter,
- ref mut block,
- ref mut expression_constness,
- ..
- }) = self.function_local_data
- {
- let is_running = emitter.is_running();
- let needs_pre_emit = expr.needs_pre_emit();
- if is_running && needs_pre_emit {
- block.extend(emitter.finish(self.expressions));
- let h = self.expressions.append(expr, span);
- emitter.start(self.expressions);
- expression_constness.insert(h);
- Ok(h)
- } else {
- let h = self.expressions.append(expr, span);
- expression_constness.insert(h);
- Ok(h)
+ Ok(self.append_expr(expr, span, ExpressionKind::Const))
+ }
+
+ fn append_expr(
+ &mut self,
+ expr: Expression,
+ span: Span,
+ expr_type: ExpressionKind,
+ ) -> Handle<Expression> {
+ let h = match self.behavior {
+ Behavior::Wgsl(WgslRestrictions::Runtime(ref mut function_local_data))
+ | Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => {
+ let is_running = function_local_data.emitter.is_running();
+ let needs_pre_emit = expr.needs_pre_emit();
+ if is_running && needs_pre_emit {
+ function_local_data
+ .block
+ .extend(function_local_data.emitter.finish(self.expressions));
+ let h = self.expressions.append(expr, span);
+ function_local_data.emitter.start(self.expressions);
+ h
+ } else {
+ self.expressions.append(expr, span)
+ }
}
- } else {
- Ok(self.expressions.append(expr, span))
- }
+ _ => self.expressions.append(expr, span),
+ };
+ self.expression_kind_tracker.insert(h, expr_type);
+ h
}
fn resolve_type(
@@ -2029,13 +2218,14 @@ mod tests {
UniqueArena, VectorSize,
};
- use super::{Behavior, ConstantEvaluator};
+ use super::{Behavior, ConstantEvaluator, ExpressionKindTracker, WgslRestrictions};
#[test]
fn unary_op() {
let mut types = UniqueArena::new();
let mut constants = Arena::new();
- let mut const_expressions = Arena::new();
+ let overrides = Arena::new();
+ let mut global_expressions = Arena::new();
let scalar_ty = types.insert(
Type {
@@ -2059,9 +2249,8 @@ mod tests {
let h = constants.append(
Constant {
name: None,
- r#override: crate::Override::None,
ty: scalar_ty,
- init: const_expressions
+ init: global_expressions
.append(Expression::Literal(Literal::I32(4)), Default::default()),
},
Default::default(),
@@ -2070,9 +2259,8 @@ mod tests {
let h1 = constants.append(
Constant {
name: None,
- r#override: crate::Override::None,
ty: scalar_ty,
- init: const_expressions
+ init: global_expressions
.append(Expression::Literal(Literal::I32(8)), Default::default()),
},
Default::default(),
@@ -2081,9 +2269,8 @@ mod tests {
let vec_h = constants.append(
Constant {
name: None,
- r#override: crate::Override::None,
ty: vec_ty,
- init: const_expressions.append(
+ init: global_expressions.append(
Expression::Compose {
ty: vec_ty,
components: vec![constants[h].init, constants[h1].init],
@@ -2094,8 +2281,8 @@ mod tests {
Default::default(),
);
- let expr = const_expressions.append(Expression::Constant(h), Default::default());
- let expr1 = const_expressions.append(Expression::Constant(vec_h), Default::default());
+ let expr = global_expressions.append(Expression::Constant(h), Default::default());
+ let expr1 = global_expressions.append(Expression::Constant(vec_h), Default::default());
let expr2 = Expression::Unary {
op: UnaryOperator::Negate,
@@ -2112,35 +2299,37 @@ mod tests {
expr: expr1,
};
+ let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
let mut solver = ConstantEvaluator {
- behavior: Behavior::Wgsl,
+ behavior: Behavior::Wgsl(WgslRestrictions::Const),
types: &mut types,
constants: &constants,
- expressions: &mut const_expressions,
- function_local_data: None,
+ overrides: &overrides,
+ expressions: &mut global_expressions,
+ expression_kind_tracker,
};
let res1 = solver
- .try_eval_and_append(&expr2, Default::default())
+ .try_eval_and_append(expr2, Default::default())
.unwrap();
let res2 = solver
- .try_eval_and_append(&expr3, Default::default())
+ .try_eval_and_append(expr3, Default::default())
.unwrap();
let res3 = solver
- .try_eval_and_append(&expr4, Default::default())
+ .try_eval_and_append(expr4, Default::default())
.unwrap();
assert_eq!(
- const_expressions[res1],
+ global_expressions[res1],
Expression::Literal(Literal::I32(-4))
);
assert_eq!(
- const_expressions[res2],
+ global_expressions[res2],
Expression::Literal(Literal::I32(!4))
);
- let res3_inner = &const_expressions[res3];
+ let res3_inner = &global_expressions[res3];
match *res3_inner {
Expression::Compose {
@@ -2150,11 +2339,11 @@ mod tests {
assert_eq!(*ty, vec_ty);
let mut components_iter = components.iter().copied();
assert_eq!(
- const_expressions[components_iter.next().unwrap()],
+ global_expressions[components_iter.next().unwrap()],
Expression::Literal(Literal::I32(!4))
);
assert_eq!(
- const_expressions[components_iter.next().unwrap()],
+ global_expressions[components_iter.next().unwrap()],
Expression::Literal(Literal::I32(!8))
);
assert!(components_iter.next().is_none());
@@ -2167,7 +2356,8 @@ mod tests {
fn cast() {
let mut types = UniqueArena::new();
let mut constants = Arena::new();
- let mut const_expressions = Arena::new();
+ let overrides = Arena::new();
+ let mut global_expressions = Arena::new();
let scalar_ty = types.insert(
Type {
@@ -2180,15 +2370,14 @@ mod tests {
let h = constants.append(
Constant {
name: None,
- r#override: crate::Override::None,
ty: scalar_ty,
- init: const_expressions
+ init: global_expressions
.append(Expression::Literal(Literal::I32(4)), Default::default()),
},
Default::default(),
);
- let expr = const_expressions.append(Expression::Constant(h), Default::default());
+ let expr = global_expressions.append(Expression::Constant(h), Default::default());
let root = Expression::As {
expr,
@@ -2196,20 +2385,22 @@ mod tests {
convert: Some(crate::BOOL_WIDTH),
};
+ let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
let mut solver = ConstantEvaluator {
- behavior: Behavior::Wgsl,
+ behavior: Behavior::Wgsl(WgslRestrictions::Const),
types: &mut types,
constants: &constants,
- expressions: &mut const_expressions,
- function_local_data: None,
+ overrides: &overrides,
+ expressions: &mut global_expressions,
+ expression_kind_tracker,
};
let res = solver
- .try_eval_and_append(&root, Default::default())
+ .try_eval_and_append(root, Default::default())
.unwrap();
assert_eq!(
- const_expressions[res],
+ global_expressions[res],
Expression::Literal(Literal::Bool(true))
);
}
@@ -2218,7 +2409,8 @@ mod tests {
fn access() {
let mut types = UniqueArena::new();
let mut constants = Arena::new();
- let mut const_expressions = Arena::new();
+ let overrides = Arena::new();
+ let mut global_expressions = Arena::new();
let matrix_ty = types.insert(
Type {
@@ -2247,7 +2439,7 @@ mod tests {
let mut vec2_components = Vec::with_capacity(3);
for i in 0..3 {
- let h = const_expressions.append(
+ let h = global_expressions.append(
Expression::Literal(Literal::F32(i as f32)),
Default::default(),
);
@@ -2256,7 +2448,7 @@ mod tests {
}
for i in 3..6 {
- let h = const_expressions.append(
+ let h = global_expressions.append(
Expression::Literal(Literal::F32(i as f32)),
Default::default(),
);
@@ -2267,9 +2459,8 @@ mod tests {
let vec1 = constants.append(
Constant {
name: None,
- r#override: crate::Override::None,
ty: vec_ty,
- init: const_expressions.append(
+ init: global_expressions.append(
Expression::Compose {
ty: vec_ty,
components: vec1_components,
@@ -2283,9 +2474,8 @@ mod tests {
let vec2 = constants.append(
Constant {
name: None,
- r#override: crate::Override::None,
ty: vec_ty,
- init: const_expressions.append(
+ init: global_expressions.append(
Expression::Compose {
ty: vec_ty,
components: vec2_components,
@@ -2299,9 +2489,8 @@ mod tests {
let h = constants.append(
Constant {
name: None,
- r#override: crate::Override::None,
ty: matrix_ty,
- init: const_expressions.append(
+ init: global_expressions.append(
Expression::Compose {
ty: matrix_ty,
components: vec![constants[vec1].init, constants[vec2].init],
@@ -2312,20 +2501,22 @@ mod tests {
Default::default(),
);
- let base = const_expressions.append(Expression::Constant(h), Default::default());
+ let base = global_expressions.append(Expression::Constant(h), Default::default());
+ let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
let mut solver = ConstantEvaluator {
- behavior: Behavior::Wgsl,
+ behavior: Behavior::Wgsl(WgslRestrictions::Const),
types: &mut types,
constants: &constants,
- expressions: &mut const_expressions,
- function_local_data: None,
+ overrides: &overrides,
+ expressions: &mut global_expressions,
+ expression_kind_tracker,
};
let root1 = Expression::AccessIndex { base, index: 1 };
let res1 = solver
- .try_eval_and_append(&root1, Default::default())
+ .try_eval_and_append(root1, Default::default())
.unwrap();
let root2 = Expression::AccessIndex {
@@ -2334,10 +2525,10 @@ mod tests {
};
let res2 = solver
- .try_eval_and_append(&root2, Default::default())
+ .try_eval_and_append(root2, Default::default())
.unwrap();
- match const_expressions[res1] {
+ match global_expressions[res1] {
Expression::Compose {
ref ty,
ref components,
@@ -2345,15 +2536,15 @@ mod tests {
assert_eq!(*ty, vec_ty);
let mut components_iter = components.iter().copied();
assert_eq!(
- const_expressions[components_iter.next().unwrap()],
+ global_expressions[components_iter.next().unwrap()],
Expression::Literal(Literal::F32(3.))
);
assert_eq!(
- const_expressions[components_iter.next().unwrap()],
+ global_expressions[components_iter.next().unwrap()],
Expression::Literal(Literal::F32(4.))
);
assert_eq!(
- const_expressions[components_iter.next().unwrap()],
+ global_expressions[components_iter.next().unwrap()],
Expression::Literal(Literal::F32(5.))
);
assert!(components_iter.next().is_none());
@@ -2362,7 +2553,7 @@ mod tests {
}
assert_eq!(
- const_expressions[res2],
+ global_expressions[res2],
Expression::Literal(Literal::F32(5.))
);
}
@@ -2371,7 +2562,8 @@ mod tests {
fn compose_of_constants() {
let mut types = UniqueArena::new();
let mut constants = Arena::new();
- let mut const_expressions = Arena::new();
+ let overrides = Arena::new();
+ let mut global_expressions = Arena::new();
let i32_ty = types.insert(
Type {
@@ -2395,27 +2587,28 @@ mod tests {
let h = constants.append(
Constant {
name: None,
- r#override: crate::Override::None,
ty: i32_ty,
- init: const_expressions
+ init: global_expressions
.append(Expression::Literal(Literal::I32(4)), Default::default()),
},
Default::default(),
);
- let h_expr = const_expressions.append(Expression::Constant(h), Default::default());
+ let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
+ let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
let mut solver = ConstantEvaluator {
- behavior: Behavior::Wgsl,
+ behavior: Behavior::Wgsl(WgslRestrictions::Const),
types: &mut types,
constants: &constants,
- expressions: &mut const_expressions,
- function_local_data: None,
+ overrides: &overrides,
+ expressions: &mut global_expressions,
+ expression_kind_tracker,
};
let solved_compose = solver
.try_eval_and_append(
- &Expression::Compose {
+ Expression::Compose {
ty: vec2_i32_ty,
components: vec![h_expr, h_expr],
},
@@ -2424,7 +2617,7 @@ mod tests {
.unwrap();
let solved_negate = solver
.try_eval_and_append(
- &Expression::Unary {
+ Expression::Unary {
op: UnaryOperator::Negate,
expr: solved_compose,
},
@@ -2432,11 +2625,11 @@ mod tests {
)
.unwrap();
- let pass = match const_expressions[solved_negate] {
+ let pass = match global_expressions[solved_negate] {
Expression::Compose { ty, ref components } => {
ty == vec2_i32_ty
&& components.iter().all(|&component| {
- let component = &const_expressions[component];
+ let component = &global_expressions[component];
matches!(*component, Expression::Literal(Literal::I32(-4)))
})
}
@@ -2451,7 +2644,8 @@ mod tests {
fn splat_of_constant() {
let mut types = UniqueArena::new();
let mut constants = Arena::new();
- let mut const_expressions = Arena::new();
+ let overrides = Arena::new();
+ let mut global_expressions = Arena::new();
let i32_ty = types.insert(
Type {
@@ -2475,27 +2669,28 @@ mod tests {
let h = constants.append(
Constant {
name: None,
- r#override: crate::Override::None,
ty: i32_ty,
- init: const_expressions
+ init: global_expressions
.append(Expression::Literal(Literal::I32(4)), Default::default()),
},
Default::default(),
);
- let h_expr = const_expressions.append(Expression::Constant(h), Default::default());
+ let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
+ let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
let mut solver = ConstantEvaluator {
- behavior: Behavior::Wgsl,
+ behavior: Behavior::Wgsl(WgslRestrictions::Const),
types: &mut types,
constants: &constants,
- expressions: &mut const_expressions,
- function_local_data: None,
+ overrides: &overrides,
+ expressions: &mut global_expressions,
+ expression_kind_tracker,
};
let solved_compose = solver
.try_eval_and_append(
- &Expression::Splat {
+ Expression::Splat {
size: VectorSize::Bi,
value: h_expr,
},
@@ -2504,7 +2699,7 @@ mod tests {
.unwrap();
let solved_negate = solver
.try_eval_and_append(
- &Expression::Unary {
+ Expression::Unary {
op: UnaryOperator::Negate,
expr: solved_compose,
},
@@ -2512,11 +2707,11 @@ mod tests {
)
.unwrap();
- let pass = match const_expressions[solved_negate] {
+ let pass = match global_expressions[solved_negate] {
Expression::Compose { ty, ref components } => {
ty == vec2_i32_ty
&& components.iter().all(|&component| {
- let component = &const_expressions[component];
+ let component = &global_expressions[component];
matches!(*component, Expression::Literal(Literal::I32(-4)))
})
}