summaryrefslogtreecommitdiffstats
path: root/third_party/rust/naga/src/valid/function.rs
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 19:33:14 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 19:33:14 +0000
commit36d22d82aa202bb199967e9512281e9a53db42c9 (patch)
tree105e8c98ddea1c1e4784a60a5a6410fa416be2de /third_party/rust/naga/src/valid/function.rs
parentInitial commit. (diff)
downloadfirefox-esr-upstream.tar.xz
firefox-esr-upstream.zip
Adding upstream version 115.7.0esr.upstream/115.7.0esrupstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/naga/src/valid/function.rs')
-rw-r--r--third_party/rust/naga/src/valid/function.rs1048
1 files changed, 1048 insertions, 0 deletions
diff --git a/third_party/rust/naga/src/valid/function.rs b/third_party/rust/naga/src/valid/function.rs
new file mode 100644
index 0000000000..8ebcddd9e3
--- /dev/null
+++ b/third_party/rust/naga/src/valid/function.rs
@@ -0,0 +1,1048 @@
+use crate::arena::Handle;
+#[cfg(feature = "validate")]
+use crate::arena::{Arena, UniqueArena};
+
+#[cfg(feature = "validate")]
+use super::validate_atomic_compare_exchange_struct;
+
+use super::{
+ analyzer::{UniformityDisruptor, UniformityRequirements},
+ ExpressionError, FunctionInfo, ModuleInfo,
+};
+use crate::span::WithSpan;
+#[cfg(feature = "validate")]
+use crate::span::{AddSpan as _, MapErrWithSpan as _};
+
+#[cfg(feature = "validate")]
+use bit_set::BitSet;
+
+#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum CallError {
+ #[error("Argument {index} expression is invalid")]
+ Argument {
+ index: usize,
+ source: ExpressionError,
+ },
+ #[error("Result expression {0:?} has already been introduced earlier")]
+ ResultAlreadyInScope(Handle<crate::Expression>),
+ #[error("Result value is invalid")]
+ ResultValue(#[source] ExpressionError),
+ #[error("Requires {required} arguments, but {seen} are provided")]
+ ArgumentCount { required: usize, seen: usize },
+ #[error("Argument {index} value {seen_expression:?} doesn't match the type {required:?}")]
+ ArgumentType {
+ index: usize,
+ required: Handle<crate::Type>,
+ seen_expression: Handle<crate::Expression>,
+ },
+ #[error("The emitted expression doesn't match the call")]
+ ExpressionMismatch(Option<Handle<crate::Expression>>),
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum AtomicError {
+ #[error("Pointer {0:?} to atomic is invalid.")]
+ InvalidPointer(Handle<crate::Expression>),
+ #[error("Operand {0:?} has invalid type.")]
+ InvalidOperand(Handle<crate::Expression>),
+ #[error("Result type for {0:?} doesn't match the statement")]
+ ResultTypeMismatch(Handle<crate::Expression>),
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum LocalVariableError {
+ #[error("Local variable has a type {0:?} that can't be stored in a local variable.")]
+ InvalidType(Handle<crate::Type>),
+ #[error("Initializer doesn't match the variable type")]
+ InitializerType,
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum FunctionError {
+ #[error("Expression {handle:?} is invalid")]
+ Expression {
+ handle: Handle<crate::Expression>,
+ source: ExpressionError,
+ },
+ #[error("Expression {0:?} can't be introduced - it's already in scope")]
+ ExpressionAlreadyInScope(Handle<crate::Expression>),
+ #[error("Local variable {handle:?} '{name}' is invalid")]
+ LocalVariable {
+ handle: Handle<crate::LocalVariable>,
+ name: String,
+ source: LocalVariableError,
+ },
+ #[error("Argument '{name}' at index {index} has a type that can't be passed into functions.")]
+ InvalidArgumentType { index: usize, name: String },
+ #[error("The function's given return type cannot be returned from functions")]
+ NonConstructibleReturnType,
+ #[error("Argument '{name}' at index {index} is a pointer of space {space:?}, which can't be passed into functions.")]
+ InvalidArgumentPointerSpace {
+ index: usize,
+ name: String,
+ space: crate::AddressSpace,
+ },
+ #[error("There are instructions after `return`/`break`/`continue`")]
+ InstructionsAfterReturn,
+ #[error("The `break` is used outside of a `loop` or `switch` context")]
+ BreakOutsideOfLoopOrSwitch,
+ #[error("The `continue` is used outside of a `loop` context")]
+ ContinueOutsideOfLoop,
+ #[error("The `return` is called within a `continuing` block")]
+ InvalidReturnSpot,
+ #[error("The `return` value {0:?} does not match the function return value")]
+ InvalidReturnType(Option<Handle<crate::Expression>>),
+ #[error("The `if` condition {0:?} is not a boolean scalar")]
+ InvalidIfType(Handle<crate::Expression>),
+ #[error("The `switch` value {0:?} is not an integer scalar")]
+ InvalidSwitchType(Handle<crate::Expression>),
+ #[error("Multiple `switch` cases for {0:?} are present")]
+ ConflictingSwitchCase(crate::SwitchValue),
+ #[error("The `switch` contains cases with conflicting types")]
+ ConflictingCaseType,
+ #[error("The `switch` is missing a `default` case")]
+ MissingDefaultCase,
+ #[error("Multiple `default` cases are present")]
+ MultipleDefaultCases,
+ #[error("The last `switch` case contains a `falltrough`")]
+ LastCaseFallTrough,
+ #[error("The pointer {0:?} doesn't relate to a valid destination for a store")]
+ InvalidStorePointer(Handle<crate::Expression>),
+ #[error("The value {0:?} can not be stored")]
+ InvalidStoreValue(Handle<crate::Expression>),
+ #[error("Store of {value:?} into {pointer:?} doesn't have matching types")]
+ InvalidStoreTypes {
+ pointer: Handle<crate::Expression>,
+ value: Handle<crate::Expression>,
+ },
+ #[error("Image store parameters are invalid")]
+ InvalidImageStore(#[source] ExpressionError),
+ #[error("Call to {function:?} is invalid")]
+ InvalidCall {
+ function: Handle<crate::Function>,
+ #[source]
+ error: CallError,
+ },
+ #[error("Atomic operation is invalid")]
+ InvalidAtomic(#[from] AtomicError),
+ #[error("Ray Query {0:?} is not a local variable")]
+ InvalidRayQueryExpression(Handle<crate::Expression>),
+ #[error("Acceleration structure {0:?} is not a matching expression")]
+ InvalidAccelerationStructure(Handle<crate::Expression>),
+ #[error("Ray descriptor {0:?} is not a matching expression")]
+ InvalidRayDescriptor(Handle<crate::Expression>),
+ #[error("Ray Query {0:?} does not have a matching type")]
+ InvalidRayQueryType(Handle<crate::Type>),
+ #[error(
+ "Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}"
+ )]
+ NonUniformControlFlow(
+ UniformityRequirements,
+ Handle<crate::Expression>,
+ UniformityDisruptor,
+ ),
+ #[error("Functions that are not entry points cannot have `@location` or `@builtin` attributes on their arguments: \"{name}\" has attributes")]
+ PipelineInputRegularFunction { name: String },
+ #[error("Functions that are not entry points cannot have `@location` or `@builtin` attributes on their return value types")]
+ PipelineOutputRegularFunction,
+}
+
+bitflags::bitflags! {
+ #[repr(transparent)]
+ struct ControlFlowAbility: u8 {
+ /// The control can return out of this block.
+ const RETURN = 0x1;
+ /// The control can break.
+ const BREAK = 0x2;
+ /// The control can continue.
+ const CONTINUE = 0x4;
+ }
+}
+
+#[cfg(feature = "validate")]
+struct BlockInfo {
+ stages: super::ShaderStages,
+ finished: bool,
+}
+
+#[cfg(feature = "validate")]
+struct BlockContext<'a> {
+ abilities: ControlFlowAbility,
+ info: &'a FunctionInfo,
+ expressions: &'a Arena<crate::Expression>,
+ types: &'a UniqueArena<crate::Type>,
+ local_vars: &'a Arena<crate::LocalVariable>,
+ global_vars: &'a Arena<crate::GlobalVariable>,
+ functions: &'a Arena<crate::Function>,
+ special_types: &'a crate::SpecialTypes,
+ prev_infos: &'a [FunctionInfo],
+ return_type: Option<Handle<crate::Type>>,
+}
+
+#[cfg(feature = "validate")]
+impl<'a> BlockContext<'a> {
+ fn new(
+ fun: &'a crate::Function,
+ module: &'a crate::Module,
+ info: &'a FunctionInfo,
+ prev_infos: &'a [FunctionInfo],
+ ) -> Self {
+ Self {
+ abilities: ControlFlowAbility::RETURN,
+ info,
+ expressions: &fun.expressions,
+ types: &module.types,
+ local_vars: &fun.local_variables,
+ global_vars: &module.global_variables,
+ functions: &module.functions,
+ special_types: &module.special_types,
+ prev_infos,
+ return_type: fun.result.as_ref().map(|fr| fr.ty),
+ }
+ }
+
+ const fn with_abilities(&self, abilities: ControlFlowAbility) -> Self {
+ BlockContext { abilities, ..*self }
+ }
+
+ fn get_expression(&self, handle: Handle<crate::Expression>) -> &'a crate::Expression {
+ &self.expressions[handle]
+ }
+
+ fn resolve_type_impl(
+ &self,
+ handle: Handle<crate::Expression>,
+ valid_expressions: &BitSet,
+ ) -> Result<&crate::TypeInner, WithSpan<ExpressionError>> {
+ if handle.index() >= self.expressions.len() {
+ Err(ExpressionError::DoesntExist.with_span())
+ } else if !valid_expressions.contains(handle.index()) {
+ Err(ExpressionError::NotInScope.with_span_handle(handle, self.expressions))
+ } else {
+ Ok(self.info[handle].ty.inner_with(self.types))
+ }
+ }
+
+ fn resolve_type(
+ &self,
+ handle: Handle<crate::Expression>,
+ valid_expressions: &BitSet,
+ ) -> Result<&crate::TypeInner, WithSpan<FunctionError>> {
+ self.resolve_type_impl(handle, valid_expressions)
+ .map_err_inner(|source| FunctionError::Expression { handle, source }.with_span())
+ }
+
+ fn resolve_pointer_type(
+ &self,
+ handle: Handle<crate::Expression>,
+ ) -> Result<&crate::TypeInner, FunctionError> {
+ if handle.index() >= self.expressions.len() {
+ Err(FunctionError::Expression {
+ handle,
+ source: ExpressionError::DoesntExist,
+ })
+ } else {
+ Ok(self.info[handle].ty.inner_with(self.types))
+ }
+ }
+}
+
+impl super::Validator {
+ #[cfg(feature = "validate")]
+ fn validate_call(
+ &mut self,
+ function: Handle<crate::Function>,
+ arguments: &[Handle<crate::Expression>],
+ result: Option<Handle<crate::Expression>>,
+ context: &BlockContext,
+ ) -> Result<super::ShaderStages, WithSpan<CallError>> {
+ let fun = &context.functions[function];
+ if fun.arguments.len() != arguments.len() {
+ return Err(CallError::ArgumentCount {
+ required: fun.arguments.len(),
+ seen: arguments.len(),
+ }
+ .with_span());
+ }
+ for (index, (arg, &expr)) in fun.arguments.iter().zip(arguments).enumerate() {
+ let ty = context
+ .resolve_type_impl(expr, &self.valid_expression_set)
+ .map_err_inner(|source| {
+ CallError::Argument { index, source }
+ .with_span_handle(expr, context.expressions)
+ })?;
+ let arg_inner = &context.types[arg.ty].inner;
+ if !ty.equivalent(arg_inner, context.types) {
+ return Err(CallError::ArgumentType {
+ index,
+ required: arg.ty,
+ seen_expression: expr,
+ }
+ .with_span_handle(expr, context.expressions));
+ }
+ }
+
+ if let Some(expr) = result {
+ if self.valid_expression_set.insert(expr.index()) {
+ self.valid_expression_list.push(expr);
+ } else {
+ return Err(CallError::ResultAlreadyInScope(expr)
+ .with_span_handle(expr, context.expressions));
+ }
+ match context.expressions[expr] {
+ crate::Expression::CallResult(callee)
+ if fun.result.is_some() && callee == function => {}
+ _ => {
+ return Err(CallError::ExpressionMismatch(result)
+ .with_span_handle(expr, context.expressions))
+ }
+ }
+ } else if fun.result.is_some() {
+ return Err(CallError::ExpressionMismatch(result).with_span());
+ }
+
+ let callee_info = &context.prev_infos[function.index()];
+ Ok(callee_info.available_stages)
+ }
+
+ #[cfg(feature = "validate")]
+ fn emit_expression(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ context: &BlockContext,
+ ) -> Result<(), WithSpan<FunctionError>> {
+ if self.valid_expression_set.insert(handle.index()) {
+ self.valid_expression_list.push(handle);
+ Ok(())
+ } else {
+ Err(FunctionError::ExpressionAlreadyInScope(handle)
+ .with_span_handle(handle, context.expressions))
+ }
+ }
+
+ #[cfg(feature = "validate")]
+ fn validate_atomic(
+ &mut self,
+ pointer: Handle<crate::Expression>,
+ fun: &crate::AtomicFunction,
+ value: Handle<crate::Expression>,
+ result: Handle<crate::Expression>,
+ context: &BlockContext,
+ ) -> Result<(), WithSpan<FunctionError>> {
+ let pointer_inner = context.resolve_type(pointer, &self.valid_expression_set)?;
+ let (ptr_kind, ptr_width) = match *pointer_inner {
+ crate::TypeInner::Pointer { base, .. } => match context.types[base].inner {
+ crate::TypeInner::Atomic { kind, width } => (kind, width),
+ ref other => {
+ log::error!("Atomic pointer to type {:?}", other);
+ return Err(AtomicError::InvalidPointer(pointer)
+ .with_span_handle(pointer, context.expressions)
+ .into_other());
+ }
+ },
+ ref other => {
+ log::error!("Atomic on type {:?}", other);
+ return Err(AtomicError::InvalidPointer(pointer)
+ .with_span_handle(pointer, context.expressions)
+ .into_other());
+ }
+ };
+
+ let value_inner = context.resolve_type(value, &self.valid_expression_set)?;
+ match *value_inner {
+ crate::TypeInner::Scalar { width, kind } if kind == ptr_kind && width == ptr_width => {}
+ ref other => {
+ log::error!("Atomic operand type {:?}", other);
+ return Err(AtomicError::InvalidOperand(value)
+ .with_span_handle(value, context.expressions)
+ .into_other());
+ }
+ }
+
+ if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
+ if context.resolve_type(cmp, &self.valid_expression_set)? != value_inner {
+ log::error!("Atomic exchange comparison has a different type from the value");
+ return Err(AtomicError::InvalidOperand(cmp)
+ .with_span_handle(cmp, context.expressions)
+ .into_other());
+ }
+ }
+
+ self.emit_expression(result, context)?;
+ match context.expressions[result] {
+ crate::Expression::AtomicResult { ty, comparison }
+ if {
+ let scalar_predicate = |ty: &crate::TypeInner| {
+ *ty == crate::TypeInner::Scalar {
+ kind: ptr_kind,
+ width: ptr_width,
+ }
+ };
+ match &context.types[ty].inner {
+ ty if !comparison => scalar_predicate(ty),
+ &crate::TypeInner::Struct { ref members, .. } if comparison => {
+ validate_atomic_compare_exchange_struct(
+ context.types,
+ members,
+ scalar_predicate,
+ )
+ }
+ _ => false,
+ }
+ } => {}
+ _ => {
+ return Err(AtomicError::ResultTypeMismatch(result)
+ .with_span_handle(result, context.expressions)
+ .into_other())
+ }
+ }
+ Ok(())
+ }
+
+ #[cfg(feature = "validate")]
+ fn validate_block_impl(
+ &mut self,
+ statements: &crate::Block,
+ context: &BlockContext,
+ ) -> Result<BlockInfo, WithSpan<FunctionError>> {
+ use crate::{Statement as S, TypeInner as Ti};
+ let mut finished = false;
+ let mut stages = super::ShaderStages::all();
+ for (statement, &span) in statements.span_iter() {
+ if finished {
+ return Err(FunctionError::InstructionsAfterReturn
+ .with_span_static(span, "instructions after return"));
+ }
+ match *statement {
+ S::Emit(ref range) => {
+ for handle in range.clone() {
+ self.emit_expression(handle, context)?;
+ }
+ }
+ S::Block(ref block) => {
+ let info = self.validate_block(block, context)?;
+ stages &= info.stages;
+ finished = info.finished;
+ }
+ S::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ match *context.resolve_type(condition, &self.valid_expression_set)? {
+ Ti::Scalar {
+ kind: crate::ScalarKind::Bool,
+ width: _,
+ } => {}
+ _ => {
+ return Err(FunctionError::InvalidIfType(condition)
+ .with_span_handle(condition, context.expressions))
+ }
+ }
+ stages &= self.validate_block(accept, context)?.stages;
+ stages &= self.validate_block(reject, context)?.stages;
+ }
+ S::Switch {
+ selector,
+ ref cases,
+ } => {
+ let uint = match context
+ .resolve_type(selector, &self.valid_expression_set)?
+ .scalar_kind()
+ {
+ Some(crate::ScalarKind::Uint) => true,
+ Some(crate::ScalarKind::Sint) => false,
+ _ => {
+ return Err(FunctionError::InvalidSwitchType(selector)
+ .with_span_handle(selector, context.expressions))
+ }
+ };
+ self.switch_values.clear();
+ for case in cases {
+ match case.value {
+ crate::SwitchValue::I32(_) if !uint => {}
+ crate::SwitchValue::U32(_) if uint => {}
+ crate::SwitchValue::Default => {}
+ _ => {
+ return Err(FunctionError::ConflictingCaseType.with_span_static(
+ case.body
+ .span_iter()
+ .next()
+ .map_or(Default::default(), |(_, s)| *s),
+ "conflicting switch arm here",
+ ));
+ }
+ };
+ if !self.switch_values.insert(case.value) {
+ return Err(match case.value {
+ crate::SwitchValue::Default => FunctionError::MultipleDefaultCases
+ .with_span_static(
+ case.body
+ .span_iter()
+ .next()
+ .map_or(Default::default(), |(_, s)| *s),
+ "duplicated switch arm here",
+ ),
+ _ => FunctionError::ConflictingSwitchCase(case.value)
+ .with_span_static(
+ case.body
+ .span_iter()
+ .next()
+ .map_or(Default::default(), |(_, s)| *s),
+ "conflicting switch arm here",
+ ),
+ });
+ }
+ }
+ if !self.switch_values.contains(&crate::SwitchValue::Default) {
+ return Err(FunctionError::MissingDefaultCase
+ .with_span_static(span, "missing default case"));
+ }
+ if let Some(case) = cases.last() {
+ if case.fall_through {
+ return Err(FunctionError::LastCaseFallTrough.with_span_static(
+ case.body
+ .span_iter()
+ .next()
+ .map_or(Default::default(), |(_, s)| *s),
+ "bad switch arm here",
+ ));
+ }
+ }
+ let pass_through_abilities = context.abilities
+ & (ControlFlowAbility::RETURN | ControlFlowAbility::CONTINUE);
+ let sub_context =
+ context.with_abilities(pass_through_abilities | ControlFlowAbility::BREAK);
+ for case in cases {
+ stages &= self.validate_block(&case.body, &sub_context)?.stages;
+ }
+ }
+ S::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ // special handling for block scoping is needed here,
+ // because the continuing{} block inherits the scope
+ let base_expression_count = self.valid_expression_list.len();
+ let pass_through_abilities = context.abilities & ControlFlowAbility::RETURN;
+ stages &= self
+ .validate_block_impl(
+ body,
+ &context.with_abilities(
+ pass_through_abilities
+ | ControlFlowAbility::BREAK
+ | ControlFlowAbility::CONTINUE,
+ ),
+ )?
+ .stages;
+ stages &= self
+ .validate_block_impl(
+ continuing,
+ &context.with_abilities(ControlFlowAbility::empty()),
+ )?
+ .stages;
+
+ if let Some(condition) = break_if {
+ match *context.resolve_type(condition, &self.valid_expression_set)? {
+ Ti::Scalar {
+ kind: crate::ScalarKind::Bool,
+ width: _,
+ } => {}
+ _ => {
+ return Err(FunctionError::InvalidIfType(condition)
+ .with_span_handle(condition, context.expressions))
+ }
+ }
+ }
+
+ for handle in self.valid_expression_list.drain(base_expression_count..) {
+ self.valid_expression_set.remove(handle.index());
+ }
+ }
+ S::Break => {
+ if !context.abilities.contains(ControlFlowAbility::BREAK) {
+ return Err(FunctionError::BreakOutsideOfLoopOrSwitch
+ .with_span_static(span, "invalid break"));
+ }
+ finished = true;
+ }
+ S::Continue => {
+ if !context.abilities.contains(ControlFlowAbility::CONTINUE) {
+ return Err(FunctionError::ContinueOutsideOfLoop
+ .with_span_static(span, "invalid continue"));
+ }
+ finished = true;
+ }
+ S::Return { value } => {
+ if !context.abilities.contains(ControlFlowAbility::RETURN) {
+ return Err(FunctionError::InvalidReturnSpot
+ .with_span_static(span, "invalid return"));
+ }
+ let value_ty = value
+ .map(|expr| context.resolve_type(expr, &self.valid_expression_set))
+ .transpose()?;
+ let expected_ty = context.return_type.map(|ty| &context.types[ty].inner);
+ // We can't return pointers, but it seems best not to embed that
+ // assumption here, so use `TypeInner::equivalent` for comparison.
+ let okay = match (value_ty, expected_ty) {
+ (None, None) => true,
+ (Some(value_inner), Some(expected_inner)) => {
+ value_inner.equivalent(expected_inner, context.types)
+ }
+ (_, _) => false,
+ };
+
+ if !okay {
+ log::error!(
+ "Returning {:?} where {:?} is expected",
+ value_ty,
+ expected_ty
+ );
+ if let Some(handle) = value {
+ return Err(FunctionError::InvalidReturnType(value)
+ .with_span_handle(handle, context.expressions));
+ } else {
+ return Err(FunctionError::InvalidReturnType(value)
+ .with_span_static(span, "invalid return"));
+ }
+ }
+ finished = true;
+ }
+ S::Kill => {
+ stages &= super::ShaderStages::FRAGMENT;
+ finished = true;
+ }
+ S::Barrier(_) => {
+ stages &= super::ShaderStages::COMPUTE;
+ }
+ S::Store { pointer, value } => {
+ let mut current = pointer;
+ loop {
+ let _ = context
+ .resolve_pointer_type(current)
+ .map_err(|e| e.with_span())?;
+ match context.expressions[current] {
+ crate::Expression::Access { base, .. }
+ | crate::Expression::AccessIndex { base, .. } => current = base,
+ crate::Expression::LocalVariable(_)
+ | crate::Expression::GlobalVariable(_)
+ | crate::Expression::FunctionArgument(_) => break,
+ _ => {
+ return Err(FunctionError::InvalidStorePointer(current)
+ .with_span_handle(pointer, context.expressions))
+ }
+ }
+ }
+
+ let value_ty = context.resolve_type(value, &self.valid_expression_set)?;
+ match *value_ty {
+ Ti::Image { .. } | Ti::Sampler { .. } => {
+ return Err(FunctionError::InvalidStoreValue(value)
+ .with_span_handle(value, context.expressions));
+ }
+ _ => {}
+ }
+
+ let pointer_ty = context
+ .resolve_pointer_type(pointer)
+ .map_err(|e| e.with_span())?;
+
+ let good = match *pointer_ty {
+ Ti::Pointer { base, space: _ } => match context.types[base].inner {
+ Ti::Atomic { kind, width } => *value_ty == Ti::Scalar { kind, width },
+ ref other => value_ty == other,
+ },
+ Ti::ValuePointer {
+ size: Some(size),
+ kind,
+ width,
+ space: _,
+ } => *value_ty == Ti::Vector { size, kind, width },
+ Ti::ValuePointer {
+ size: None,
+ kind,
+ width,
+ space: _,
+ } => *value_ty == Ti::Scalar { kind, width },
+ _ => false,
+ };
+ if !good {
+ return Err(FunctionError::InvalidStoreTypes { pointer, value }
+ .with_span()
+ .with_handle(pointer, context.expressions)
+ .with_handle(value, context.expressions));
+ }
+
+ if let Some(space) = pointer_ty.pointer_space() {
+ if !space.access().contains(crate::StorageAccess::STORE) {
+ return Err(FunctionError::InvalidStorePointer(pointer)
+ .with_span_static(
+ context.expressions.get_span(pointer),
+ "writing to this location is not permitted",
+ ));
+ }
+ }
+ }
+ S::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => {
+ //Note: this code uses a lot of `FunctionError::InvalidImageStore`,
+ // and could probably be refactored.
+ let var = match *context.get_expression(image) {
+ crate::Expression::GlobalVariable(var_handle) => {
+ &context.global_vars[var_handle]
+ }
+ // We're looking at a binding index situation, so punch through the index and look at the global behind it.
+ crate::Expression::Access { base, .. }
+ | crate::Expression::AccessIndex { base, .. } => {
+ match *context.get_expression(base) {
+ crate::Expression::GlobalVariable(var_handle) => {
+ &context.global_vars[var_handle]
+ }
+ _ => {
+ return Err(FunctionError::InvalidImageStore(
+ ExpressionError::ExpectedGlobalVariable,
+ )
+ .with_span_handle(image, context.expressions))
+ }
+ }
+ }
+ _ => {
+ return Err(FunctionError::InvalidImageStore(
+ ExpressionError::ExpectedGlobalVariable,
+ )
+ .with_span_handle(image, context.expressions))
+ }
+ };
+
+ // Punch through a binding array to get the underlying type
+ let global_ty = match context.types[var.ty].inner {
+ Ti::BindingArray { base, .. } => &context.types[base].inner,
+ ref inner => inner,
+ };
+
+ let value_ty = match *global_ty {
+ Ti::Image {
+ class,
+ arrayed,
+ dim,
+ } => {
+ match context
+ .resolve_type(coordinate, &self.valid_expression_set)?
+ .image_storage_coordinates()
+ {
+ Some(coord_dim) if coord_dim == dim => {}
+ _ => {
+ return Err(FunctionError::InvalidImageStore(
+ ExpressionError::InvalidImageCoordinateType(
+ dim, coordinate,
+ ),
+ )
+ .with_span_handle(coordinate, context.expressions));
+ }
+ };
+ if arrayed != array_index.is_some() {
+ return Err(FunctionError::InvalidImageStore(
+ ExpressionError::InvalidImageArrayIndex,
+ )
+ .with_span_handle(coordinate, context.expressions));
+ }
+ if let Some(expr) = array_index {
+ match *context.resolve_type(expr, &self.valid_expression_set)? {
+ Ti::Scalar {
+ kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
+ width: _,
+ } => {}
+ _ => {
+ return Err(FunctionError::InvalidImageStore(
+ ExpressionError::InvalidImageArrayIndexType(expr),
+ )
+ .with_span_handle(expr, context.expressions));
+ }
+ }
+ }
+ match class {
+ crate::ImageClass::Storage { format, .. } => {
+ crate::TypeInner::Vector {
+ kind: format.into(),
+ size: crate::VectorSize::Quad,
+ width: 4,
+ }
+ }
+ _ => {
+ return Err(FunctionError::InvalidImageStore(
+ ExpressionError::InvalidImageClass(class),
+ )
+ .with_span_handle(image, context.expressions));
+ }
+ }
+ }
+ _ => {
+ return Err(FunctionError::InvalidImageStore(
+ ExpressionError::ExpectedImageType(var.ty),
+ )
+ .with_span()
+ .with_handle(var.ty, context.types)
+ .with_handle(image, context.expressions))
+ }
+ };
+
+ if *context.resolve_type(value, &self.valid_expression_set)? != value_ty {
+ return Err(FunctionError::InvalidStoreValue(value)
+ .with_span_handle(value, context.expressions));
+ }
+ }
+ S::Call {
+ function,
+ ref arguments,
+ result,
+ } => match self.validate_call(function, arguments, result, context) {
+ Ok(callee_stages) => stages &= callee_stages,
+ Err(error) => {
+ return Err(error.and_then(|error| {
+ FunctionError::InvalidCall { function, error }
+ .with_span_static(span, "invalid function call")
+ }))
+ }
+ },
+ S::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result,
+ } => {
+ self.validate_atomic(pointer, fun, value, result, context)?;
+ }
+ S::RayQuery { query, ref fun } => {
+ let query_var = match *context.get_expression(query) {
+ crate::Expression::LocalVariable(var) => &context.local_vars[var],
+ ref other => {
+ log::error!("Unexpected ray query expression {other:?}");
+ return Err(FunctionError::InvalidRayQueryExpression(query)
+ .with_span_static(span, "invalid query expression"));
+ }
+ };
+ match context.types[query_var.ty].inner {
+ Ti::RayQuery => {}
+ ref other => {
+ log::error!("Unexpected ray query type {other:?}");
+ return Err(FunctionError::InvalidRayQueryType(query_var.ty)
+ .with_span_static(span, "invalid query type"));
+ }
+ }
+ match *fun {
+ crate::RayQueryFunction::Initialize {
+ acceleration_structure,
+ descriptor,
+ } => {
+ match *context
+ .resolve_type(acceleration_structure, &self.valid_expression_set)?
+ {
+ Ti::AccelerationStructure => {}
+ _ => {
+ return Err(FunctionError::InvalidAccelerationStructure(
+ acceleration_structure,
+ )
+ .with_span_static(span, "invalid acceleration structure"))
+ }
+ }
+ let desc_ty_given =
+ context.resolve_type(descriptor, &self.valid_expression_set)?;
+ let desc_ty_expected = context
+ .special_types
+ .ray_desc
+ .map(|handle| &context.types[handle].inner);
+ if Some(desc_ty_given) != desc_ty_expected {
+ return Err(FunctionError::InvalidRayDescriptor(descriptor)
+ .with_span_static(span, "invalid ray descriptor"));
+ }
+ }
+ crate::RayQueryFunction::Proceed { result } => {
+ self.emit_expression(result, context)?;
+ }
+ crate::RayQueryFunction::Terminate => {}
+ }
+ }
+ }
+ }
+ Ok(BlockInfo { stages, finished })
+ }
+
+ #[cfg(feature = "validate")]
+ fn validate_block(
+ &mut self,
+ statements: &crate::Block,
+ context: &BlockContext,
+ ) -> Result<BlockInfo, WithSpan<FunctionError>> {
+ let base_expression_count = self.valid_expression_list.len();
+ let info = self.validate_block_impl(statements, context)?;
+ for handle in self.valid_expression_list.drain(base_expression_count..) {
+ self.valid_expression_set.remove(handle.index());
+ }
+ Ok(info)
+ }
+
+ #[cfg(feature = "validate")]
+ fn validate_local_var(
+ &self,
+ var: &crate::LocalVariable,
+ types: &UniqueArena<crate::Type>,
+ constants: &Arena<crate::Constant>,
+ ) -> Result<(), LocalVariableError> {
+ log::debug!("var {:?}", var);
+ let type_info = self
+ .types
+ .get(var.ty.index())
+ .ok_or(LocalVariableError::InvalidType(var.ty))?;
+ if !type_info
+ .flags
+ .contains(super::TypeFlags::DATA | super::TypeFlags::SIZED)
+ {
+ return Err(LocalVariableError::InvalidType(var.ty));
+ }
+
+ if let Some(const_handle) = var.init {
+ match constants[const_handle].inner {
+ crate::ConstantInner::Scalar { width, ref value } => {
+ let ty_inner = crate::TypeInner::Scalar {
+ width,
+ kind: value.scalar_kind(),
+ };
+ if types[var.ty].inner != ty_inner {
+ return Err(LocalVariableError::InitializerType);
+ }
+ }
+ crate::ConstantInner::Composite { ty, components: _ } => {
+ if ty != var.ty {
+ return Err(LocalVariableError::InitializerType);
+ }
+ }
+ }
+ }
+ Ok(())
+ }
+
+ pub(super) fn validate_function(
+ &mut self,
+ fun: &crate::Function,
+ module: &crate::Module,
+ mod_info: &ModuleInfo,
+ #[cfg_attr(not(feature = "validate"), allow(unused))] entry_point: bool,
+ ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
+ #[cfg_attr(not(feature = "validate"), allow(unused_mut))]
+ let mut info = mod_info.process_function(fun, module, self.flags, self.capabilities)?;
+
+ #[cfg(feature = "validate")]
+ for (var_handle, var) in fun.local_variables.iter() {
+ self.validate_local_var(var, &module.types, &module.constants)
+ .map_err(|source| {
+ FunctionError::LocalVariable {
+ handle: var_handle,
+ name: var.name.clone().unwrap_or_default(),
+ source,
+ }
+ .with_span_handle(var.ty, &module.types)
+ .with_handle(var_handle, &fun.local_variables)
+ })?;
+ }
+
+ #[cfg(feature = "validate")]
+ for (index, argument) in fun.arguments.iter().enumerate() {
+ match module.types[argument.ty].inner.pointer_space() {
+ Some(
+ crate::AddressSpace::Private
+ | crate::AddressSpace::Function
+ | crate::AddressSpace::WorkGroup,
+ )
+ | None => {}
+ Some(other) => {
+ return Err(FunctionError::InvalidArgumentPointerSpace {
+ index,
+ name: argument.name.clone().unwrap_or_default(),
+ space: other,
+ }
+ .with_span_handle(argument.ty, &module.types))
+ }
+ }
+ // Check for the least informative error last.
+ if !self.types[argument.ty.index()]
+ .flags
+ .contains(super::TypeFlags::ARGUMENT)
+ {
+ return Err(FunctionError::InvalidArgumentType {
+ index,
+ name: argument.name.clone().unwrap_or_default(),
+ }
+ .with_span_handle(argument.ty, &module.types));
+ }
+
+ if !entry_point && argument.binding.is_some() {
+ return Err(FunctionError::PipelineInputRegularFunction {
+ name: argument.name.clone().unwrap_or_default(),
+ }
+ .with_span_handle(argument.ty, &module.types));
+ }
+ }
+
+ #[cfg(feature = "validate")]
+ if let Some(ref result) = fun.result {
+ if !self.types[result.ty.index()]
+ .flags
+ .contains(super::TypeFlags::CONSTRUCTIBLE)
+ {
+ return Err(FunctionError::NonConstructibleReturnType
+ .with_span_handle(result.ty, &module.types));
+ }
+
+ if !entry_point && result.binding.is_some() {
+ return Err(FunctionError::PipelineOutputRegularFunction
+ .with_span_handle(result.ty, &module.types));
+ }
+ }
+
+ self.valid_expression_set.clear();
+ self.valid_expression_list.clear();
+ for (handle, expr) in fun.expressions.iter() {
+ if expr.needs_pre_emit() {
+ self.valid_expression_set.insert(handle.index());
+ }
+ #[cfg(feature = "validate")]
+ if self.flags.contains(super::ValidationFlags::EXPRESSIONS) {
+ match self.validate_expression(
+ handle,
+ expr,
+ fun,
+ module,
+ &info,
+ &mod_info.functions,
+ ) {
+ Ok(stages) => info.available_stages &= stages,
+ Err(source) => {
+ return Err(FunctionError::Expression { handle, source }
+ .with_span_handle(handle, &fun.expressions))
+ }
+ }
+ }
+ }
+
+ #[cfg(feature = "validate")]
+ if self.flags.contains(super::ValidationFlags::BLOCKS) {
+ let stages = self
+ .validate_block(
+ &fun.body,
+ &BlockContext::new(fun, module, &info, &mod_info.functions),
+ )?
+ .stages;
+ info.available_stages &= stages;
+ }
+ Ok(info)
+ }
+}