summaryrefslogtreecommitdiffstats
path: root/third_party/rust/naga/src/proc
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/naga/src/proc')
-rw-r--r--third_party/rust/naga/src/proc/call_graph.rs74
-rw-r--r--third_party/rust/naga/src/proc/interface.rs290
-rw-r--r--third_party/rust/naga/src/proc/mod.rs67
-rw-r--r--third_party/rust/naga/src/proc/namer.rs113
-rw-r--r--third_party/rust/naga/src/proc/typifier.rs424
-rw-r--r--third_party/rust/naga/src/proc/validator.rs489
6 files changed, 1457 insertions, 0 deletions
diff --git a/third_party/rust/naga/src/proc/call_graph.rs b/third_party/rust/naga/src/proc/call_graph.rs
new file mode 100644
index 0000000000..1c580d5c15
--- /dev/null
+++ b/third_party/rust/naga/src/proc/call_graph.rs
@@ -0,0 +1,74 @@
+use crate::{
+ arena::{Arena, Handle},
+ proc::{Interface, Visitor},
+ Function,
+};
+use petgraph::{
+ graph::{DefaultIx, NodeIndex},
+ Graph,
+};
+
+pub type CallGraph = Graph<Handle<Function>, ()>;
+
+pub struct CallGraphBuilder<'a> {
+ pub functions: &'a Arena<Function>,
+}
+
+impl<'a> CallGraphBuilder<'a> {
+ pub fn process(&self, func: &Function) -> CallGraph {
+ let mut graph = Graph::new();
+ let mut children = Vec::new();
+
+ let visitor = CallGraphVisitor {
+ children: &mut children,
+ };
+
+ let mut interface = Interface {
+ expressions: &func.expressions,
+ local_variables: &func.local_variables,
+ visitor,
+ };
+
+ interface.traverse(&func.body);
+
+ for handle in children {
+ let id = graph.add_node(handle);
+ self.collect(handle, id, &mut graph);
+ }
+
+ graph
+ }
+
+ fn collect(&self, handle: Handle<Function>, id: NodeIndex<DefaultIx>, graph: &mut CallGraph) {
+ let mut children = Vec::new();
+ let visitor = CallGraphVisitor {
+ children: &mut children,
+ };
+ let func = &self.functions[handle];
+
+ let mut interface = Interface {
+ expressions: &func.expressions,
+ local_variables: &func.local_variables,
+ visitor,
+ };
+
+ interface.traverse(&func.body);
+
+ for handle in children {
+ let child_id = graph.add_node(handle);
+ graph.add_edge(id, child_id, ());
+
+ self.collect(handle, child_id, graph);
+ }
+ }
+}
+
+struct CallGraphVisitor<'a> {
+ children: &'a mut Vec<Handle<Function>>,
+}
+
+impl<'a> Visitor for CallGraphVisitor<'a> {
+ fn visit_fun(&mut self, func: Handle<Function>) {
+ self.children.push(func)
+ }
+}
diff --git a/third_party/rust/naga/src/proc/interface.rs b/third_party/rust/naga/src/proc/interface.rs
new file mode 100644
index 0000000000..b512452fe2
--- /dev/null
+++ b/third_party/rust/naga/src/proc/interface.rs
@@ -0,0 +1,290 @@
+use crate::arena::{Arena, Handle};
+
+pub struct Interface<'a, T> {
+ pub expressions: &'a Arena<crate::Expression>,
+ pub local_variables: &'a Arena<crate::LocalVariable>,
+ pub visitor: T,
+}
+
+pub trait Visitor {
+ fn visit_expr(&mut self, _: &crate::Expression) {}
+ fn visit_lhs_expr(&mut self, _: &crate::Expression) {}
+ fn visit_fun(&mut self, _: Handle<crate::Function>) {}
+}
+
+impl<'a, T> Interface<'a, T>
+where
+ T: Visitor,
+{
+ fn traverse_expr(&mut self, handle: Handle<crate::Expression>) {
+ use crate::Expression as E;
+
+ let expr = &self.expressions[handle];
+
+ self.visitor.visit_expr(expr);
+
+ match *expr {
+ E::Access { base, index } => {
+ self.traverse_expr(base);
+ self.traverse_expr(index);
+ }
+ E::AccessIndex { base, .. } => {
+ self.traverse_expr(base);
+ }
+ E::Constant(_) => {}
+ E::Compose { ref components, .. } => {
+ for &comp in components {
+ self.traverse_expr(comp);
+ }
+ }
+ E::FunctionArgument(_) | E::GlobalVariable(_) | E::LocalVariable(_) => {}
+ E::Load { pointer } => {
+ self.traverse_expr(pointer);
+ }
+ E::ImageSample {
+ image,
+ sampler,
+ coordinate,
+ level,
+ depth_ref,
+ } => {
+ self.traverse_expr(image);
+ self.traverse_expr(sampler);
+ self.traverse_expr(coordinate);
+ match level {
+ crate::SampleLevel::Auto | crate::SampleLevel::Zero => (),
+ crate::SampleLevel::Exact(h) | crate::SampleLevel::Bias(h) => {
+ self.traverse_expr(h)
+ }
+ }
+ if let Some(dref) = depth_ref {
+ self.traverse_expr(dref);
+ }
+ }
+ E::ImageLoad {
+ image,
+ coordinate,
+ index,
+ } => {
+ self.traverse_expr(image);
+ self.traverse_expr(coordinate);
+ if let Some(index) = index {
+ self.traverse_expr(index);
+ }
+ }
+ E::Unary { expr, .. } => {
+ self.traverse_expr(expr);
+ }
+ E::Binary { left, right, .. } => {
+ self.traverse_expr(left);
+ self.traverse_expr(right);
+ }
+ E::Select {
+ condition,
+ accept,
+ reject,
+ } => {
+ self.traverse_expr(condition);
+ self.traverse_expr(accept);
+ self.traverse_expr(reject);
+ }
+ E::Intrinsic { argument, .. } => {
+ self.traverse_expr(argument);
+ }
+ E::Transpose(matrix) => {
+ self.traverse_expr(matrix);
+ }
+ E::DotProduct(left, right) => {
+ self.traverse_expr(left);
+ self.traverse_expr(right);
+ }
+ E::CrossProduct(left, right) => {
+ self.traverse_expr(left);
+ self.traverse_expr(right);
+ }
+ E::As { expr, .. } => {
+ self.traverse_expr(expr);
+ }
+ E::Derivative { expr, .. } => {
+ self.traverse_expr(expr);
+ }
+ E::Call {
+ ref origin,
+ ref arguments,
+ } => {
+ for &argument in arguments {
+ self.traverse_expr(argument);
+ }
+ if let crate::FunctionOrigin::Local(fun) = *origin {
+ self.visitor.visit_fun(fun);
+ }
+ }
+ E::ArrayLength(expr) => {
+ self.traverse_expr(expr);
+ }
+ }
+ }
+
+ pub fn traverse(&mut self, block: &[crate::Statement]) {
+ for statement in block {
+ use crate::Statement as S;
+ match *statement {
+ S::Break | S::Continue | S::Kill => (),
+ S::Block(ref b) => {
+ self.traverse(b);
+ }
+ S::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ self.traverse_expr(condition);
+ self.traverse(accept);
+ self.traverse(reject);
+ }
+ S::Switch {
+ selector,
+ ref cases,
+ ref default,
+ } => {
+ self.traverse_expr(selector);
+ for &(ref case, _) in cases.values() {
+ self.traverse(case);
+ }
+ self.traverse(default);
+ }
+ S::Loop {
+ ref body,
+ ref continuing,
+ } => {
+ self.traverse(body);
+ self.traverse(continuing);
+ }
+ S::Return { value } => {
+ if let Some(expr) = value {
+ self.traverse_expr(expr);
+ }
+ }
+ S::Store { pointer, value } => {
+ let mut left = pointer;
+ loop {
+ match self.expressions[left] {
+ crate::Expression::Access { base, index } => {
+ self.traverse_expr(index);
+ left = base;
+ }
+ crate::Expression::AccessIndex { base, .. } => {
+ left = base;
+ }
+ _ => break,
+ }
+ }
+ self.visitor.visit_lhs_expr(&self.expressions[left]);
+ self.traverse_expr(value);
+ }
+ }
+ }
+ }
+}
+
+struct GlobalUseVisitor<'a>(&'a mut [crate::GlobalUse]);
+
+impl Visitor for GlobalUseVisitor<'_> {
+ fn visit_expr(&mut self, expr: &crate::Expression) {
+ if let crate::Expression::GlobalVariable(handle) = expr {
+ self.0[handle.index()] |= crate::GlobalUse::LOAD;
+ }
+ }
+
+ fn visit_lhs_expr(&mut self, expr: &crate::Expression) {
+ if let crate::Expression::GlobalVariable(handle) = expr {
+ self.0[handle.index()] |= crate::GlobalUse::STORE;
+ }
+ }
+}
+
+impl crate::Function {
+ pub fn fill_global_use(&mut self, globals: &Arena<crate::GlobalVariable>) {
+ self.global_usage.clear();
+ self.global_usage
+ .resize(globals.len(), crate::GlobalUse::empty());
+
+ let mut io = Interface {
+ expressions: &self.expressions,
+ local_variables: &self.local_variables,
+ visitor: GlobalUseVisitor(&mut self.global_usage),
+ };
+ io.traverse(&self.body);
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::{
+ Arena, Expression, GlobalUse, GlobalVariable, Handle, Statement, StorageAccess,
+ StorageClass,
+ };
+
+ #[test]
+ fn global_use_scan() {
+ let test_global = GlobalVariable {
+ name: None,
+ class: StorageClass::Uniform,
+ binding: None,
+ ty: Handle::new(std::num::NonZeroU32::new(1).unwrap()),
+ init: None,
+ interpolation: None,
+ storage_access: StorageAccess::empty(),
+ };
+ let mut test_globals = Arena::new();
+
+ let global_1 = test_globals.append(test_global.clone());
+ let global_2 = test_globals.append(test_global.clone());
+ let global_3 = test_globals.append(test_global.clone());
+ let global_4 = test_globals.append(test_global);
+
+ let mut expressions = Arena::new();
+ let global_1_expr = expressions.append(Expression::GlobalVariable(global_1));
+ let global_2_expr = expressions.append(Expression::GlobalVariable(global_2));
+ let global_3_expr = expressions.append(Expression::GlobalVariable(global_3));
+ let global_4_expr = expressions.append(Expression::GlobalVariable(global_4));
+
+ let test_body = vec![
+ Statement::Return {
+ value: Some(global_1_expr),
+ },
+ Statement::Store {
+ pointer: global_2_expr,
+ value: global_1_expr,
+ },
+ Statement::Store {
+ pointer: expressions.append(Expression::Access {
+ base: global_3_expr,
+ index: global_4_expr,
+ }),
+ value: global_1_expr,
+ },
+ ];
+
+ let mut function = crate::Function {
+ name: None,
+ arguments: Vec::new(),
+ return_type: None,
+ local_variables: Arena::new(),
+ expressions,
+ global_usage: Vec::new(),
+ body: test_body,
+ };
+ function.fill_global_use(&test_globals);
+
+ assert_eq!(
+ &function.global_usage,
+ &[
+ GlobalUse::LOAD,
+ GlobalUse::STORE,
+ GlobalUse::STORE,
+ GlobalUse::LOAD,
+ ],
+ )
+ }
+}
diff --git a/third_party/rust/naga/src/proc/mod.rs b/third_party/rust/naga/src/proc/mod.rs
new file mode 100644
index 0000000000..961e55da7c
--- /dev/null
+++ b/third_party/rust/naga/src/proc/mod.rs
@@ -0,0 +1,67 @@
+//! Module processing functionality.
+
+#[cfg(feature = "petgraph")]
+mod call_graph;
+mod interface;
+mod namer;
+mod typifier;
+mod validator;
+
+#[cfg(feature = "petgraph")]
+pub use call_graph::{CallGraph, CallGraphBuilder};
+pub use interface::{Interface, Visitor};
+pub use namer::{EntryPointIndex, NameKey, Namer};
+pub use typifier::{check_constant_type, ResolveContext, ResolveError, Typifier};
+pub use validator::{ValidationError, Validator};
+
+impl From<super::StorageFormat> for super::ScalarKind {
+ fn from(format: super::StorageFormat) -> Self {
+ use super::{ScalarKind as Sk, StorageFormat as Sf};
+ match format {
+ Sf::R8Unorm => Sk::Float,
+ Sf::R8Snorm => Sk::Float,
+ Sf::R8Uint => Sk::Uint,
+ Sf::R8Sint => Sk::Sint,
+ Sf::R16Uint => Sk::Uint,
+ Sf::R16Sint => Sk::Sint,
+ Sf::R16Float => Sk::Float,
+ Sf::Rg8Unorm => Sk::Float,
+ Sf::Rg8Snorm => Sk::Float,
+ Sf::Rg8Uint => Sk::Uint,
+ Sf::Rg8Sint => Sk::Sint,
+ Sf::R32Uint => Sk::Uint,
+ Sf::R32Sint => Sk::Sint,
+ Sf::R32Float => Sk::Float,
+ Sf::Rg16Uint => Sk::Uint,
+ Sf::Rg16Sint => Sk::Sint,
+ Sf::Rg16Float => Sk::Float,
+ Sf::Rgba8Unorm => Sk::Float,
+ Sf::Rgba8Snorm => Sk::Float,
+ Sf::Rgba8Uint => Sk::Uint,
+ Sf::Rgba8Sint => Sk::Sint,
+ Sf::Rgb10a2Unorm => Sk::Float,
+ Sf::Rg11b10Float => Sk::Float,
+ Sf::Rg32Uint => Sk::Uint,
+ Sf::Rg32Sint => Sk::Sint,
+ Sf::Rg32Float => Sk::Float,
+ Sf::Rgba16Uint => Sk::Uint,
+ Sf::Rgba16Sint => Sk::Sint,
+ Sf::Rgba16Float => Sk::Float,
+ Sf::Rgba32Uint => Sk::Uint,
+ Sf::Rgba32Sint => Sk::Sint,
+ Sf::Rgba32Float => Sk::Float,
+ }
+ }
+}
+
+impl crate::TypeInner {
+ pub fn scalar_kind(&self) -> Option<super::ScalarKind> {
+ match *self {
+ super::TypeInner::Scalar { kind, .. } | super::TypeInner::Vector { kind, .. } => {
+ Some(kind)
+ }
+ super::TypeInner::Matrix { .. } => Some(super::ScalarKind::Float),
+ _ => None,
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/proc/namer.rs b/third_party/rust/naga/src/proc/namer.rs
new file mode 100644
index 0000000000..03b508904b
--- /dev/null
+++ b/third_party/rust/naga/src/proc/namer.rs
@@ -0,0 +1,113 @@
+use crate::{arena::Handle, FastHashMap};
+use std::collections::hash_map::Entry;
+
+pub type EntryPointIndex = u16;
+
+#[derive(Debug, Eq, Hash, PartialEq)]
+pub enum NameKey {
+ GlobalVariable(Handle<crate::GlobalVariable>),
+ Type(Handle<crate::Type>),
+ StructMember(Handle<crate::Type>, u32),
+ Function(Handle<crate::Function>),
+ FunctionArgument(Handle<crate::Function>, u32),
+ FunctionLocal(Handle<crate::Function>, Handle<crate::LocalVariable>),
+ EntryPoint(EntryPointIndex),
+ EntryPointLocal(EntryPointIndex, Handle<crate::LocalVariable>),
+}
+
+/// This processor assigns names to all the things in a module
+/// that may need identifiers in a textual backend.
+pub struct Namer {
+ unique: FastHashMap<String, u32>,
+}
+
+impl Namer {
+ fn sanitize(string: &str) -> String {
+ let mut base = string
+ .chars()
+ .skip_while(|c| c.is_numeric())
+ .filter(|&c| c.is_ascii_alphanumeric() || c == '_')
+ .collect::<String>();
+ // close the name by '_' if the re is a number, so that
+ // we can have our own number!
+ match base.chars().next_back() {
+ Some(c) if !c.is_numeric() => {}
+ _ => base.push('_'),
+ };
+ base
+ }
+
+ fn call(&mut self, label_raw: &str) -> String {
+ let base = Self::sanitize(label_raw);
+ match self.unique.entry(base) {
+ Entry::Occupied(mut e) => {
+ *e.get_mut() += 1;
+ format!("{}{}", e.key(), e.get())
+ }
+ Entry::Vacant(e) => {
+ let name = e.key().to_string();
+ e.insert(0);
+ name
+ }
+ }
+ }
+
+ fn call_or(&mut self, label: &Option<String>, fallback: &str) -> String {
+ self.call(match *label {
+ Some(ref name) => name,
+ None => fallback,
+ })
+ }
+
+ pub fn process(
+ module: &crate::Module,
+ reserved: &[&str],
+ output: &mut FastHashMap<NameKey, String>,
+ ) {
+ let mut this = Namer {
+ unique: reserved
+ .iter()
+ .map(|string| (string.to_string(), 0))
+ .collect(),
+ };
+
+ for (handle, var) in module.global_variables.iter() {
+ let name = this.call_or(&var.name, "global");
+ output.insert(NameKey::GlobalVariable(handle), name);
+ }
+
+ for (ty_handle, ty) in module.types.iter() {
+ let ty_name = this.call_or(&ty.name, "type");
+ output.insert(NameKey::Type(ty_handle), ty_name);
+
+ if let crate::TypeInner::Struct { ref members } = ty.inner {
+ for (index, member) in members.iter().enumerate() {
+ let name = this.call_or(&member.name, "member");
+ output.insert(NameKey::StructMember(ty_handle, index as u32), name);
+ }
+ }
+ }
+
+ for (fun_handle, fun) in module.functions.iter() {
+ let fun_name = this.call_or(&fun.name, "function");
+ output.insert(NameKey::Function(fun_handle), fun_name);
+ for (index, arg) in fun.arguments.iter().enumerate() {
+ let name = this.call_or(&arg.name, "param");
+ output.insert(NameKey::FunctionArgument(fun_handle, index as u32), name);
+ }
+ for (handle, var) in fun.local_variables.iter() {
+ let name = this.call_or(&var.name, "local");
+ output.insert(NameKey::FunctionLocal(fun_handle, handle), name);
+ }
+ }
+
+ for (ep_index, (&(_, ref base_name), ep)) in module.entry_points.iter().enumerate() {
+ let ep_name = this.call(base_name);
+ output.insert(NameKey::EntryPoint(ep_index as _), ep_name);
+ for (handle, var) in ep.function.local_variables.iter() {
+ let name = this.call_or(&var.name, "local");
+ output.insert(NameKey::EntryPointLocal(ep_index as _, handle), name);
+ }
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/proc/typifier.rs b/third_party/rust/naga/src/proc/typifier.rs
new file mode 100644
index 0000000000..b09893c179
--- /dev/null
+++ b/third_party/rust/naga/src/proc/typifier.rs
@@ -0,0 +1,424 @@
+use crate::arena::{Arena, Handle};
+
+use thiserror::Error;
+
+#[derive(Debug, PartialEq)]
+enum Resolution {
+ Handle(Handle<crate::Type>),
+ Value(crate::TypeInner),
+}
+
+// Clone is only implemented for numeric variants of `TypeInner`.
+impl Clone for Resolution {
+ fn clone(&self) -> Self {
+ match *self {
+ Resolution::Handle(handle) => Resolution::Handle(handle),
+ Resolution::Value(ref v) => Resolution::Value(match *v {
+ crate::TypeInner::Scalar { kind, width } => {
+ crate::TypeInner::Scalar { kind, width }
+ }
+ crate::TypeInner::Vector { size, kind, width } => {
+ crate::TypeInner::Vector { size, kind, width }
+ }
+ crate::TypeInner::Matrix {
+ rows,
+ columns,
+ width,
+ } => crate::TypeInner::Matrix {
+ rows,
+ columns,
+ width,
+ },
+ #[allow(clippy::panic)]
+ _ => panic!("Unepxected clone type: {:?}", v),
+ }),
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct Typifier {
+ resolutions: Vec<Resolution>,
+}
+
+#[derive(Clone, Debug, Error, PartialEq)]
+pub enum ResolveError {
+ #[error("Invalid index into array")]
+ InvalidAccessIndex,
+ #[error("Function {name} not defined")]
+ FunctionNotDefined { name: String },
+ #[error("Function without return type")]
+ FunctionReturnsVoid,
+ #[error("Type is not found in the given immutable arena")]
+ TypeNotFound,
+ #[error("Incompatible operand: {op} {operand}")]
+ IncompatibleOperand { op: String, operand: String },
+ #[error("Incompatible operands: {left} {op} {right}")]
+ IncompatibleOperands {
+ op: String,
+ left: String,
+ right: String,
+ },
+}
+
+pub struct ResolveContext<'a> {
+ pub constants: &'a Arena<crate::Constant>,
+ pub global_vars: &'a Arena<crate::GlobalVariable>,
+ pub local_vars: &'a Arena<crate::LocalVariable>,
+ pub functions: &'a Arena<crate::Function>,
+ pub arguments: &'a [crate::FunctionArgument],
+}
+
+impl Typifier {
+ pub fn new() -> Self {
+ Typifier {
+ resolutions: Vec::new(),
+ }
+ }
+
+ pub fn clear(&mut self) {
+ self.resolutions.clear()
+ }
+
+ pub fn get<'a>(
+ &'a self,
+ expr_handle: Handle<crate::Expression>,
+ types: &'a Arena<crate::Type>,
+ ) -> &'a crate::TypeInner {
+ match self.resolutions[expr_handle.index()] {
+ Resolution::Handle(ty_handle) => &types[ty_handle].inner,
+ Resolution::Value(ref inner) => inner,
+ }
+ }
+
+ pub fn get_handle(
+ &self,
+ expr_handle: Handle<crate::Expression>,
+ ) -> Option<Handle<crate::Type>> {
+ match self.resolutions[expr_handle.index()] {
+ Resolution::Handle(ty_handle) => Some(ty_handle),
+ Resolution::Value(_) => None,
+ }
+ }
+
+ fn resolve_impl(
+ &self,
+ expr: &crate::Expression,
+ types: &Arena<crate::Type>,
+ ctx: &ResolveContext,
+ ) -> Result<Resolution, ResolveError> {
+ Ok(match *expr {
+ crate::Expression::Access { base, .. } => match *self.get(base, types) {
+ crate::TypeInner::Array { base, .. } => Resolution::Handle(base),
+ crate::TypeInner::Vector {
+ size: _,
+ kind,
+ width,
+ } => Resolution::Value(crate::TypeInner::Scalar { kind, width }),
+ crate::TypeInner::Matrix {
+ rows: size,
+ columns: _,
+ width,
+ } => Resolution::Value(crate::TypeInner::Vector {
+ size,
+ kind: crate::ScalarKind::Float,
+ width,
+ }),
+ ref other => {
+ return Err(ResolveError::IncompatibleOperand {
+ op: "access".to_string(),
+ operand: format!("{:?}", other),
+ })
+ }
+ },
+ crate::Expression::AccessIndex { base, index } => match *self.get(base, types) {
+ crate::TypeInner::Vector { size, kind, width } => {
+ if index >= size as u32 {
+ return Err(ResolveError::InvalidAccessIndex);
+ }
+ Resolution::Value(crate::TypeInner::Scalar { kind, width })
+ }
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ if index >= columns as u32 {
+ return Err(ResolveError::InvalidAccessIndex);
+ }
+ Resolution::Value(crate::TypeInner::Vector {
+ size: rows,
+ kind: crate::ScalarKind::Float,
+ width,
+ })
+ }
+ crate::TypeInner::Array { base, .. } => Resolution::Handle(base),
+ crate::TypeInner::Struct { ref members } => {
+ let member = members
+ .get(index as usize)
+ .ok_or(ResolveError::InvalidAccessIndex)?;
+ Resolution::Handle(member.ty)
+ }
+ ref other => {
+ return Err(ResolveError::IncompatibleOperand {
+ op: "access index".to_string(),
+ operand: format!("{:?}", other),
+ })
+ }
+ },
+ crate::Expression::Constant(h) => Resolution::Handle(ctx.constants[h].ty),
+ crate::Expression::Compose { ty, .. } => Resolution::Handle(ty),
+ crate::Expression::FunctionArgument(index) => {
+ Resolution::Handle(ctx.arguments[index as usize].ty)
+ }
+ crate::Expression::GlobalVariable(h) => Resolution::Handle(ctx.global_vars[h].ty),
+ crate::Expression::LocalVariable(h) => Resolution::Handle(ctx.local_vars[h].ty),
+ crate::Expression::Load { .. } => unimplemented!(),
+ crate::Expression::ImageSample { image, .. }
+ | crate::Expression::ImageLoad { image, .. } => match *self.get(image, types) {
+ crate::TypeInner::Image { class, .. } => Resolution::Value(match class {
+ crate::ImageClass::Depth => crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ },
+ crate::ImageClass::Sampled { kind, multi: _ } => crate::TypeInner::Vector {
+ kind,
+ width: 4,
+ size: crate::VectorSize::Quad,
+ },
+ crate::ImageClass::Storage(format) => crate::TypeInner::Vector {
+ kind: format.into(),
+ width: 4,
+ size: crate::VectorSize::Quad,
+ },
+ }),
+ _ => unreachable!(),
+ },
+ crate::Expression::Unary { expr, .. } => self.resolutions[expr.index()].clone(),
+ crate::Expression::Binary { op, left, right } => match op {
+ crate::BinaryOperator::Add
+ | crate::BinaryOperator::Subtract
+ | crate::BinaryOperator::Divide
+ | crate::BinaryOperator::Modulo => self.resolutions[left.index()].clone(),
+ crate::BinaryOperator::Multiply => {
+ let ty_left = self.get(left, types);
+ let ty_right = self.get(right, types);
+ if ty_left == ty_right {
+ self.resolutions[left.index()].clone()
+ } else if let crate::TypeInner::Scalar { .. } = *ty_right {
+ self.resolutions[left.index()].clone()
+ } else {
+ match *ty_left {
+ crate::TypeInner::Scalar { .. } => {
+ self.resolutions[right.index()].clone()
+ }
+ crate::TypeInner::Matrix {
+ columns,
+ rows: _,
+ width,
+ } => Resolution::Value(crate::TypeInner::Vector {
+ size: columns,
+ kind: crate::ScalarKind::Float,
+ width,
+ }),
+ _ => {
+ return Err(ResolveError::IncompatibleOperands {
+ op: "x".to_string(),
+ left: format!("{:?}", ty_left),
+ right: format!("{:?}", ty_right),
+ })
+ }
+ }
+ }
+ }
+ crate::BinaryOperator::Equal
+ | crate::BinaryOperator::NotEqual
+ | crate::BinaryOperator::Less
+ | crate::BinaryOperator::LessEqual
+ | crate::BinaryOperator::Greater
+ | crate::BinaryOperator::GreaterEqual
+ | crate::BinaryOperator::LogicalAnd
+ | crate::BinaryOperator::LogicalOr => self.resolutions[left.index()].clone(),
+ crate::BinaryOperator::And
+ | crate::BinaryOperator::ExclusiveOr
+ | crate::BinaryOperator::InclusiveOr
+ | crate::BinaryOperator::ShiftLeft
+ | crate::BinaryOperator::ShiftRight => self.resolutions[left.index()].clone(),
+ },
+ crate::Expression::Select { accept, .. } => self.resolutions[accept.index()].clone(),
+ crate::Expression::Intrinsic { .. } => unimplemented!(),
+ crate::Expression::Transpose(expr) => match *self.get(expr, types) {
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => Resolution::Value(crate::TypeInner::Matrix {
+ columns: rows,
+ rows: columns,
+ width,
+ }),
+ ref other => {
+ return Err(ResolveError::IncompatibleOperand {
+ op: "transpose".to_string(),
+ operand: format!("{:?}", other),
+ })
+ }
+ },
+ crate::Expression::DotProduct(left_expr, _) => match *self.get(left_expr, types) {
+ crate::TypeInner::Vector {
+ kind,
+ size: _,
+ width,
+ } => Resolution::Value(crate::TypeInner::Scalar { kind, width }),
+ ref other => {
+ return Err(ResolveError::IncompatibleOperand {
+ op: "dot product".to_string(),
+ operand: format!("{:?}", other),
+ })
+ }
+ },
+ crate::Expression::CrossProduct(_, _) => unimplemented!(),
+ crate::Expression::As {
+ expr,
+ kind,
+ convert: _,
+ } => match *self.get(expr, types) {
+ crate::TypeInner::Scalar { kind: _, width } => {
+ Resolution::Value(crate::TypeInner::Scalar { kind, width })
+ }
+ crate::TypeInner::Vector {
+ kind: _,
+ size,
+ width,
+ } => Resolution::Value(crate::TypeInner::Vector { kind, size, width }),
+ ref other => {
+ return Err(ResolveError::IncompatibleOperand {
+ op: "as".to_string(),
+ operand: format!("{:?}", other),
+ })
+ }
+ },
+ crate::Expression::Derivative { .. } => unimplemented!(),
+ crate::Expression::Call {
+ origin: crate::FunctionOrigin::External(ref name),
+ ref arguments,
+ } => match name.as_str() {
+ "distance" | "length" => match *self.get(arguments[0], types) {
+ crate::TypeInner::Vector { kind, width, .. }
+ | crate::TypeInner::Scalar { kind, width } => {
+ Resolution::Value(crate::TypeInner::Scalar { kind, width })
+ }
+ ref other => {
+ return Err(ResolveError::IncompatibleOperand {
+ op: name.clone(),
+ operand: format!("{:?}", other),
+ })
+ }
+ },
+ "dot" => match *self.get(arguments[0], types) {
+ crate::TypeInner::Vector { kind, width, .. } => {
+ Resolution::Value(crate::TypeInner::Scalar { kind, width })
+ }
+ ref other => {
+ return Err(ResolveError::IncompatibleOperand {
+ op: name.clone(),
+ operand: format!("{:?}", other),
+ })
+ }
+ },
+ //Note: `cross` is here too, we still need to figure out what to do with it
+ "abs" | "atan2" | "cos" | "sin" | "floor" | "inverse" | "normalize" | "min"
+ | "max" | "reflect" | "pow" | "clamp" | "fclamp" | "mix" | "step"
+ | "smoothstep" | "cross" => self.resolutions[arguments[0].index()].clone(),
+ _ => return Err(ResolveError::FunctionNotDefined { name: name.clone() }),
+ },
+ crate::Expression::Call {
+ origin: crate::FunctionOrigin::Local(handle),
+ arguments: _,
+ } => {
+ let ty = ctx.functions[handle]
+ .return_type
+ .ok_or(ResolveError::FunctionReturnsVoid)?;
+ Resolution::Handle(ty)
+ }
+ crate::Expression::ArrayLength(_) => Resolution::Value(crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ }),
+ })
+ }
+
+ pub fn grow(
+ &mut self,
+ expr_handle: Handle<crate::Expression>,
+ expressions: &Arena<crate::Expression>,
+ types: &mut Arena<crate::Type>,
+ ctx: &ResolveContext,
+ ) -> Result<(), ResolveError> {
+ if self.resolutions.len() <= expr_handle.index() {
+ for (eh, expr) in expressions.iter().skip(self.resolutions.len()) {
+ let resolution = self.resolve_impl(expr, types, ctx)?;
+ log::debug!("Resolving {:?} = {:?} : {:?}", eh, expr, resolution);
+
+ let ty_handle = match resolution {
+ Resolution::Handle(h) => h,
+ Resolution::Value(inner) => types
+ .fetch_if_or_append(crate::Type { name: None, inner }, |a, b| {
+ a.inner == b.inner
+ }),
+ };
+ self.resolutions.push(Resolution::Handle(ty_handle));
+ }
+ }
+ Ok(())
+ }
+
+ pub fn resolve_all(
+ &mut self,
+ expressions: &Arena<crate::Expression>,
+ types: &Arena<crate::Type>,
+ ctx: &ResolveContext,
+ ) -> Result<(), ResolveError> {
+ self.clear();
+ for (_, expr) in expressions.iter() {
+ let resolution = self.resolve_impl(expr, types, ctx)?;
+ self.resolutions.push(resolution);
+ }
+ Ok(())
+ }
+}
+
+pub fn check_constant_type(inner: &crate::ConstantInner, type_inner: &crate::TypeInner) -> bool {
+ match (inner, type_inner) {
+ (
+ crate::ConstantInner::Sint(_),
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Sint,
+ width: _,
+ },
+ ) => true,
+ (
+ crate::ConstantInner::Uint(_),
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Uint,
+ width: _,
+ },
+ ) => true,
+ (
+ crate::ConstantInner::Float(_),
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Float,
+ width: _,
+ },
+ ) => true,
+ (
+ crate::ConstantInner::Bool(_),
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Bool,
+ width: _,
+ },
+ ) => true,
+ (crate::ConstantInner::Composite(_inner), _) => true, // TODO recursively check composite types
+ (_, _) => false,
+ }
+}
diff --git a/third_party/rust/naga/src/proc/validator.rs b/third_party/rust/naga/src/proc/validator.rs
new file mode 100644
index 0000000000..d9d3eac659
--- /dev/null
+++ b/third_party/rust/naga/src/proc/validator.rs
@@ -0,0 +1,489 @@
+use super::typifier::{ResolveContext, ResolveError, Typifier};
+use crate::arena::{Arena, Handle};
+
+const MAX_BIND_GROUPS: u32 = 8;
+const MAX_LOCATIONS: u32 = 64; // using u64 mask
+const MAX_BIND_INDICES: u32 = 64; // using u64 mask
+const MAX_WORKGROUP_SIZE: u32 = 0x4000;
+
+#[derive(Debug)]
+pub struct Validator {
+ //Note: this is a bit tricky: some of the front-ends as well as backends
+ // already have to use the typifier, so the work here is redundant in a way.
+ typifier: Typifier,
+}
+
+#[derive(Clone, Debug, PartialEq, thiserror::Error)]
+pub enum GlobalVariableError {
+ #[error("Usage isn't compatible with the storage class")]
+ InvalidUsage,
+ #[error("Type isn't compatible with the storage class")]
+ InvalidType,
+ #[error("Interpolation is not valid")]
+ InvalidInterpolation,
+ #[error("Storage access {seen:?} exceed the allowed {allowed:?}")]
+ InvalidStorageAccess {
+ allowed: crate::StorageAccess,
+ seen: crate::StorageAccess,
+ },
+ #[error("Binding decoration is missing or not applicable")]
+ InvalidBinding,
+ #[error("Binding is out of range")]
+ OutOfRangeBinding,
+}
+
+#[derive(Clone, Debug, PartialEq, thiserror::Error)]
+pub enum LocalVariableError {
+ #[error("Initializer is not a constant expression")]
+ InitializerConst,
+ #[error("Initializer doesn't match the variable type")]
+ InitializerType,
+}
+
+#[derive(Clone, Debug, PartialEq, thiserror::Error)]
+pub enum FunctionError {
+ #[error(transparent)]
+ Resolve(#[from] ResolveError),
+ #[error("There are instructions after `return`/`break`/`continue`")]
+ InvalidControlFlowExitTail,
+ #[error("Local variable {handle:?} '{name}' is invalid: {error:?}")]
+ LocalVariable {
+ handle: Handle<crate::LocalVariable>,
+ name: String,
+ error: LocalVariableError,
+ },
+}
+
+#[derive(Clone, Debug, PartialEq, thiserror::Error)]
+pub enum EntryPointError {
+ #[error("Early depth test is not applicable")]
+ UnexpectedEarlyDepthTest,
+ #[error("Workgroup size is not applicable")]
+ UnexpectedWorkgroupSize,
+ #[error("Workgroup size is out of range")]
+ OutOfRangeWorkgroupSize,
+ #[error("Global variable {0:?} is used incorrectly as {1:?}")]
+ InvalidGlobalUsage(Handle<crate::GlobalVariable>, crate::GlobalUse),
+ #[error("Bindings for {0:?} conflict with other global variables")]
+ BindingCollision(Handle<crate::GlobalVariable>),
+ #[error("Built-in {0:?} is not applicable to this entry point")]
+ InvalidBuiltIn(crate::BuiltIn),
+ #[error("Interpolation of an integer has to be flat")]
+ InvalidIntegerInterpolation,
+ #[error(transparent)]
+ Function(#[from] FunctionError),
+}
+
+#[derive(Clone, Debug, PartialEq, thiserror::Error)]
+pub enum ValidationError {
+ #[error("The type {0:?} width {1} is not supported")]
+ InvalidTypeWidth(crate::ScalarKind, crate::Bytes),
+ #[error("The type handle {0:?} can not be resolved")]
+ UnresolvedType(Handle<crate::Type>),
+ #[error("The constant {0:?} can not be used for an array size")]
+ InvalidArraySizeConstant(Handle<crate::Constant>),
+ #[error("Global variable {handle:?} '{name}' is invalid: {error:?}")]
+ GlobalVariable {
+ handle: Handle<crate::GlobalVariable>,
+ name: String,
+ error: GlobalVariableError,
+ },
+ #[error("Function {0:?} is invalid: {1:?}")]
+ Function(Handle<crate::Function>, FunctionError),
+ #[error("Entry point {name} at {stage:?} is invalid: {error:?}")]
+ EntryPoint {
+ stage: crate::ShaderStage,
+ name: String,
+ error: EntryPointError,
+ },
+ #[error("Module is corrupted")]
+ Corrupted,
+}
+
+impl crate::GlobalVariable {
+ fn forbid_interpolation(&self) -> Result<(), GlobalVariableError> {
+ match self.interpolation {
+ Some(_) => Err(GlobalVariableError::InvalidInterpolation),
+ None => Ok(()),
+ }
+ }
+
+ fn check_resource(&self) -> Result<(), GlobalVariableError> {
+ match self.binding {
+ Some(crate::Binding::BuiltIn(_)) => {} // validated per entry point
+ Some(crate::Binding::Resource { group, binding }) => {
+ if group > MAX_BIND_GROUPS || binding > MAX_BIND_INDICES {
+ return Err(GlobalVariableError::OutOfRangeBinding);
+ }
+ }
+ Some(crate::Binding::Location(_)) | None => {
+ return Err(GlobalVariableError::InvalidBinding)
+ }
+ }
+ self.forbid_interpolation()
+ }
+}
+
+fn storage_usage(access: crate::StorageAccess) -> crate::GlobalUse {
+ let mut storage_usage = crate::GlobalUse::empty();
+ if access.contains(crate::StorageAccess::LOAD) {
+ storage_usage |= crate::GlobalUse::LOAD;
+ }
+ if access.contains(crate::StorageAccess::STORE) {
+ storage_usage |= crate::GlobalUse::STORE;
+ }
+ storage_usage
+}
+
+impl Validator {
+ /// Construct a new validator instance.
+ pub fn new() -> Self {
+ Validator {
+ typifier: Typifier::new(),
+ }
+ }
+
+ fn validate_global_var(
+ &self,
+ var: &crate::GlobalVariable,
+ types: &Arena<crate::Type>,
+ ) -> Result<(), GlobalVariableError> {
+ log::debug!("var {:?}", var);
+ let allowed_storage_access = match var.class {
+ crate::StorageClass::Function => return Err(GlobalVariableError::InvalidUsage),
+ crate::StorageClass::Input | crate::StorageClass::Output => {
+ match var.binding {
+ Some(crate::Binding::BuiltIn(_)) => {
+ // validated per entry point
+ var.forbid_interpolation()?
+ }
+ Some(crate::Binding::Location(loc)) => {
+ if loc > MAX_LOCATIONS {
+ return Err(GlobalVariableError::OutOfRangeBinding);
+ }
+ match types[var.ty].inner {
+ crate::TypeInner::Scalar { .. }
+ | crate::TypeInner::Vector { .. }
+ | crate::TypeInner::Matrix { .. } => {}
+ _ => return Err(GlobalVariableError::InvalidType),
+ }
+ }
+ Some(crate::Binding::Resource { .. }) => {
+ return Err(GlobalVariableError::InvalidBinding)
+ }
+ None => {
+ match types[var.ty].inner {
+ //TODO: check the member types
+ crate::TypeInner::Struct { members: _ } => {
+ var.forbid_interpolation()?
+ }
+ _ => return Err(GlobalVariableError::InvalidType),
+ }
+ }
+ }
+ crate::StorageAccess::empty()
+ }
+ crate::StorageClass::Storage => {
+ var.check_resource()?;
+ crate::StorageAccess::all()
+ }
+ crate::StorageClass::Uniform => {
+ var.check_resource()?;
+ crate::StorageAccess::empty()
+ }
+ crate::StorageClass::Handle => {
+ var.check_resource()?;
+ match types[var.ty].inner {
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Storage(_),
+ ..
+ } => crate::StorageAccess::all(),
+ _ => crate::StorageAccess::empty(),
+ }
+ }
+ crate::StorageClass::Private | crate::StorageClass::WorkGroup => {
+ if var.binding.is_some() {
+ return Err(GlobalVariableError::InvalidBinding);
+ }
+ var.forbid_interpolation()?;
+ crate::StorageAccess::empty()
+ }
+ crate::StorageClass::PushConstant => {
+ //TODO
+ return Err(GlobalVariableError::InvalidStorageAccess {
+ allowed: crate::StorageAccess::empty(),
+ seen: crate::StorageAccess::empty(),
+ });
+ }
+ };
+
+ if !allowed_storage_access.contains(var.storage_access) {
+ return Err(GlobalVariableError::InvalidStorageAccess {
+ allowed: allowed_storage_access,
+ seen: var.storage_access,
+ });
+ }
+
+ Ok(())
+ }
+
+ fn validate_local_var(
+ &self,
+ var: &crate::LocalVariable,
+ _fun: &crate::Function,
+ _types: &Arena<crate::Type>,
+ ) -> Result<(), LocalVariableError> {
+ log::debug!("var {:?}", var);
+ if let Some(_expr_handle) = var.init {
+ if false {
+ return Err(LocalVariableError::InitializerConst);
+ }
+ }
+ Ok(())
+ }
+
+ fn validate_function(
+ &mut self,
+ fun: &crate::Function,
+ module: &crate::Module,
+ ) -> Result<(), FunctionError> {
+ let resolve_ctx = ResolveContext {
+ constants: &module.constants,
+ global_vars: &module.global_variables,
+ local_vars: &fun.local_variables,
+ functions: &module.functions,
+ arguments: &fun.arguments,
+ };
+ self.typifier
+ .resolve_all(&fun.expressions, &module.types, &resolve_ctx)?;
+
+ for (var_handle, var) in fun.local_variables.iter() {
+ self.validate_local_var(var, fun, &module.types)
+ .map_err(|error| FunctionError::LocalVariable {
+ handle: var_handle,
+ name: var.name.clone().unwrap_or_default(),
+ error,
+ })?;
+ }
+ Ok(())
+ }
+
+ fn validate_entry_point(
+ &mut self,
+ ep: &crate::EntryPoint,
+ stage: crate::ShaderStage,
+ module: &crate::Module,
+ ) -> Result<(), EntryPointError> {
+ if ep.early_depth_test.is_some() && stage != crate::ShaderStage::Fragment {
+ return Err(EntryPointError::UnexpectedEarlyDepthTest);
+ }
+ if stage == crate::ShaderStage::Compute {
+ if ep
+ .workgroup_size
+ .iter()
+ .any(|&s| s == 0 || s > MAX_WORKGROUP_SIZE)
+ {
+ return Err(EntryPointError::OutOfRangeWorkgroupSize);
+ }
+ } else if ep.workgroup_size != [0; 3] {
+ return Err(EntryPointError::UnexpectedWorkgroupSize);
+ }
+
+ let mut bind_group_masks = [0u64; MAX_BIND_GROUPS as usize];
+ let mut location_in_mask = 0u64;
+ let mut location_out_mask = 0u64;
+ for ((var_handle, var), &usage) in module
+ .global_variables
+ .iter()
+ .zip(&ep.function.global_usage)
+ {
+ if usage.is_empty() {
+ continue;
+ }
+
+ if let Some(crate::Binding::Location(_)) = var.binding {
+ match (stage, var.class) {
+ (crate::ShaderStage::Vertex, crate::StorageClass::Output)
+ | (crate::ShaderStage::Fragment, crate::StorageClass::Input) => {
+ match module.types[var.ty].inner.scalar_kind() {
+ Some(crate::ScalarKind::Float) => {}
+ Some(_) if var.interpolation != Some(crate::Interpolation::Flat) => {
+ return Err(EntryPointError::InvalidIntegerInterpolation);
+ }
+ _ => {}
+ }
+ }
+ _ => {}
+ }
+ }
+
+ let allowed_usage = match var.class {
+ crate::StorageClass::Function => unreachable!(),
+ crate::StorageClass::Input => {
+ let mask = match var.binding {
+ Some(crate::Binding::BuiltIn(built_in)) => match (stage, built_in) {
+ (crate::ShaderStage::Vertex, crate::BuiltIn::BaseInstance)
+ | (crate::ShaderStage::Vertex, crate::BuiltIn::BaseVertex)
+ | (crate::ShaderStage::Vertex, crate::BuiltIn::InstanceIndex)
+ | (crate::ShaderStage::Vertex, crate::BuiltIn::VertexIndex)
+ | (crate::ShaderStage::Fragment, crate::BuiltIn::PointSize)
+ | (crate::ShaderStage::Fragment, crate::BuiltIn::FragCoord)
+ | (crate::ShaderStage::Fragment, crate::BuiltIn::FrontFacing)
+ | (crate::ShaderStage::Fragment, crate::BuiltIn::SampleIndex)
+ | (crate::ShaderStage::Compute, crate::BuiltIn::GlobalInvocationId)
+ | (crate::ShaderStage::Compute, crate::BuiltIn::LocalInvocationId)
+ | (crate::ShaderStage::Compute, crate::BuiltIn::LocalInvocationIndex)
+ | (crate::ShaderStage::Compute, crate::BuiltIn::WorkGroupId) => 0,
+ _ => return Err(EntryPointError::InvalidBuiltIn(built_in)),
+ },
+ Some(crate::Binding::Location(loc)) => 1 << loc,
+ Some(crate::Binding::Resource { .. }) => unreachable!(),
+ None => 0,
+ };
+ if location_in_mask & mask != 0 {
+ return Err(EntryPointError::BindingCollision(var_handle));
+ }
+ location_in_mask |= mask;
+ crate::GlobalUse::LOAD
+ }
+ crate::StorageClass::Output => {
+ let mask = match var.binding {
+ Some(crate::Binding::BuiltIn(built_in)) => match (stage, built_in) {
+ (crate::ShaderStage::Vertex, crate::BuiltIn::Position)
+ | (crate::ShaderStage::Vertex, crate::BuiltIn::PointSize)
+ | (crate::ShaderStage::Vertex, crate::BuiltIn::ClipDistance)
+ | (crate::ShaderStage::Fragment, crate::BuiltIn::FragDepth) => 0,
+ _ => return Err(EntryPointError::InvalidBuiltIn(built_in)),
+ },
+ Some(crate::Binding::Location(loc)) => 1 << loc,
+ Some(crate::Binding::Resource { .. }) => unreachable!(),
+ None => 0,
+ };
+ if location_out_mask & mask != 0 {
+ return Err(EntryPointError::BindingCollision(var_handle));
+ }
+ location_out_mask |= mask;
+ crate::GlobalUse::LOAD | crate::GlobalUse::STORE
+ }
+ crate::StorageClass::Uniform => crate::GlobalUse::LOAD,
+ crate::StorageClass::Storage => storage_usage(var.storage_access),
+ crate::StorageClass::Handle => match module.types[var.ty].inner {
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Storage(_),
+ ..
+ } => storage_usage(var.storage_access),
+ _ => crate::GlobalUse::LOAD,
+ },
+ crate::StorageClass::Private | crate::StorageClass::WorkGroup => {
+ crate::GlobalUse::all()
+ }
+ crate::StorageClass::PushConstant => crate::GlobalUse::LOAD,
+ };
+ if !allowed_usage.contains(usage) {
+ log::warn!("\tUsage error for: {:?}", var);
+ log::warn!(
+ "\tAllowed usage: {:?}, requested: {:?}",
+ allowed_usage,
+ usage
+ );
+ return Err(EntryPointError::InvalidGlobalUsage(var_handle, usage));
+ }
+
+ if let Some(crate::Binding::Resource { group, binding }) = var.binding {
+ let mask = 1 << binding;
+ let group_mask = &mut bind_group_masks[group as usize];
+ if *group_mask & mask != 0 {
+ return Err(EntryPointError::BindingCollision(var_handle));
+ }
+ *group_mask |= mask;
+ }
+ }
+
+ self.validate_function(&ep.function, module)?;
+ Ok(())
+ }
+
+ /// Check the given module to be valid.
+ pub fn validate(&mut self, module: &crate::Module) -> Result<(), ValidationError> {
+ // check the types
+ for (handle, ty) in module.types.iter() {
+ use crate::TypeInner as Ti;
+ match ty.inner {
+ Ti::Scalar { kind, width } | Ti::Vector { kind, width, .. } => {
+ let expected = match kind {
+ crate::ScalarKind::Bool => 1,
+ _ => 4,
+ };
+ if width != expected {
+ return Err(ValidationError::InvalidTypeWidth(kind, width));
+ }
+ }
+ Ti::Matrix { width, .. } => {
+ if width != 4 {
+ return Err(ValidationError::InvalidTypeWidth(
+ crate::ScalarKind::Float,
+ width,
+ ));
+ }
+ }
+ Ti::Pointer { base, class: _ } => {
+ if base >= handle {
+ return Err(ValidationError::UnresolvedType(base));
+ }
+ }
+ Ti::Array { base, size, .. } => {
+ if base >= handle {
+ return Err(ValidationError::UnresolvedType(base));
+ }
+ if let crate::ArraySize::Constant(const_handle) = size {
+ let constant = module
+ .constants
+ .try_get(const_handle)
+ .ok_or(ValidationError::Corrupted)?;
+ match constant.inner {
+ crate::ConstantInner::Uint(_) => {}
+ _ => {
+ return Err(ValidationError::InvalidArraySizeConstant(const_handle))
+ }
+ }
+ }
+ }
+ Ti::Struct { ref members } => {
+ //TODO: check that offsets are not intersecting?
+ for member in members {
+ if member.ty >= handle {
+ return Err(ValidationError::UnresolvedType(member.ty));
+ }
+ }
+ }
+ Ti::Image { .. } => {}
+ Ti::Sampler { comparison: _ } => {}
+ }
+ }
+
+ for (var_handle, var) in module.global_variables.iter() {
+ self.validate_global_var(var, &module.types)
+ .map_err(|error| ValidationError::GlobalVariable {
+ handle: var_handle,
+ name: var.name.clone().unwrap_or_default(),
+ error,
+ })?;
+ }
+
+ for (fun_handle, fun) in module.functions.iter() {
+ self.validate_function(fun, module)
+ .map_err(|e| ValidationError::Function(fun_handle, e))?;
+ }
+
+ for (&(stage, ref name), entry_point) in module.entry_points.iter() {
+ self.validate_entry_point(entry_point, stage, module)
+ .map_err(|error| ValidationError::EntryPoint {
+ stage,
+ name: name.to_string(),
+ error,
+ })?;
+ }
+
+ Ok(())
+ }
+}