diff options
Diffstat (limited to 'third_party/rust/naga/src/proc')
-rw-r--r-- | third_party/rust/naga/src/proc/constant_evaluator.rs | 589 | ||||
-rw-r--r-- | third_party/rust/naga/src/proc/index.rs | 4 | ||||
-rw-r--r-- | third_party/rust/naga/src/proc/mod.rs | 105 | ||||
-rw-r--r-- | third_party/rust/naga/src/proc/terminator.rs | 3 | ||||
-rw-r--r-- | third_party/rust/naga/src/proc/typifier.rs | 8 |
5 files changed, 445 insertions, 264 deletions
diff --git a/third_party/rust/naga/src/proc/constant_evaluator.rs b/third_party/rust/naga/src/proc/constant_evaluator.rs index 983af3718c..ead3d00980 100644 --- a/third_party/rust/naga/src/proc/constant_evaluator.rs +++ b/third_party/rust/naga/src/proc/constant_evaluator.rs @@ -4,8 +4,8 @@ use arrayvec::ArrayVec; use crate::{ arena::{Arena, Handle, UniqueArena}, - ArraySize, BinaryOperator, Constant, Expression, Literal, ScalarKind, Span, Type, TypeInner, - UnaryOperator, + ArraySize, BinaryOperator, Constant, Expression, Literal, Override, ScalarKind, Span, Type, + TypeInner, UnaryOperator, }; /// A macro that allows dollar signs (`$`) to be emitted by other macros. Useful for generating @@ -253,9 +253,20 @@ gen_component_wise_extractor! { } #[derive(Debug)] -enum Behavior { - Wgsl, - Glsl, +enum Behavior<'a> { + Wgsl(WgslRestrictions<'a>), + Glsl(GlslRestrictions<'a>), +} + +impl Behavior<'_> { + /// Returns `true` if the inner WGSL/GLSL restrictions are runtime restrictions. + const fn has_runtime_restrictions(&self) -> bool { + matches!( + self, + &Behavior::Wgsl(WgslRestrictions::Runtime(_)) + | &Behavior::Glsl(GlslRestrictions::Runtime(_)) + ) + } } /// A context for evaluating constant expressions. @@ -278,7 +289,7 @@ enum Behavior { #[derive(Debug)] pub struct ConstantEvaluator<'a> { /// Which language's evaluation rules we should follow. - behavior: Behavior, + behavior: Behavior<'a>, /// The module's type arena. /// @@ -291,71 +302,155 @@ pub struct ConstantEvaluator<'a> { /// The module's constant arena. constants: &'a Arena<Constant>, + /// The module's override arena. + overrides: &'a Arena<Override>, + /// The arena to which we are contributing expressions. expressions: &'a mut Arena<Expression>, - /// When `self.expressions` refers to a function's local expression - /// arena, this needs to be populated - function_local_data: Option<FunctionLocalData<'a>>, + /// Tracks the constness of expressions residing in [`Self::expressions`] + expression_kind_tracker: &'a mut ExpressionKindTracker, +} + +#[derive(Debug)] +enum WgslRestrictions<'a> { + /// - const-expressions will be evaluated and inserted in the arena + Const, + /// - const-expressions will be evaluated and inserted in the arena + /// - override-expressions will be inserted in the arena + Override, + /// - const-expressions will be evaluated and inserted in the arena + /// - override-expressions will be inserted in the arena + /// - runtime-expressions will be inserted in the arena + Runtime(FunctionLocalData<'a>), +} + +#[derive(Debug)] +enum GlslRestrictions<'a> { + /// - const-expressions will be evaluated and inserted in the arena + Const, + /// - const-expressions will be evaluated and inserted in the arena + /// - override-expressions will be inserted in the arena + /// - runtime-expressions will be inserted in the arena + Runtime(FunctionLocalData<'a>), } #[derive(Debug)] struct FunctionLocalData<'a> { /// Global constant expressions - const_expressions: &'a Arena<Expression>, - /// Tracks the constness of expressions residing in `ConstantEvaluator.expressions` - expression_constness: &'a mut ExpressionConstnessTracker, + global_expressions: &'a Arena<Expression>, emitter: &'a mut super::Emitter, block: &'a mut crate::Block, } +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +pub enum ExpressionKind { + Const, + Override, + Runtime, +} + #[derive(Debug)] -pub struct ExpressionConstnessTracker { - inner: bit_set::BitSet, +pub struct ExpressionKindTracker { + inner: Vec<ExpressionKind>, } -impl ExpressionConstnessTracker { - pub fn new() -> Self { - Self { - inner: bit_set::BitSet::new(), - } +impl ExpressionKindTracker { + pub const fn new() -> Self { + Self { inner: Vec::new() } } /// Forces the the expression to not be const pub fn force_non_const(&mut self, value: Handle<Expression>) { - self.inner.remove(value.index()); + self.inner[value.index()] = ExpressionKind::Runtime; } - fn insert(&mut self, value: Handle<Expression>) { - self.inner.insert(value.index()); + pub fn insert(&mut self, value: Handle<Expression>, expr_type: ExpressionKind) { + assert_eq!(self.inner.len(), value.index()); + self.inner.push(expr_type); + } + pub fn is_const(&self, h: Handle<Expression>) -> bool { + matches!(self.type_of(h), ExpressionKind::Const) + } + + pub fn is_const_or_override(&self, h: Handle<Expression>) -> bool { + matches!( + self.type_of(h), + ExpressionKind::Const | ExpressionKind::Override + ) } - pub fn is_const(&self, value: Handle<Expression>) -> bool { - self.inner.contains(value.index()) + fn type_of(&self, value: Handle<Expression>) -> ExpressionKind { + self.inner[value.index()] } pub fn from_arena(arena: &Arena<Expression>) -> Self { - let mut tracker = Self::new(); - for (handle, expr) in arena.iter() { - let insert = match *expr { - crate::Expression::Literal(_) - | crate::Expression::ZeroValue(_) - | crate::Expression::Constant(_) => true, - crate::Expression::Compose { ref components, .. } => { - components.iter().all(|h| tracker.is_const(*h)) - } - crate::Expression::Splat { value, .. } => tracker.is_const(value), - _ => false, - }; - if insert { - tracker.insert(handle); - } + let mut tracker = Self { + inner: Vec::with_capacity(arena.len()), + }; + for (_, expr) in arena.iter() { + tracker.inner.push(tracker.type_of_with_expr(expr)); } tracker } + + fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind { + match *expr { + Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => { + ExpressionKind::Const + } + Expression::Override(_) => ExpressionKind::Override, + Expression::Compose { ref components, .. } => { + let mut expr_type = ExpressionKind::Const; + for component in components { + expr_type = expr_type.max(self.type_of(*component)) + } + expr_type + } + Expression::Splat { value, .. } => self.type_of(value), + Expression::AccessIndex { base, .. } => self.type_of(base), + Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)), + Expression::Swizzle { vector, .. } => self.type_of(vector), + Expression::Unary { expr, .. } => self.type_of(expr), + Expression::Binary { left, right, .. } => self.type_of(left).max(self.type_of(right)), + Expression::Math { + arg, + arg1, + arg2, + arg3, + .. + } => self + .type_of(arg) + .max( + arg1.map(|arg| self.type_of(arg)) + .unwrap_or(ExpressionKind::Const), + ) + .max( + arg2.map(|arg| self.type_of(arg)) + .unwrap_or(ExpressionKind::Const), + ) + .max( + arg3.map(|arg| self.type_of(arg)) + .unwrap_or(ExpressionKind::Const), + ), + Expression::As { expr, .. } => self.type_of(expr), + Expression::Select { + condition, + accept, + reject, + } => self + .type_of(condition) + .max(self.type_of(accept)) + .max(self.type_of(reject)), + Expression::Relational { argument, .. } => self.type_of(argument), + Expression::ArrayLength(expr) => self.type_of(expr), + _ => ExpressionKind::Runtime, + } + } } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum ConstantEvaluatorError { #[error("Constants cannot access function arguments")] FunctionArg, @@ -381,6 +476,8 @@ pub enum ConstantEvaluatorError { ImageExpression, #[error("Constants don't support ray query expressions")] RayQueryExpression, + #[error("Constants don't support subgroup expressions")] + SubgroupExpression, #[error("Cannot access the type")] InvalidAccessBase, #[error("Cannot access at the index")] @@ -432,6 +529,12 @@ pub enum ConstantEvaluatorError { ShiftedMoreThan32Bits, #[error(transparent)] Literal(#[from] crate::valid::LiteralError), + #[error("Can't use pipeline-overridable constants in const-expressions")] + Override, + #[error("Unexpected runtime-expression")] + RuntimeExpr, + #[error("Unexpected override-expression")] + OverrideExpr, } impl<'a> ConstantEvaluator<'a> { @@ -439,25 +542,49 @@ impl<'a> ConstantEvaluator<'a> { /// constant expression arena. /// /// Report errors according to WGSL's rules for constant evaluation. - pub fn for_wgsl_module(module: &'a mut crate::Module) -> Self { - Self::for_module(Behavior::Wgsl, module) + pub fn for_wgsl_module( + module: &'a mut crate::Module, + global_expression_kind_tracker: &'a mut ExpressionKindTracker, + in_override_ctx: bool, + ) -> Self { + Self::for_module( + Behavior::Wgsl(if in_override_ctx { + WgslRestrictions::Override + } else { + WgslRestrictions::Const + }), + module, + global_expression_kind_tracker, + ) } /// Return a [`ConstantEvaluator`] that will add expressions to `module`'s /// constant expression arena. /// /// Report errors according to GLSL's rules for constant evaluation. - pub fn for_glsl_module(module: &'a mut crate::Module) -> Self { - Self::for_module(Behavior::Glsl, module) + pub fn for_glsl_module( + module: &'a mut crate::Module, + global_expression_kind_tracker: &'a mut ExpressionKindTracker, + ) -> Self { + Self::for_module( + Behavior::Glsl(GlslRestrictions::Const), + module, + global_expression_kind_tracker, + ) } - fn for_module(behavior: Behavior, module: &'a mut crate::Module) -> Self { + fn for_module( + behavior: Behavior<'a>, + module: &'a mut crate::Module, + global_expression_kind_tracker: &'a mut ExpressionKindTracker, + ) -> Self { Self { behavior, types: &mut module.types, constants: &module.constants, - expressions: &mut module.const_expressions, - function_local_data: None, + overrides: &module.overrides, + expressions: &mut module.global_expressions, + expression_kind_tracker: global_expression_kind_tracker, } } @@ -468,18 +595,22 @@ impl<'a> ConstantEvaluator<'a> { pub fn for_wgsl_function( module: &'a mut crate::Module, expressions: &'a mut Arena<Expression>, - expression_constness: &'a mut ExpressionConstnessTracker, + local_expression_kind_tracker: &'a mut ExpressionKindTracker, emitter: &'a mut super::Emitter, block: &'a mut crate::Block, ) -> Self { - Self::for_function( - Behavior::Wgsl, - module, + Self { + behavior: Behavior::Wgsl(WgslRestrictions::Runtime(FunctionLocalData { + global_expressions: &module.global_expressions, + emitter, + block, + })), + types: &mut module.types, + constants: &module.constants, + overrides: &module.overrides, expressions, - expression_constness, - emitter, - block, - ) + expression_kind_tracker: local_expression_kind_tracker, + } } /// Return a [`ConstantEvaluator`] that will add expressions to `function`'s @@ -489,39 +620,21 @@ impl<'a> ConstantEvaluator<'a> { pub fn for_glsl_function( module: &'a mut crate::Module, expressions: &'a mut Arena<Expression>, - expression_constness: &'a mut ExpressionConstnessTracker, - emitter: &'a mut super::Emitter, - block: &'a mut crate::Block, - ) -> Self { - Self::for_function( - Behavior::Glsl, - module, - expressions, - expression_constness, - emitter, - block, - ) - } - - fn for_function( - behavior: Behavior, - module: &'a mut crate::Module, - expressions: &'a mut Arena<Expression>, - expression_constness: &'a mut ExpressionConstnessTracker, + local_expression_kind_tracker: &'a mut ExpressionKindTracker, emitter: &'a mut super::Emitter, block: &'a mut crate::Block, ) -> Self { Self { - behavior, + behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData { + global_expressions: &module.global_expressions, + emitter, + block, + })), types: &mut module.types, constants: &module.constants, + overrides: &module.overrides, expressions, - function_local_data: Some(FunctionLocalData { - const_expressions: &module.const_expressions, - expression_constness, - emitter, - block, - }), + expression_kind_tracker: local_expression_kind_tracker, } } @@ -529,19 +642,18 @@ impl<'a> ConstantEvaluator<'a> { crate::proc::GlobalCtx { types: self.types, constants: self.constants, - const_expressions: match self.function_local_data { - Some(ref data) => data.const_expressions, + overrides: self.overrides, + global_expressions: match self.function_local_data() { + Some(data) => data.global_expressions, None => self.expressions, }, } } fn check(&self, expr: Handle<Expression>) -> Result<(), ConstantEvaluatorError> { - if let Some(ref function_local_data) = self.function_local_data { - if !function_local_data.expression_constness.is_const(expr) { - log::debug!("check: SubexpressionsAreNotConstant"); - return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant); - } + if !self.expression_kind_tracker.is_const(expr) { + log::debug!("check: SubexpressionsAreNotConstant"); + return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant); } Ok(()) } @@ -554,11 +666,11 @@ impl<'a> ConstantEvaluator<'a> { Expression::Constant(c) => { // Are we working in a function's expression arena, or the // module's constant expression arena? - if let Some(ref function_local_data) = self.function_local_data { + if let Some(function_local_data) = self.function_local_data() { // Deep-copy the constant's value into our arena. self.copy_from( self.constants[c].init, - function_local_data.const_expressions, + function_local_data.global_expressions, ) } else { // "See through" the constant and use its initializer. @@ -580,9 +692,11 @@ impl<'a> ConstantEvaluator<'a> { /// [`ZeroValue`], and [`Swizzle`] expressions - to the expression arena /// `self` contributes to. /// - /// If `expr`'s value cannot be determined at compile time, return a an - /// error. If it's acceptable to evaluate `expr` at runtime, this error can - /// be ignored, and the caller can append `expr` to the arena itself. + /// If `expr`'s value cannot be determined at compile time, and `self` is + /// contributing to some function's expression arena, then append `expr` to + /// that arena unchanged (and thus unevaluated). Otherwise, `self` must be + /// contributing to the module's constant expression arena; since `expr`'s + /// value is not a constant, return an error. /// /// We only consider `expr` itself, without recursing into its operands. Its /// operands must all have been produced by prior calls to @@ -595,16 +709,81 @@ impl<'a> ConstantEvaluator<'a> { /// [`Swizzle`]: Expression::Swizzle pub fn try_eval_and_append( &mut self, + expr: Expression, + span: Span, + ) -> Result<Handle<Expression>, ConstantEvaluatorError> { + match self.expression_kind_tracker.type_of_with_expr(&expr) { + ExpressionKind::Const => { + let eval_result = self.try_eval_and_append_impl(&expr, span); + // We should be able to evaluate `Const` expressions at this + // point. If we failed to, then that probably means we just + // haven't implemented that part of constant evaluation. Work + // around this by simply emitting it as a run-time expression. + if self.behavior.has_runtime_restrictions() + && matches!( + eval_result, + Err(ConstantEvaluatorError::NotImplemented(_) + | ConstantEvaluatorError::InvalidBinaryOpArgs,) + ) + { + Ok(self.append_expr(expr, span, ExpressionKind::Runtime)) + } else { + eval_result + } + } + ExpressionKind::Override => match self.behavior { + Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => { + Ok(self.append_expr(expr, span, ExpressionKind::Override)) + } + Behavior::Wgsl(WgslRestrictions::Const) => { + Err(ConstantEvaluatorError::OverrideExpr) + } + Behavior::Glsl(_) => { + unreachable!() + } + }, + ExpressionKind::Runtime => { + if self.behavior.has_runtime_restrictions() { + Ok(self.append_expr(expr, span, ExpressionKind::Runtime)) + } else { + Err(ConstantEvaluatorError::RuntimeExpr) + } + } + } + } + + /// Is the [`Self::expressions`] arena the global module expression arena? + const fn is_global_arena(&self) -> bool { + matches!( + self.behavior, + Behavior::Wgsl(WgslRestrictions::Const | WgslRestrictions::Override) + | Behavior::Glsl(GlslRestrictions::Const) + ) + } + + const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> { + match self.behavior { + Behavior::Wgsl(WgslRestrictions::Runtime(ref function_local_data)) + | Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => { + Some(function_local_data) + } + _ => None, + } + } + + fn try_eval_and_append_impl( + &mut self, expr: &Expression, span: Span, ) -> Result<Handle<Expression>, ConstantEvaluatorError> { log::trace!("try_eval_and_append: {:?}", expr); match *expr { - Expression::Constant(c) if self.function_local_data.is_none() => { + Expression::Constant(c) if self.is_global_arena() => { // "See through" the constant and use its initializer. // This is mainly done to avoid having constants pointing to other constants. Ok(self.constants[c].init) } + Expression::Override(_) => Err(ConstantEvaluatorError::Override), Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => { self.register_evaluated_expr(expr.clone(), span) } @@ -685,8 +864,8 @@ impl<'a> ConstantEvaluator<'a> { format!("{fun:?} built-in function"), )), Expression::ArrayLength(expr) => match self.behavior { - Behavior::Wgsl => Err(ConstantEvaluatorError::ArrayLength), - Behavior::Glsl => { + Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength), + Behavior::Glsl(_) => { let expr = self.check_and_get(expr)?; self.array_length(expr, span) } @@ -707,6 +886,12 @@ impl<'a> ConstantEvaluator<'a> { Expression::RayQueryProceedResult | Expression::RayQueryGetIntersection { .. } => { Err(ConstantEvaluatorError::RayQueryExpression) } + Expression::SubgroupBallotResult { .. } => { + Err(ConstantEvaluatorError::SubgroupExpression) + } + Expression::SubgroupOperationResult { .. } => { + Err(ConstantEvaluatorError::SubgroupExpression) + } } } @@ -765,10 +950,10 @@ impl<'a> ConstantEvaluator<'a> { pattern: [crate::SwizzleComponent; 4], ) -> Result<Handle<Expression>, ConstantEvaluatorError> { let mut get_dst_ty = |ty| match self.types[ty].inner { - crate::TypeInner::Vector { size: _, scalar } => Ok(self.types.insert( + TypeInner::Vector { size: _, scalar } => Ok(self.types.insert( Type { name: None, - inner: crate::TypeInner::Vector { size, scalar }, + inner: TypeInner::Vector { size, scalar }, }, span, )), @@ -1059,13 +1244,11 @@ impl<'a> ConstantEvaluator<'a> { Expression::ZeroValue(ty) | Expression::Compose { ty, .. } => { match self.types[ty].inner { TypeInner::Array { size, .. } => match size { - crate::ArraySize::Constant(len) => { + ArraySize::Constant(len) => { let expr = Expression::Literal(Literal::U32(len.get())); self.register_evaluated_expr(expr, span) } - crate::ArraySize::Dynamic => { - Err(ConstantEvaluatorError::ArrayLengthDynamic) - } + ArraySize::Dynamic => Err(ConstantEvaluatorError::ArrayLengthDynamic), }, _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg), } @@ -1128,7 +1311,7 @@ impl<'a> ConstantEvaluator<'a> { Expression::ZeroValue(ty) if matches!( self.types[ty].inner, - crate::TypeInner::Scalar(crate::Scalar { + TypeInner::Scalar(crate::Scalar { kind: ScalarKind::Uint, .. }) @@ -1443,7 +1626,7 @@ impl<'a> ConstantEvaluator<'a> { return self.cast(expr, target, span); }; - let crate::TypeInner::Array { + let TypeInner::Array { base: _, size, stride: _, @@ -1853,29 +2036,35 @@ impl<'a> ConstantEvaluator<'a> { crate::valid::check_literal_value(literal)?; } - if let Some(FunctionLocalData { - ref mut emitter, - ref mut block, - ref mut expression_constness, - .. - }) = self.function_local_data - { - let is_running = emitter.is_running(); - let needs_pre_emit = expr.needs_pre_emit(); - if is_running && needs_pre_emit { - block.extend(emitter.finish(self.expressions)); - let h = self.expressions.append(expr, span); - emitter.start(self.expressions); - expression_constness.insert(h); - Ok(h) - } else { - let h = self.expressions.append(expr, span); - expression_constness.insert(h); - Ok(h) + Ok(self.append_expr(expr, span, ExpressionKind::Const)) + } + + fn append_expr( + &mut self, + expr: Expression, + span: Span, + expr_type: ExpressionKind, + ) -> Handle<Expression> { + let h = match self.behavior { + Behavior::Wgsl(WgslRestrictions::Runtime(ref mut function_local_data)) + | Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => { + let is_running = function_local_data.emitter.is_running(); + let needs_pre_emit = expr.needs_pre_emit(); + if is_running && needs_pre_emit { + function_local_data + .block + .extend(function_local_data.emitter.finish(self.expressions)); + let h = self.expressions.append(expr, span); + function_local_data.emitter.start(self.expressions); + h + } else { + self.expressions.append(expr, span) + } } - } else { - Ok(self.expressions.append(expr, span)) - } + _ => self.expressions.append(expr, span), + }; + self.expression_kind_tracker.insert(h, expr_type); + h } fn resolve_type( @@ -2029,13 +2218,14 @@ mod tests { UniqueArena, VectorSize, }; - use super::{Behavior, ConstantEvaluator}; + use super::{Behavior, ConstantEvaluator, ExpressionKindTracker, WgslRestrictions}; #[test] fn unary_op() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); - let mut const_expressions = Arena::new(); + let overrides = Arena::new(); + let mut global_expressions = Arena::new(); let scalar_ty = types.insert( Type { @@ -2059,9 +2249,8 @@ mod tests { let h = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: scalar_ty, - init: const_expressions + init: global_expressions .append(Expression::Literal(Literal::I32(4)), Default::default()), }, Default::default(), @@ -2070,9 +2259,8 @@ mod tests { let h1 = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: scalar_ty, - init: const_expressions + init: global_expressions .append(Expression::Literal(Literal::I32(8)), Default::default()), }, Default::default(), @@ -2081,9 +2269,8 @@ mod tests { let vec_h = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: vec_ty, - init: const_expressions.append( + init: global_expressions.append( Expression::Compose { ty: vec_ty, components: vec![constants[h].init, constants[h1].init], @@ -2094,8 +2281,8 @@ mod tests { Default::default(), ); - let expr = const_expressions.append(Expression::Constant(h), Default::default()); - let expr1 = const_expressions.append(Expression::Constant(vec_h), Default::default()); + let expr = global_expressions.append(Expression::Constant(h), Default::default()); + let expr1 = global_expressions.append(Expression::Constant(vec_h), Default::default()); let expr2 = Expression::Unary { op: UnaryOperator::Negate, @@ -2112,35 +2299,37 @@ mod tests { expr: expr1, }; + let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions); let mut solver = ConstantEvaluator { - behavior: Behavior::Wgsl, + behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, constants: &constants, - expressions: &mut const_expressions, - function_local_data: None, + overrides: &overrides, + expressions: &mut global_expressions, + expression_kind_tracker, }; let res1 = solver - .try_eval_and_append(&expr2, Default::default()) + .try_eval_and_append(expr2, Default::default()) .unwrap(); let res2 = solver - .try_eval_and_append(&expr3, Default::default()) + .try_eval_and_append(expr3, Default::default()) .unwrap(); let res3 = solver - .try_eval_and_append(&expr4, Default::default()) + .try_eval_and_append(expr4, Default::default()) .unwrap(); assert_eq!( - const_expressions[res1], + global_expressions[res1], Expression::Literal(Literal::I32(-4)) ); assert_eq!( - const_expressions[res2], + global_expressions[res2], Expression::Literal(Literal::I32(!4)) ); - let res3_inner = &const_expressions[res3]; + let res3_inner = &global_expressions[res3]; match *res3_inner { Expression::Compose { @@ -2150,11 +2339,11 @@ mod tests { assert_eq!(*ty, vec_ty); let mut components_iter = components.iter().copied(); assert_eq!( - const_expressions[components_iter.next().unwrap()], + global_expressions[components_iter.next().unwrap()], Expression::Literal(Literal::I32(!4)) ); assert_eq!( - const_expressions[components_iter.next().unwrap()], + global_expressions[components_iter.next().unwrap()], Expression::Literal(Literal::I32(!8)) ); assert!(components_iter.next().is_none()); @@ -2167,7 +2356,8 @@ mod tests { fn cast() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); - let mut const_expressions = Arena::new(); + let overrides = Arena::new(); + let mut global_expressions = Arena::new(); let scalar_ty = types.insert( Type { @@ -2180,15 +2370,14 @@ mod tests { let h = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: scalar_ty, - init: const_expressions + init: global_expressions .append(Expression::Literal(Literal::I32(4)), Default::default()), }, Default::default(), ); - let expr = const_expressions.append(Expression::Constant(h), Default::default()); + let expr = global_expressions.append(Expression::Constant(h), Default::default()); let root = Expression::As { expr, @@ -2196,20 +2385,22 @@ mod tests { convert: Some(crate::BOOL_WIDTH), }; + let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions); let mut solver = ConstantEvaluator { - behavior: Behavior::Wgsl, + behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, constants: &constants, - expressions: &mut const_expressions, - function_local_data: None, + overrides: &overrides, + expressions: &mut global_expressions, + expression_kind_tracker, }; let res = solver - .try_eval_and_append(&root, Default::default()) + .try_eval_and_append(root, Default::default()) .unwrap(); assert_eq!( - const_expressions[res], + global_expressions[res], Expression::Literal(Literal::Bool(true)) ); } @@ -2218,7 +2409,8 @@ mod tests { fn access() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); - let mut const_expressions = Arena::new(); + let overrides = Arena::new(); + let mut global_expressions = Arena::new(); let matrix_ty = types.insert( Type { @@ -2247,7 +2439,7 @@ mod tests { let mut vec2_components = Vec::with_capacity(3); for i in 0..3 { - let h = const_expressions.append( + let h = global_expressions.append( Expression::Literal(Literal::F32(i as f32)), Default::default(), ); @@ -2256,7 +2448,7 @@ mod tests { } for i in 3..6 { - let h = const_expressions.append( + let h = global_expressions.append( Expression::Literal(Literal::F32(i as f32)), Default::default(), ); @@ -2267,9 +2459,8 @@ mod tests { let vec1 = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: vec_ty, - init: const_expressions.append( + init: global_expressions.append( Expression::Compose { ty: vec_ty, components: vec1_components, @@ -2283,9 +2474,8 @@ mod tests { let vec2 = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: vec_ty, - init: const_expressions.append( + init: global_expressions.append( Expression::Compose { ty: vec_ty, components: vec2_components, @@ -2299,9 +2489,8 @@ mod tests { let h = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: matrix_ty, - init: const_expressions.append( + init: global_expressions.append( Expression::Compose { ty: matrix_ty, components: vec![constants[vec1].init, constants[vec2].init], @@ -2312,20 +2501,22 @@ mod tests { Default::default(), ); - let base = const_expressions.append(Expression::Constant(h), Default::default()); + let base = global_expressions.append(Expression::Constant(h), Default::default()); + let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions); let mut solver = ConstantEvaluator { - behavior: Behavior::Wgsl, + behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, constants: &constants, - expressions: &mut const_expressions, - function_local_data: None, + overrides: &overrides, + expressions: &mut global_expressions, + expression_kind_tracker, }; let root1 = Expression::AccessIndex { base, index: 1 }; let res1 = solver - .try_eval_and_append(&root1, Default::default()) + .try_eval_and_append(root1, Default::default()) .unwrap(); let root2 = Expression::AccessIndex { @@ -2334,10 +2525,10 @@ mod tests { }; let res2 = solver - .try_eval_and_append(&root2, Default::default()) + .try_eval_and_append(root2, Default::default()) .unwrap(); - match const_expressions[res1] { + match global_expressions[res1] { Expression::Compose { ref ty, ref components, @@ -2345,15 +2536,15 @@ mod tests { assert_eq!(*ty, vec_ty); let mut components_iter = components.iter().copied(); assert_eq!( - const_expressions[components_iter.next().unwrap()], + global_expressions[components_iter.next().unwrap()], Expression::Literal(Literal::F32(3.)) ); assert_eq!( - const_expressions[components_iter.next().unwrap()], + global_expressions[components_iter.next().unwrap()], Expression::Literal(Literal::F32(4.)) ); assert_eq!( - const_expressions[components_iter.next().unwrap()], + global_expressions[components_iter.next().unwrap()], Expression::Literal(Literal::F32(5.)) ); assert!(components_iter.next().is_none()); @@ -2362,7 +2553,7 @@ mod tests { } assert_eq!( - const_expressions[res2], + global_expressions[res2], Expression::Literal(Literal::F32(5.)) ); } @@ -2371,7 +2562,8 @@ mod tests { fn compose_of_constants() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); - let mut const_expressions = Arena::new(); + let overrides = Arena::new(); + let mut global_expressions = Arena::new(); let i32_ty = types.insert( Type { @@ -2395,27 +2587,28 @@ mod tests { let h = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: i32_ty, - init: const_expressions + init: global_expressions .append(Expression::Literal(Literal::I32(4)), Default::default()), }, Default::default(), ); - let h_expr = const_expressions.append(Expression::Constant(h), Default::default()); + let h_expr = global_expressions.append(Expression::Constant(h), Default::default()); + let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions); let mut solver = ConstantEvaluator { - behavior: Behavior::Wgsl, + behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, constants: &constants, - expressions: &mut const_expressions, - function_local_data: None, + overrides: &overrides, + expressions: &mut global_expressions, + expression_kind_tracker, }; let solved_compose = solver .try_eval_and_append( - &Expression::Compose { + Expression::Compose { ty: vec2_i32_ty, components: vec![h_expr, h_expr], }, @@ -2424,7 +2617,7 @@ mod tests { .unwrap(); let solved_negate = solver .try_eval_and_append( - &Expression::Unary { + Expression::Unary { op: UnaryOperator::Negate, expr: solved_compose, }, @@ -2432,11 +2625,11 @@ mod tests { ) .unwrap(); - let pass = match const_expressions[solved_negate] { + let pass = match global_expressions[solved_negate] { Expression::Compose { ty, ref components } => { ty == vec2_i32_ty && components.iter().all(|&component| { - let component = &const_expressions[component]; + let component = &global_expressions[component]; matches!(*component, Expression::Literal(Literal::I32(-4))) }) } @@ -2451,7 +2644,8 @@ mod tests { fn splat_of_constant() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); - let mut const_expressions = Arena::new(); + let overrides = Arena::new(); + let mut global_expressions = Arena::new(); let i32_ty = types.insert( Type { @@ -2475,27 +2669,28 @@ mod tests { let h = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: i32_ty, - init: const_expressions + init: global_expressions .append(Expression::Literal(Literal::I32(4)), Default::default()), }, Default::default(), ); - let h_expr = const_expressions.append(Expression::Constant(h), Default::default()); + let h_expr = global_expressions.append(Expression::Constant(h), Default::default()); + let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions); let mut solver = ConstantEvaluator { - behavior: Behavior::Wgsl, + behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, constants: &constants, - expressions: &mut const_expressions, - function_local_data: None, + overrides: &overrides, + expressions: &mut global_expressions, + expression_kind_tracker, }; let solved_compose = solver .try_eval_and_append( - &Expression::Splat { + Expression::Splat { size: VectorSize::Bi, value: h_expr, }, @@ -2504,7 +2699,7 @@ mod tests { .unwrap(); let solved_negate = solver .try_eval_and_append( - &Expression::Unary { + Expression::Unary { op: UnaryOperator::Negate, expr: solved_compose, }, @@ -2512,11 +2707,11 @@ mod tests { ) .unwrap(); - let pass = match const_expressions[solved_negate] { + let pass = match global_expressions[solved_negate] { Expression::Compose { ty, ref components } => { ty == vec2_i32_ty && components.iter().all(|&component| { - let component = &const_expressions[component]; + let component = &global_expressions[component]; matches!(*component, Expression::Literal(Literal::I32(-4))) }) } diff --git a/third_party/rust/naga/src/proc/index.rs b/third_party/rust/naga/src/proc/index.rs index af3221c0fe..e2c3de8eb0 100644 --- a/third_party/rust/naga/src/proc/index.rs +++ b/third_party/rust/naga/src/proc/index.rs @@ -239,7 +239,7 @@ pub enum GuardedIndex { pub fn find_checked_indexes( module: &crate::Module, function: &crate::Function, - info: &crate::valid::FunctionInfo, + info: &valid::FunctionInfo, policies: BoundsCheckPolicies, ) -> BitSet { use crate::Expression as Ex; @@ -321,7 +321,7 @@ pub fn access_needs_check( mut index: GuardedIndex, module: &crate::Module, function: &crate::Function, - info: &crate::valid::FunctionInfo, + info: &valid::FunctionInfo, ) -> Option<IndexableLength> { let base_inner = info[base].ty.inner_with(&module.types); // Unwrap safety: `Err` here indicates unindexable base types and invalid diff --git a/third_party/rust/naga/src/proc/mod.rs b/third_party/rust/naga/src/proc/mod.rs index 46cbb6c3b3..93aac5b3e5 100644 --- a/third_party/rust/naga/src/proc/mod.rs +++ b/third_party/rust/naga/src/proc/mod.rs @@ -11,7 +11,7 @@ mod terminator; mod typifier; pub use constant_evaluator::{ - ConstantEvaluator, ConstantEvaluatorError, ExpressionConstnessTracker, + ConstantEvaluator, ConstantEvaluatorError, ExpressionKind, ExpressionKindTracker, }; pub use emitter::Emitter; pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError}; @@ -153,56 +153,31 @@ impl super::Scalar { } } -impl PartialEq for crate::Literal { - fn eq(&self, other: &Self) -> bool { - match (*self, *other) { - (Self::F64(a), Self::F64(b)) => a.to_bits() == b.to_bits(), - (Self::F32(a), Self::F32(b)) => a.to_bits() == b.to_bits(), - (Self::U32(a), Self::U32(b)) => a == b, - (Self::I32(a), Self::I32(b)) => a == b, - (Self::U64(a), Self::U64(b)) => a == b, - (Self::I64(a), Self::I64(b)) => a == b, - (Self::Bool(a), Self::Bool(b)) => a == b, - _ => false, - } - } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum HashableLiteral { + F64(u64), + F32(u32), + U32(u32), + I32(i32), + U64(u64), + I64(i64), + Bool(bool), + AbstractInt(i64), + AbstractFloat(u64), } -impl Eq for crate::Literal {} -impl std::hash::Hash for crate::Literal { - fn hash<H: std::hash::Hasher>(&self, hasher: &mut H) { - match *self { - Self::F64(v) | Self::AbstractFloat(v) => { - hasher.write_u8(0); - v.to_bits().hash(hasher); - } - Self::F32(v) => { - hasher.write_u8(1); - v.to_bits().hash(hasher); - } - Self::U32(v) => { - hasher.write_u8(2); - v.hash(hasher); - } - Self::I32(v) => { - hasher.write_u8(3); - v.hash(hasher); - } - Self::Bool(v) => { - hasher.write_u8(4); - v.hash(hasher); - } - Self::I64(v) => { - hasher.write_u8(5); - v.hash(hasher); - } - Self::U64(v) => { - hasher.write_u8(6); - v.hash(hasher); - } - Self::AbstractInt(v) => { - hasher.write_u8(7); - v.hash(hasher); - } + +impl From<crate::Literal> for HashableLiteral { + fn from(l: crate::Literal) -> Self { + match l { + crate::Literal::F64(v) => Self::F64(v.to_bits()), + crate::Literal::F32(v) => Self::F32(v.to_bits()), + crate::Literal::U32(v) => Self::U32(v), + crate::Literal::I32(v) => Self::I32(v), + crate::Literal::U64(v) => Self::U64(v), + crate::Literal::I64(v) => Self::I64(v), + crate::Literal::Bool(v) => Self::Bool(v), + crate::Literal::AbstractInt(v) => Self::AbstractInt(v), + crate::Literal::AbstractFloat(v) => Self::AbstractFloat(v.to_bits()), } } } @@ -216,8 +191,8 @@ impl crate::Literal { (value, crate::ScalarKind::Sint, 4) => Some(Self::I32(value as _)), (value, crate::ScalarKind::Uint, 8) => Some(Self::U64(value as _)), (value, crate::ScalarKind::Sint, 8) => Some(Self::I64(value as _)), - (1, crate::ScalarKind::Bool, 4) => Some(Self::Bool(true)), - (0, crate::ScalarKind::Bool, 4) => Some(Self::Bool(false)), + (1, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(true)), + (0, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(false)), _ => None, } } @@ -279,8 +254,9 @@ impl super::TypeInner { self.scalar().map(|scalar| scalar.kind) } + /// Returns the scalar width in bytes pub fn scalar_width(&self) -> Option<u8> { - self.scalar().map(|scalar| scalar.width * 8) + self.scalar().map(|scalar| scalar.width) } pub const fn pointer_space(&self) -> Option<crate::AddressSpace> { @@ -532,6 +508,7 @@ impl crate::Expression { match *self { Self::Literal(_) | Self::Constant(_) + | Self::Override(_) | Self::ZeroValue(_) | Self::FunctionArgument(_) | Self::GlobalVariable(_) @@ -553,13 +530,9 @@ impl crate::Expression { /// /// [`Access`]: crate::Expression::Access /// [`ResolveContext`]: crate::proc::ResolveContext - pub fn is_dynamic_index(&self, module: &crate::Module) -> bool { + pub const fn is_dynamic_index(&self) -> bool { match *self { - Self::Literal(_) | Self::ZeroValue(_) => false, - Self::Constant(handle) => { - let constant = &module.constants[handle]; - !matches!(constant.r#override, crate::Override::None) - } + Self::Literal(_) | Self::ZeroValue(_) | Self::Constant(_) => false, _ => true, } } @@ -652,7 +625,8 @@ impl crate::Module { GlobalCtx { types: &self.types, constants: &self.constants, - const_expressions: &self.const_expressions, + overrides: &self.overrides, + global_expressions: &self.global_expressions, } } } @@ -667,17 +641,18 @@ pub(super) enum U32EvalError { pub struct GlobalCtx<'a> { pub types: &'a crate::UniqueArena<crate::Type>, pub constants: &'a crate::Arena<crate::Constant>, - pub const_expressions: &'a crate::Arena<crate::Expression>, + pub overrides: &'a crate::Arena<crate::Override>, + pub global_expressions: &'a crate::Arena<crate::Expression>, } impl GlobalCtx<'_> { - /// Try to evaluate the expression in `self.const_expressions` using its `handle` and return it as a `u32`. + /// Try to evaluate the expression in `self.global_expressions` using its `handle` and return it as a `u32`. #[allow(dead_code)] pub(super) fn eval_expr_to_u32( &self, handle: crate::Handle<crate::Expression>, ) -> Result<u32, U32EvalError> { - self.eval_expr_to_u32_from(handle, self.const_expressions) + self.eval_expr_to_u32_from(handle, self.global_expressions) } /// Try to evaluate the expression in the `arena` using its `handle` and return it as a `u32`. @@ -700,7 +675,7 @@ impl GlobalCtx<'_> { &self, handle: crate::Handle<crate::Expression>, ) -> Option<crate::Literal> { - self.eval_expr_to_literal_from(handle, self.const_expressions) + self.eval_expr_to_literal_from(handle, self.global_expressions) } fn eval_expr_to_literal_from( @@ -724,7 +699,7 @@ impl GlobalCtx<'_> { } match arena[handle] { crate::Expression::Constant(c) => { - get(*self, self.constants[c].init, self.const_expressions) + get(*self, self.constants[c].init, self.global_expressions) } _ => get(*self, handle, arena), } diff --git a/third_party/rust/naga/src/proc/terminator.rs b/third_party/rust/naga/src/proc/terminator.rs index a5239d4eca..5edf55cb73 100644 --- a/third_party/rust/naga/src/proc/terminator.rs +++ b/third_party/rust/naga/src/proc/terminator.rs @@ -37,6 +37,9 @@ pub fn ensure_block_returns(block: &mut crate::Block) { | S::RayQuery { .. } | S::Atomic { .. } | S::WorkGroupUniformLoad { .. } + | S::SubgroupBallot { .. } + | S::SubgroupCollectiveOperation { .. } + | S::SubgroupGather { .. } | S::Barrier(_)), ) | None => block.push(S::Return { value: None }, Default::default()), diff --git a/third_party/rust/naga/src/proc/typifier.rs b/third_party/rust/naga/src/proc/typifier.rs index 9c4403445c..3936e7efbe 100644 --- a/third_party/rust/naga/src/proc/typifier.rs +++ b/third_party/rust/naga/src/proc/typifier.rs @@ -185,6 +185,7 @@ pub enum ResolveError { pub struct ResolveContext<'a> { pub constants: &'a Arena<crate::Constant>, + pub overrides: &'a Arena<crate::Override>, pub types: &'a UniqueArena<crate::Type>, pub special_types: &'a crate::SpecialTypes, pub global_vars: &'a Arena<crate::GlobalVariable>, @@ -202,6 +203,7 @@ impl<'a> ResolveContext<'a> { ) -> Self { Self { constants: &module.constants, + overrides: &module.overrides, types: &module.types, special_types: &module.special_types, global_vars: &module.global_variables, @@ -407,6 +409,7 @@ impl<'a> ResolveContext<'a> { }, crate::Expression::Literal(lit) => TypeResolution::Value(lit.ty_inner()), crate::Expression::Constant(h) => TypeResolution::Handle(self.constants[h].ty), + crate::Expression::Override(h) => TypeResolution::Handle(self.overrides[h].ty), crate::Expression::ZeroValue(ty) => TypeResolution::Handle(ty), crate::Expression::Compose { ty, .. } => TypeResolution::Handle(ty), crate::Expression::FunctionArgument(index) => { @@ -595,6 +598,7 @@ impl<'a> ResolveContext<'a> { | crate::BinaryOperator::ShiftRight => past(left)?.clone(), }, crate::Expression::AtomicResult { ty, .. } => TypeResolution::Handle(ty), + crate::Expression::SubgroupOperationResult { ty } => TypeResolution::Handle(ty), crate::Expression::WorkGroupUniformLoadResult { ty } => TypeResolution::Handle(ty), crate::Expression::Select { accept, .. } => past(accept)?.clone(), crate::Expression::Derivative { expr, .. } => past(expr)?.clone(), @@ -882,6 +886,10 @@ impl<'a> ResolveContext<'a> { .ok_or(ResolveError::MissingSpecialType)?; TypeResolution::Handle(result) } + crate::Expression::SubgroupBallotResult => TypeResolution::Value(Ti::Vector { + scalar: crate::Scalar::U32, + size: crate::VectorSize::Quad, + }), }) } } |