diff options
Diffstat (limited to 'third_party/rust/naga/src/proc')
-rw-r--r-- | third_party/rust/naga/src/proc/call_graph.rs | 74 | ||||
-rw-r--r-- | third_party/rust/naga/src/proc/interface.rs | 290 | ||||
-rw-r--r-- | third_party/rust/naga/src/proc/mod.rs | 67 | ||||
-rw-r--r-- | third_party/rust/naga/src/proc/namer.rs | 113 | ||||
-rw-r--r-- | third_party/rust/naga/src/proc/typifier.rs | 424 | ||||
-rw-r--r-- | third_party/rust/naga/src/proc/validator.rs | 489 |
6 files changed, 1457 insertions, 0 deletions
diff --git a/third_party/rust/naga/src/proc/call_graph.rs b/third_party/rust/naga/src/proc/call_graph.rs new file mode 100644 index 0000000000..1c580d5c15 --- /dev/null +++ b/third_party/rust/naga/src/proc/call_graph.rs @@ -0,0 +1,74 @@ +use crate::{ + arena::{Arena, Handle}, + proc::{Interface, Visitor}, + Function, +}; +use petgraph::{ + graph::{DefaultIx, NodeIndex}, + Graph, +}; + +pub type CallGraph = Graph<Handle<Function>, ()>; + +pub struct CallGraphBuilder<'a> { + pub functions: &'a Arena<Function>, +} + +impl<'a> CallGraphBuilder<'a> { + pub fn process(&self, func: &Function) -> CallGraph { + let mut graph = Graph::new(); + let mut children = Vec::new(); + + let visitor = CallGraphVisitor { + children: &mut children, + }; + + let mut interface = Interface { + expressions: &func.expressions, + local_variables: &func.local_variables, + visitor, + }; + + interface.traverse(&func.body); + + for handle in children { + let id = graph.add_node(handle); + self.collect(handle, id, &mut graph); + } + + graph + } + + fn collect(&self, handle: Handle<Function>, id: NodeIndex<DefaultIx>, graph: &mut CallGraph) { + let mut children = Vec::new(); + let visitor = CallGraphVisitor { + children: &mut children, + }; + let func = &self.functions[handle]; + + let mut interface = Interface { + expressions: &func.expressions, + local_variables: &func.local_variables, + visitor, + }; + + interface.traverse(&func.body); + + for handle in children { + let child_id = graph.add_node(handle); + graph.add_edge(id, child_id, ()); + + self.collect(handle, child_id, graph); + } + } +} + +struct CallGraphVisitor<'a> { + children: &'a mut Vec<Handle<Function>>, +} + +impl<'a> Visitor for CallGraphVisitor<'a> { + fn visit_fun(&mut self, func: Handle<Function>) { + self.children.push(func) + } +} diff --git a/third_party/rust/naga/src/proc/interface.rs b/third_party/rust/naga/src/proc/interface.rs new file mode 100644 index 0000000000..b512452fe2 --- /dev/null +++ b/third_party/rust/naga/src/proc/interface.rs @@ -0,0 +1,290 @@ +use crate::arena::{Arena, Handle}; + +pub struct Interface<'a, T> { + pub expressions: &'a Arena<crate::Expression>, + pub local_variables: &'a Arena<crate::LocalVariable>, + pub visitor: T, +} + +pub trait Visitor { + fn visit_expr(&mut self, _: &crate::Expression) {} + fn visit_lhs_expr(&mut self, _: &crate::Expression) {} + fn visit_fun(&mut self, _: Handle<crate::Function>) {} +} + +impl<'a, T> Interface<'a, T> +where + T: Visitor, +{ + fn traverse_expr(&mut self, handle: Handle<crate::Expression>) { + use crate::Expression as E; + + let expr = &self.expressions[handle]; + + self.visitor.visit_expr(expr); + + match *expr { + E::Access { base, index } => { + self.traverse_expr(base); + self.traverse_expr(index); + } + E::AccessIndex { base, .. } => { + self.traverse_expr(base); + } + E::Constant(_) => {} + E::Compose { ref components, .. } => { + for &comp in components { + self.traverse_expr(comp); + } + } + E::FunctionArgument(_) | E::GlobalVariable(_) | E::LocalVariable(_) => {} + E::Load { pointer } => { + self.traverse_expr(pointer); + } + E::ImageSample { + image, + sampler, + coordinate, + level, + depth_ref, + } => { + self.traverse_expr(image); + self.traverse_expr(sampler); + self.traverse_expr(coordinate); + match level { + crate::SampleLevel::Auto | crate::SampleLevel::Zero => (), + crate::SampleLevel::Exact(h) | crate::SampleLevel::Bias(h) => { + self.traverse_expr(h) + } + } + if let Some(dref) = depth_ref { + self.traverse_expr(dref); + } + } + E::ImageLoad { + image, + coordinate, + index, + } => { + self.traverse_expr(image); + self.traverse_expr(coordinate); + if let Some(index) = index { + self.traverse_expr(index); + } + } + E::Unary { expr, .. } => { + self.traverse_expr(expr); + } + E::Binary { left, right, .. } => { + self.traverse_expr(left); + self.traverse_expr(right); + } + E::Select { + condition, + accept, + reject, + } => { + self.traverse_expr(condition); + self.traverse_expr(accept); + self.traverse_expr(reject); + } + E::Intrinsic { argument, .. } => { + self.traverse_expr(argument); + } + E::Transpose(matrix) => { + self.traverse_expr(matrix); + } + E::DotProduct(left, right) => { + self.traverse_expr(left); + self.traverse_expr(right); + } + E::CrossProduct(left, right) => { + self.traverse_expr(left); + self.traverse_expr(right); + } + E::As { expr, .. } => { + self.traverse_expr(expr); + } + E::Derivative { expr, .. } => { + self.traverse_expr(expr); + } + E::Call { + ref origin, + ref arguments, + } => { + for &argument in arguments { + self.traverse_expr(argument); + } + if let crate::FunctionOrigin::Local(fun) = *origin { + self.visitor.visit_fun(fun); + } + } + E::ArrayLength(expr) => { + self.traverse_expr(expr); + } + } + } + + pub fn traverse(&mut self, block: &[crate::Statement]) { + for statement in block { + use crate::Statement as S; + match *statement { + S::Break | S::Continue | S::Kill => (), + S::Block(ref b) => { + self.traverse(b); + } + S::If { + condition, + ref accept, + ref reject, + } => { + self.traverse_expr(condition); + self.traverse(accept); + self.traverse(reject); + } + S::Switch { + selector, + ref cases, + ref default, + } => { + self.traverse_expr(selector); + for &(ref case, _) in cases.values() { + self.traverse(case); + } + self.traverse(default); + } + S::Loop { + ref body, + ref continuing, + } => { + self.traverse(body); + self.traverse(continuing); + } + S::Return { value } => { + if let Some(expr) = value { + self.traverse_expr(expr); + } + } + S::Store { pointer, value } => { + let mut left = pointer; + loop { + match self.expressions[left] { + crate::Expression::Access { base, index } => { + self.traverse_expr(index); + left = base; + } + crate::Expression::AccessIndex { base, .. } => { + left = base; + } + _ => break, + } + } + self.visitor.visit_lhs_expr(&self.expressions[left]); + self.traverse_expr(value); + } + } + } + } +} + +struct GlobalUseVisitor<'a>(&'a mut [crate::GlobalUse]); + +impl Visitor for GlobalUseVisitor<'_> { + fn visit_expr(&mut self, expr: &crate::Expression) { + if let crate::Expression::GlobalVariable(handle) = expr { + self.0[handle.index()] |= crate::GlobalUse::LOAD; + } + } + + fn visit_lhs_expr(&mut self, expr: &crate::Expression) { + if let crate::Expression::GlobalVariable(handle) = expr { + self.0[handle.index()] |= crate::GlobalUse::STORE; + } + } +} + +impl crate::Function { + pub fn fill_global_use(&mut self, globals: &Arena<crate::GlobalVariable>) { + self.global_usage.clear(); + self.global_usage + .resize(globals.len(), crate::GlobalUse::empty()); + + let mut io = Interface { + expressions: &self.expressions, + local_variables: &self.local_variables, + visitor: GlobalUseVisitor(&mut self.global_usage), + }; + io.traverse(&self.body); + } +} + +#[cfg(test)] +mod tests { + use crate::{ + Arena, Expression, GlobalUse, GlobalVariable, Handle, Statement, StorageAccess, + StorageClass, + }; + + #[test] + fn global_use_scan() { + let test_global = GlobalVariable { + name: None, + class: StorageClass::Uniform, + binding: None, + ty: Handle::new(std::num::NonZeroU32::new(1).unwrap()), + init: None, + interpolation: None, + storage_access: StorageAccess::empty(), + }; + let mut test_globals = Arena::new(); + + let global_1 = test_globals.append(test_global.clone()); + let global_2 = test_globals.append(test_global.clone()); + let global_3 = test_globals.append(test_global.clone()); + let global_4 = test_globals.append(test_global); + + let mut expressions = Arena::new(); + let global_1_expr = expressions.append(Expression::GlobalVariable(global_1)); + let global_2_expr = expressions.append(Expression::GlobalVariable(global_2)); + let global_3_expr = expressions.append(Expression::GlobalVariable(global_3)); + let global_4_expr = expressions.append(Expression::GlobalVariable(global_4)); + + let test_body = vec![ + Statement::Return { + value: Some(global_1_expr), + }, + Statement::Store { + pointer: global_2_expr, + value: global_1_expr, + }, + Statement::Store { + pointer: expressions.append(Expression::Access { + base: global_3_expr, + index: global_4_expr, + }), + value: global_1_expr, + }, + ]; + + let mut function = crate::Function { + name: None, + arguments: Vec::new(), + return_type: None, + local_variables: Arena::new(), + expressions, + global_usage: Vec::new(), + body: test_body, + }; + function.fill_global_use(&test_globals); + + assert_eq!( + &function.global_usage, + &[ + GlobalUse::LOAD, + GlobalUse::STORE, + GlobalUse::STORE, + GlobalUse::LOAD, + ], + ) + } +} diff --git a/third_party/rust/naga/src/proc/mod.rs b/third_party/rust/naga/src/proc/mod.rs new file mode 100644 index 0000000000..961e55da7c --- /dev/null +++ b/third_party/rust/naga/src/proc/mod.rs @@ -0,0 +1,67 @@ +//! Module processing functionality. + +#[cfg(feature = "petgraph")] +mod call_graph; +mod interface; +mod namer; +mod typifier; +mod validator; + +#[cfg(feature = "petgraph")] +pub use call_graph::{CallGraph, CallGraphBuilder}; +pub use interface::{Interface, Visitor}; +pub use namer::{EntryPointIndex, NameKey, Namer}; +pub use typifier::{check_constant_type, ResolveContext, ResolveError, Typifier}; +pub use validator::{ValidationError, Validator}; + +impl From<super::StorageFormat> for super::ScalarKind { + fn from(format: super::StorageFormat) -> Self { + use super::{ScalarKind as Sk, StorageFormat as Sf}; + match format { + Sf::R8Unorm => Sk::Float, + Sf::R8Snorm => Sk::Float, + Sf::R8Uint => Sk::Uint, + Sf::R8Sint => Sk::Sint, + Sf::R16Uint => Sk::Uint, + Sf::R16Sint => Sk::Sint, + Sf::R16Float => Sk::Float, + Sf::Rg8Unorm => Sk::Float, + Sf::Rg8Snorm => Sk::Float, + Sf::Rg8Uint => Sk::Uint, + Sf::Rg8Sint => Sk::Sint, + Sf::R32Uint => Sk::Uint, + Sf::R32Sint => Sk::Sint, + Sf::R32Float => Sk::Float, + Sf::Rg16Uint => Sk::Uint, + Sf::Rg16Sint => Sk::Sint, + Sf::Rg16Float => Sk::Float, + Sf::Rgba8Unorm => Sk::Float, + Sf::Rgba8Snorm => Sk::Float, + Sf::Rgba8Uint => Sk::Uint, + Sf::Rgba8Sint => Sk::Sint, + Sf::Rgb10a2Unorm => Sk::Float, + Sf::Rg11b10Float => Sk::Float, + Sf::Rg32Uint => Sk::Uint, + Sf::Rg32Sint => Sk::Sint, + Sf::Rg32Float => Sk::Float, + Sf::Rgba16Uint => Sk::Uint, + Sf::Rgba16Sint => Sk::Sint, + Sf::Rgba16Float => Sk::Float, + Sf::Rgba32Uint => Sk::Uint, + Sf::Rgba32Sint => Sk::Sint, + Sf::Rgba32Float => Sk::Float, + } + } +} + +impl crate::TypeInner { + pub fn scalar_kind(&self) -> Option<super::ScalarKind> { + match *self { + super::TypeInner::Scalar { kind, .. } | super::TypeInner::Vector { kind, .. } => { + Some(kind) + } + super::TypeInner::Matrix { .. } => Some(super::ScalarKind::Float), + _ => None, + } + } +} diff --git a/third_party/rust/naga/src/proc/namer.rs b/third_party/rust/naga/src/proc/namer.rs new file mode 100644 index 0000000000..03b508904b --- /dev/null +++ b/third_party/rust/naga/src/proc/namer.rs @@ -0,0 +1,113 @@ +use crate::{arena::Handle, FastHashMap}; +use std::collections::hash_map::Entry; + +pub type EntryPointIndex = u16; + +#[derive(Debug, Eq, Hash, PartialEq)] +pub enum NameKey { + GlobalVariable(Handle<crate::GlobalVariable>), + Type(Handle<crate::Type>), + StructMember(Handle<crate::Type>, u32), + Function(Handle<crate::Function>), + FunctionArgument(Handle<crate::Function>, u32), + FunctionLocal(Handle<crate::Function>, Handle<crate::LocalVariable>), + EntryPoint(EntryPointIndex), + EntryPointLocal(EntryPointIndex, Handle<crate::LocalVariable>), +} + +/// This processor assigns names to all the things in a module +/// that may need identifiers in a textual backend. +pub struct Namer { + unique: FastHashMap<String, u32>, +} + +impl Namer { + fn sanitize(string: &str) -> String { + let mut base = string + .chars() + .skip_while(|c| c.is_numeric()) + .filter(|&c| c.is_ascii_alphanumeric() || c == '_') + .collect::<String>(); + // close the name by '_' if the re is a number, so that + // we can have our own number! + match base.chars().next_back() { + Some(c) if !c.is_numeric() => {} + _ => base.push('_'), + }; + base + } + + fn call(&mut self, label_raw: &str) -> String { + let base = Self::sanitize(label_raw); + match self.unique.entry(base) { + Entry::Occupied(mut e) => { + *e.get_mut() += 1; + format!("{}{}", e.key(), e.get()) + } + Entry::Vacant(e) => { + let name = e.key().to_string(); + e.insert(0); + name + } + } + } + + fn call_or(&mut self, label: &Option<String>, fallback: &str) -> String { + self.call(match *label { + Some(ref name) => name, + None => fallback, + }) + } + + pub fn process( + module: &crate::Module, + reserved: &[&str], + output: &mut FastHashMap<NameKey, String>, + ) { + let mut this = Namer { + unique: reserved + .iter() + .map(|string| (string.to_string(), 0)) + .collect(), + }; + + for (handle, var) in module.global_variables.iter() { + let name = this.call_or(&var.name, "global"); + output.insert(NameKey::GlobalVariable(handle), name); + } + + for (ty_handle, ty) in module.types.iter() { + let ty_name = this.call_or(&ty.name, "type"); + output.insert(NameKey::Type(ty_handle), ty_name); + + if let crate::TypeInner::Struct { ref members } = ty.inner { + for (index, member) in members.iter().enumerate() { + let name = this.call_or(&member.name, "member"); + output.insert(NameKey::StructMember(ty_handle, index as u32), name); + } + } + } + + for (fun_handle, fun) in module.functions.iter() { + let fun_name = this.call_or(&fun.name, "function"); + output.insert(NameKey::Function(fun_handle), fun_name); + for (index, arg) in fun.arguments.iter().enumerate() { + let name = this.call_or(&arg.name, "param"); + output.insert(NameKey::FunctionArgument(fun_handle, index as u32), name); + } + for (handle, var) in fun.local_variables.iter() { + let name = this.call_or(&var.name, "local"); + output.insert(NameKey::FunctionLocal(fun_handle, handle), name); + } + } + + for (ep_index, (&(_, ref base_name), ep)) in module.entry_points.iter().enumerate() { + let ep_name = this.call(base_name); + output.insert(NameKey::EntryPoint(ep_index as _), ep_name); + for (handle, var) in ep.function.local_variables.iter() { + let name = this.call_or(&var.name, "local"); + output.insert(NameKey::EntryPointLocal(ep_index as _, handle), name); + } + } + } +} diff --git a/third_party/rust/naga/src/proc/typifier.rs b/third_party/rust/naga/src/proc/typifier.rs new file mode 100644 index 0000000000..b09893c179 --- /dev/null +++ b/third_party/rust/naga/src/proc/typifier.rs @@ -0,0 +1,424 @@ +use crate::arena::{Arena, Handle}; + +use thiserror::Error; + +#[derive(Debug, PartialEq)] +enum Resolution { + Handle(Handle<crate::Type>), + Value(crate::TypeInner), +} + +// Clone is only implemented for numeric variants of `TypeInner`. +impl Clone for Resolution { + fn clone(&self) -> Self { + match *self { + Resolution::Handle(handle) => Resolution::Handle(handle), + Resolution::Value(ref v) => Resolution::Value(match *v { + crate::TypeInner::Scalar { kind, width } => { + crate::TypeInner::Scalar { kind, width } + } + crate::TypeInner::Vector { size, kind, width } => { + crate::TypeInner::Vector { size, kind, width } + } + crate::TypeInner::Matrix { + rows, + columns, + width, + } => crate::TypeInner::Matrix { + rows, + columns, + width, + }, + #[allow(clippy::panic)] + _ => panic!("Unepxected clone type: {:?}", v), + }), + } + } +} + +#[derive(Debug)] +pub struct Typifier { + resolutions: Vec<Resolution>, +} + +#[derive(Clone, Debug, Error, PartialEq)] +pub enum ResolveError { + #[error("Invalid index into array")] + InvalidAccessIndex, + #[error("Function {name} not defined")] + FunctionNotDefined { name: String }, + #[error("Function without return type")] + FunctionReturnsVoid, + #[error("Type is not found in the given immutable arena")] + TypeNotFound, + #[error("Incompatible operand: {op} {operand}")] + IncompatibleOperand { op: String, operand: String }, + #[error("Incompatible operands: {left} {op} {right}")] + IncompatibleOperands { + op: String, + left: String, + right: String, + }, +} + +pub struct ResolveContext<'a> { + pub constants: &'a Arena<crate::Constant>, + pub global_vars: &'a Arena<crate::GlobalVariable>, + pub local_vars: &'a Arena<crate::LocalVariable>, + pub functions: &'a Arena<crate::Function>, + pub arguments: &'a [crate::FunctionArgument], +} + +impl Typifier { + pub fn new() -> Self { + Typifier { + resolutions: Vec::new(), + } + } + + pub fn clear(&mut self) { + self.resolutions.clear() + } + + pub fn get<'a>( + &'a self, + expr_handle: Handle<crate::Expression>, + types: &'a Arena<crate::Type>, + ) -> &'a crate::TypeInner { + match self.resolutions[expr_handle.index()] { + Resolution::Handle(ty_handle) => &types[ty_handle].inner, + Resolution::Value(ref inner) => inner, + } + } + + pub fn get_handle( + &self, + expr_handle: Handle<crate::Expression>, + ) -> Option<Handle<crate::Type>> { + match self.resolutions[expr_handle.index()] { + Resolution::Handle(ty_handle) => Some(ty_handle), + Resolution::Value(_) => None, + } + } + + fn resolve_impl( + &self, + expr: &crate::Expression, + types: &Arena<crate::Type>, + ctx: &ResolveContext, + ) -> Result<Resolution, ResolveError> { + Ok(match *expr { + crate::Expression::Access { base, .. } => match *self.get(base, types) { + crate::TypeInner::Array { base, .. } => Resolution::Handle(base), + crate::TypeInner::Vector { + size: _, + kind, + width, + } => Resolution::Value(crate::TypeInner::Scalar { kind, width }), + crate::TypeInner::Matrix { + rows: size, + columns: _, + width, + } => Resolution::Value(crate::TypeInner::Vector { + size, + kind: crate::ScalarKind::Float, + width, + }), + ref other => { + return Err(ResolveError::IncompatibleOperand { + op: "access".to_string(), + operand: format!("{:?}", other), + }) + } + }, + crate::Expression::AccessIndex { base, index } => match *self.get(base, types) { + crate::TypeInner::Vector { size, kind, width } => { + if index >= size as u32 { + return Err(ResolveError::InvalidAccessIndex); + } + Resolution::Value(crate::TypeInner::Scalar { kind, width }) + } + crate::TypeInner::Matrix { + columns, + rows, + width, + } => { + if index >= columns as u32 { + return Err(ResolveError::InvalidAccessIndex); + } + Resolution::Value(crate::TypeInner::Vector { + size: rows, + kind: crate::ScalarKind::Float, + width, + }) + } + crate::TypeInner::Array { base, .. } => Resolution::Handle(base), + crate::TypeInner::Struct { ref members } => { + let member = members + .get(index as usize) + .ok_or(ResolveError::InvalidAccessIndex)?; + Resolution::Handle(member.ty) + } + ref other => { + return Err(ResolveError::IncompatibleOperand { + op: "access index".to_string(), + operand: format!("{:?}", other), + }) + } + }, + crate::Expression::Constant(h) => Resolution::Handle(ctx.constants[h].ty), + crate::Expression::Compose { ty, .. } => Resolution::Handle(ty), + crate::Expression::FunctionArgument(index) => { + Resolution::Handle(ctx.arguments[index as usize].ty) + } + crate::Expression::GlobalVariable(h) => Resolution::Handle(ctx.global_vars[h].ty), + crate::Expression::LocalVariable(h) => Resolution::Handle(ctx.local_vars[h].ty), + crate::Expression::Load { .. } => unimplemented!(), + crate::Expression::ImageSample { image, .. } + | crate::Expression::ImageLoad { image, .. } => match *self.get(image, types) { + crate::TypeInner::Image { class, .. } => Resolution::Value(match class { + crate::ImageClass::Depth => crate::TypeInner::Scalar { + kind: crate::ScalarKind::Float, + width: 4, + }, + crate::ImageClass::Sampled { kind, multi: _ } => crate::TypeInner::Vector { + kind, + width: 4, + size: crate::VectorSize::Quad, + }, + crate::ImageClass::Storage(format) => crate::TypeInner::Vector { + kind: format.into(), + width: 4, + size: crate::VectorSize::Quad, + }, + }), + _ => unreachable!(), + }, + crate::Expression::Unary { expr, .. } => self.resolutions[expr.index()].clone(), + crate::Expression::Binary { op, left, right } => match op { + crate::BinaryOperator::Add + | crate::BinaryOperator::Subtract + | crate::BinaryOperator::Divide + | crate::BinaryOperator::Modulo => self.resolutions[left.index()].clone(), + crate::BinaryOperator::Multiply => { + let ty_left = self.get(left, types); + let ty_right = self.get(right, types); + if ty_left == ty_right { + self.resolutions[left.index()].clone() + } else if let crate::TypeInner::Scalar { .. } = *ty_right { + self.resolutions[left.index()].clone() + } else { + match *ty_left { + crate::TypeInner::Scalar { .. } => { + self.resolutions[right.index()].clone() + } + crate::TypeInner::Matrix { + columns, + rows: _, + width, + } => Resolution::Value(crate::TypeInner::Vector { + size: columns, + kind: crate::ScalarKind::Float, + width, + }), + _ => { + return Err(ResolveError::IncompatibleOperands { + op: "x".to_string(), + left: format!("{:?}", ty_left), + right: format!("{:?}", ty_right), + }) + } + } + } + } + crate::BinaryOperator::Equal + | crate::BinaryOperator::NotEqual + | crate::BinaryOperator::Less + | crate::BinaryOperator::LessEqual + | crate::BinaryOperator::Greater + | crate::BinaryOperator::GreaterEqual + | crate::BinaryOperator::LogicalAnd + | crate::BinaryOperator::LogicalOr => self.resolutions[left.index()].clone(), + crate::BinaryOperator::And + | crate::BinaryOperator::ExclusiveOr + | crate::BinaryOperator::InclusiveOr + | crate::BinaryOperator::ShiftLeft + | crate::BinaryOperator::ShiftRight => self.resolutions[left.index()].clone(), + }, + crate::Expression::Select { accept, .. } => self.resolutions[accept.index()].clone(), + crate::Expression::Intrinsic { .. } => unimplemented!(), + crate::Expression::Transpose(expr) => match *self.get(expr, types) { + crate::TypeInner::Matrix { + columns, + rows, + width, + } => Resolution::Value(crate::TypeInner::Matrix { + columns: rows, + rows: columns, + width, + }), + ref other => { + return Err(ResolveError::IncompatibleOperand { + op: "transpose".to_string(), + operand: format!("{:?}", other), + }) + } + }, + crate::Expression::DotProduct(left_expr, _) => match *self.get(left_expr, types) { + crate::TypeInner::Vector { + kind, + size: _, + width, + } => Resolution::Value(crate::TypeInner::Scalar { kind, width }), + ref other => { + return Err(ResolveError::IncompatibleOperand { + op: "dot product".to_string(), + operand: format!("{:?}", other), + }) + } + }, + crate::Expression::CrossProduct(_, _) => unimplemented!(), + crate::Expression::As { + expr, + kind, + convert: _, + } => match *self.get(expr, types) { + crate::TypeInner::Scalar { kind: _, width } => { + Resolution::Value(crate::TypeInner::Scalar { kind, width }) + } + crate::TypeInner::Vector { + kind: _, + size, + width, + } => Resolution::Value(crate::TypeInner::Vector { kind, size, width }), + ref other => { + return Err(ResolveError::IncompatibleOperand { + op: "as".to_string(), + operand: format!("{:?}", other), + }) + } + }, + crate::Expression::Derivative { .. } => unimplemented!(), + crate::Expression::Call { + origin: crate::FunctionOrigin::External(ref name), + ref arguments, + } => match name.as_str() { + "distance" | "length" => match *self.get(arguments[0], types) { + crate::TypeInner::Vector { kind, width, .. } + | crate::TypeInner::Scalar { kind, width } => { + Resolution::Value(crate::TypeInner::Scalar { kind, width }) + } + ref other => { + return Err(ResolveError::IncompatibleOperand { + op: name.clone(), + operand: format!("{:?}", other), + }) + } + }, + "dot" => match *self.get(arguments[0], types) { + crate::TypeInner::Vector { kind, width, .. } => { + Resolution::Value(crate::TypeInner::Scalar { kind, width }) + } + ref other => { + return Err(ResolveError::IncompatibleOperand { + op: name.clone(), + operand: format!("{:?}", other), + }) + } + }, + //Note: `cross` is here too, we still need to figure out what to do with it + "abs" | "atan2" | "cos" | "sin" | "floor" | "inverse" | "normalize" | "min" + | "max" | "reflect" | "pow" | "clamp" | "fclamp" | "mix" | "step" + | "smoothstep" | "cross" => self.resolutions[arguments[0].index()].clone(), + _ => return Err(ResolveError::FunctionNotDefined { name: name.clone() }), + }, + crate::Expression::Call { + origin: crate::FunctionOrigin::Local(handle), + arguments: _, + } => { + let ty = ctx.functions[handle] + .return_type + .ok_or(ResolveError::FunctionReturnsVoid)?; + Resolution::Handle(ty) + } + crate::Expression::ArrayLength(_) => Resolution::Value(crate::TypeInner::Scalar { + kind: crate::ScalarKind::Uint, + width: 4, + }), + }) + } + + pub fn grow( + &mut self, + expr_handle: Handle<crate::Expression>, + expressions: &Arena<crate::Expression>, + types: &mut Arena<crate::Type>, + ctx: &ResolveContext, + ) -> Result<(), ResolveError> { + if self.resolutions.len() <= expr_handle.index() { + for (eh, expr) in expressions.iter().skip(self.resolutions.len()) { + let resolution = self.resolve_impl(expr, types, ctx)?; + log::debug!("Resolving {:?} = {:?} : {:?}", eh, expr, resolution); + + let ty_handle = match resolution { + Resolution::Handle(h) => h, + Resolution::Value(inner) => types + .fetch_if_or_append(crate::Type { name: None, inner }, |a, b| { + a.inner == b.inner + }), + }; + self.resolutions.push(Resolution::Handle(ty_handle)); + } + } + Ok(()) + } + + pub fn resolve_all( + &mut self, + expressions: &Arena<crate::Expression>, + types: &Arena<crate::Type>, + ctx: &ResolveContext, + ) -> Result<(), ResolveError> { + self.clear(); + for (_, expr) in expressions.iter() { + let resolution = self.resolve_impl(expr, types, ctx)?; + self.resolutions.push(resolution); + } + Ok(()) + } +} + +pub fn check_constant_type(inner: &crate::ConstantInner, type_inner: &crate::TypeInner) -> bool { + match (inner, type_inner) { + ( + crate::ConstantInner::Sint(_), + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Sint, + width: _, + }, + ) => true, + ( + crate::ConstantInner::Uint(_), + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Uint, + width: _, + }, + ) => true, + ( + crate::ConstantInner::Float(_), + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Float, + width: _, + }, + ) => true, + ( + crate::ConstantInner::Bool(_), + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Bool, + width: _, + }, + ) => true, + (crate::ConstantInner::Composite(_inner), _) => true, // TODO recursively check composite types + (_, _) => false, + } +} diff --git a/third_party/rust/naga/src/proc/validator.rs b/third_party/rust/naga/src/proc/validator.rs new file mode 100644 index 0000000000..d9d3eac659 --- /dev/null +++ b/third_party/rust/naga/src/proc/validator.rs @@ -0,0 +1,489 @@ +use super::typifier::{ResolveContext, ResolveError, Typifier}; +use crate::arena::{Arena, Handle}; + +const MAX_BIND_GROUPS: u32 = 8; +const MAX_LOCATIONS: u32 = 64; // using u64 mask +const MAX_BIND_INDICES: u32 = 64; // using u64 mask +const MAX_WORKGROUP_SIZE: u32 = 0x4000; + +#[derive(Debug)] +pub struct Validator { + //Note: this is a bit tricky: some of the front-ends as well as backends + // already have to use the typifier, so the work here is redundant in a way. + typifier: Typifier, +} + +#[derive(Clone, Debug, PartialEq, thiserror::Error)] +pub enum GlobalVariableError { + #[error("Usage isn't compatible with the storage class")] + InvalidUsage, + #[error("Type isn't compatible with the storage class")] + InvalidType, + #[error("Interpolation is not valid")] + InvalidInterpolation, + #[error("Storage access {seen:?} exceed the allowed {allowed:?}")] + InvalidStorageAccess { + allowed: crate::StorageAccess, + seen: crate::StorageAccess, + }, + #[error("Binding decoration is missing or not applicable")] + InvalidBinding, + #[error("Binding is out of range")] + OutOfRangeBinding, +} + +#[derive(Clone, Debug, PartialEq, thiserror::Error)] +pub enum LocalVariableError { + #[error("Initializer is not a constant expression")] + InitializerConst, + #[error("Initializer doesn't match the variable type")] + InitializerType, +} + +#[derive(Clone, Debug, PartialEq, thiserror::Error)] +pub enum FunctionError { + #[error(transparent)] + Resolve(#[from] ResolveError), + #[error("There are instructions after `return`/`break`/`continue`")] + InvalidControlFlowExitTail, + #[error("Local variable {handle:?} '{name}' is invalid: {error:?}")] + LocalVariable { + handle: Handle<crate::LocalVariable>, + name: String, + error: LocalVariableError, + }, +} + +#[derive(Clone, Debug, PartialEq, thiserror::Error)] +pub enum EntryPointError { + #[error("Early depth test is not applicable")] + UnexpectedEarlyDepthTest, + #[error("Workgroup size is not applicable")] + UnexpectedWorkgroupSize, + #[error("Workgroup size is out of range")] + OutOfRangeWorkgroupSize, + #[error("Global variable {0:?} is used incorrectly as {1:?}")] + InvalidGlobalUsage(Handle<crate::GlobalVariable>, crate::GlobalUse), + #[error("Bindings for {0:?} conflict with other global variables")] + BindingCollision(Handle<crate::GlobalVariable>), + #[error("Built-in {0:?} is not applicable to this entry point")] + InvalidBuiltIn(crate::BuiltIn), + #[error("Interpolation of an integer has to be flat")] + InvalidIntegerInterpolation, + #[error(transparent)] + Function(#[from] FunctionError), +} + +#[derive(Clone, Debug, PartialEq, thiserror::Error)] +pub enum ValidationError { + #[error("The type {0:?} width {1} is not supported")] + InvalidTypeWidth(crate::ScalarKind, crate::Bytes), + #[error("The type handle {0:?} can not be resolved")] + UnresolvedType(Handle<crate::Type>), + #[error("The constant {0:?} can not be used for an array size")] + InvalidArraySizeConstant(Handle<crate::Constant>), + #[error("Global variable {handle:?} '{name}' is invalid: {error:?}")] + GlobalVariable { + handle: Handle<crate::GlobalVariable>, + name: String, + error: GlobalVariableError, + }, + #[error("Function {0:?} is invalid: {1:?}")] + Function(Handle<crate::Function>, FunctionError), + #[error("Entry point {name} at {stage:?} is invalid: {error:?}")] + EntryPoint { + stage: crate::ShaderStage, + name: String, + error: EntryPointError, + }, + #[error("Module is corrupted")] + Corrupted, +} + +impl crate::GlobalVariable { + fn forbid_interpolation(&self) -> Result<(), GlobalVariableError> { + match self.interpolation { + Some(_) => Err(GlobalVariableError::InvalidInterpolation), + None => Ok(()), + } + } + + fn check_resource(&self) -> Result<(), GlobalVariableError> { + match self.binding { + Some(crate::Binding::BuiltIn(_)) => {} // validated per entry point + Some(crate::Binding::Resource { group, binding }) => { + if group > MAX_BIND_GROUPS || binding > MAX_BIND_INDICES { + return Err(GlobalVariableError::OutOfRangeBinding); + } + } + Some(crate::Binding::Location(_)) | None => { + return Err(GlobalVariableError::InvalidBinding) + } + } + self.forbid_interpolation() + } +} + +fn storage_usage(access: crate::StorageAccess) -> crate::GlobalUse { + let mut storage_usage = crate::GlobalUse::empty(); + if access.contains(crate::StorageAccess::LOAD) { + storage_usage |= crate::GlobalUse::LOAD; + } + if access.contains(crate::StorageAccess::STORE) { + storage_usage |= crate::GlobalUse::STORE; + } + storage_usage +} + +impl Validator { + /// Construct a new validator instance. + pub fn new() -> Self { + Validator { + typifier: Typifier::new(), + } + } + + fn validate_global_var( + &self, + var: &crate::GlobalVariable, + types: &Arena<crate::Type>, + ) -> Result<(), GlobalVariableError> { + log::debug!("var {:?}", var); + let allowed_storage_access = match var.class { + crate::StorageClass::Function => return Err(GlobalVariableError::InvalidUsage), + crate::StorageClass::Input | crate::StorageClass::Output => { + match var.binding { + Some(crate::Binding::BuiltIn(_)) => { + // validated per entry point + var.forbid_interpolation()? + } + Some(crate::Binding::Location(loc)) => { + if loc > MAX_LOCATIONS { + return Err(GlobalVariableError::OutOfRangeBinding); + } + match types[var.ty].inner { + crate::TypeInner::Scalar { .. } + | crate::TypeInner::Vector { .. } + | crate::TypeInner::Matrix { .. } => {} + _ => return Err(GlobalVariableError::InvalidType), + } + } + Some(crate::Binding::Resource { .. }) => { + return Err(GlobalVariableError::InvalidBinding) + } + None => { + match types[var.ty].inner { + //TODO: check the member types + crate::TypeInner::Struct { members: _ } => { + var.forbid_interpolation()? + } + _ => return Err(GlobalVariableError::InvalidType), + } + } + } + crate::StorageAccess::empty() + } + crate::StorageClass::Storage => { + var.check_resource()?; + crate::StorageAccess::all() + } + crate::StorageClass::Uniform => { + var.check_resource()?; + crate::StorageAccess::empty() + } + crate::StorageClass::Handle => { + var.check_resource()?; + match types[var.ty].inner { + crate::TypeInner::Image { + class: crate::ImageClass::Storage(_), + .. + } => crate::StorageAccess::all(), + _ => crate::StorageAccess::empty(), + } + } + crate::StorageClass::Private | crate::StorageClass::WorkGroup => { + if var.binding.is_some() { + return Err(GlobalVariableError::InvalidBinding); + } + var.forbid_interpolation()?; + crate::StorageAccess::empty() + } + crate::StorageClass::PushConstant => { + //TODO + return Err(GlobalVariableError::InvalidStorageAccess { + allowed: crate::StorageAccess::empty(), + seen: crate::StorageAccess::empty(), + }); + } + }; + + if !allowed_storage_access.contains(var.storage_access) { + return Err(GlobalVariableError::InvalidStorageAccess { + allowed: allowed_storage_access, + seen: var.storage_access, + }); + } + + Ok(()) + } + + fn validate_local_var( + &self, + var: &crate::LocalVariable, + _fun: &crate::Function, + _types: &Arena<crate::Type>, + ) -> Result<(), LocalVariableError> { + log::debug!("var {:?}", var); + if let Some(_expr_handle) = var.init { + if false { + return Err(LocalVariableError::InitializerConst); + } + } + Ok(()) + } + + fn validate_function( + &mut self, + fun: &crate::Function, + module: &crate::Module, + ) -> Result<(), FunctionError> { + let resolve_ctx = ResolveContext { + constants: &module.constants, + global_vars: &module.global_variables, + local_vars: &fun.local_variables, + functions: &module.functions, + arguments: &fun.arguments, + }; + self.typifier + .resolve_all(&fun.expressions, &module.types, &resolve_ctx)?; + + for (var_handle, var) in fun.local_variables.iter() { + self.validate_local_var(var, fun, &module.types) + .map_err(|error| FunctionError::LocalVariable { + handle: var_handle, + name: var.name.clone().unwrap_or_default(), + error, + })?; + } + Ok(()) + } + + fn validate_entry_point( + &mut self, + ep: &crate::EntryPoint, + stage: crate::ShaderStage, + module: &crate::Module, + ) -> Result<(), EntryPointError> { + if ep.early_depth_test.is_some() && stage != crate::ShaderStage::Fragment { + return Err(EntryPointError::UnexpectedEarlyDepthTest); + } + if stage == crate::ShaderStage::Compute { + if ep + .workgroup_size + .iter() + .any(|&s| s == 0 || s > MAX_WORKGROUP_SIZE) + { + return Err(EntryPointError::OutOfRangeWorkgroupSize); + } + } else if ep.workgroup_size != [0; 3] { + return Err(EntryPointError::UnexpectedWorkgroupSize); + } + + let mut bind_group_masks = [0u64; MAX_BIND_GROUPS as usize]; + let mut location_in_mask = 0u64; + let mut location_out_mask = 0u64; + for ((var_handle, var), &usage) in module + .global_variables + .iter() + .zip(&ep.function.global_usage) + { + if usage.is_empty() { + continue; + } + + if let Some(crate::Binding::Location(_)) = var.binding { + match (stage, var.class) { + (crate::ShaderStage::Vertex, crate::StorageClass::Output) + | (crate::ShaderStage::Fragment, crate::StorageClass::Input) => { + match module.types[var.ty].inner.scalar_kind() { + Some(crate::ScalarKind::Float) => {} + Some(_) if var.interpolation != Some(crate::Interpolation::Flat) => { + return Err(EntryPointError::InvalidIntegerInterpolation); + } + _ => {} + } + } + _ => {} + } + } + + let allowed_usage = match var.class { + crate::StorageClass::Function => unreachable!(), + crate::StorageClass::Input => { + let mask = match var.binding { + Some(crate::Binding::BuiltIn(built_in)) => match (stage, built_in) { + (crate::ShaderStage::Vertex, crate::BuiltIn::BaseInstance) + | (crate::ShaderStage::Vertex, crate::BuiltIn::BaseVertex) + | (crate::ShaderStage::Vertex, crate::BuiltIn::InstanceIndex) + | (crate::ShaderStage::Vertex, crate::BuiltIn::VertexIndex) + | (crate::ShaderStage::Fragment, crate::BuiltIn::PointSize) + | (crate::ShaderStage::Fragment, crate::BuiltIn::FragCoord) + | (crate::ShaderStage::Fragment, crate::BuiltIn::FrontFacing) + | (crate::ShaderStage::Fragment, crate::BuiltIn::SampleIndex) + | (crate::ShaderStage::Compute, crate::BuiltIn::GlobalInvocationId) + | (crate::ShaderStage::Compute, crate::BuiltIn::LocalInvocationId) + | (crate::ShaderStage::Compute, crate::BuiltIn::LocalInvocationIndex) + | (crate::ShaderStage::Compute, crate::BuiltIn::WorkGroupId) => 0, + _ => return Err(EntryPointError::InvalidBuiltIn(built_in)), + }, + Some(crate::Binding::Location(loc)) => 1 << loc, + Some(crate::Binding::Resource { .. }) => unreachable!(), + None => 0, + }; + if location_in_mask & mask != 0 { + return Err(EntryPointError::BindingCollision(var_handle)); + } + location_in_mask |= mask; + crate::GlobalUse::LOAD + } + crate::StorageClass::Output => { + let mask = match var.binding { + Some(crate::Binding::BuiltIn(built_in)) => match (stage, built_in) { + (crate::ShaderStage::Vertex, crate::BuiltIn::Position) + | (crate::ShaderStage::Vertex, crate::BuiltIn::PointSize) + | (crate::ShaderStage::Vertex, crate::BuiltIn::ClipDistance) + | (crate::ShaderStage::Fragment, crate::BuiltIn::FragDepth) => 0, + _ => return Err(EntryPointError::InvalidBuiltIn(built_in)), + }, + Some(crate::Binding::Location(loc)) => 1 << loc, + Some(crate::Binding::Resource { .. }) => unreachable!(), + None => 0, + }; + if location_out_mask & mask != 0 { + return Err(EntryPointError::BindingCollision(var_handle)); + } + location_out_mask |= mask; + crate::GlobalUse::LOAD | crate::GlobalUse::STORE + } + crate::StorageClass::Uniform => crate::GlobalUse::LOAD, + crate::StorageClass::Storage => storage_usage(var.storage_access), + crate::StorageClass::Handle => match module.types[var.ty].inner { + crate::TypeInner::Image { + class: crate::ImageClass::Storage(_), + .. + } => storage_usage(var.storage_access), + _ => crate::GlobalUse::LOAD, + }, + crate::StorageClass::Private | crate::StorageClass::WorkGroup => { + crate::GlobalUse::all() + } + crate::StorageClass::PushConstant => crate::GlobalUse::LOAD, + }; + if !allowed_usage.contains(usage) { + log::warn!("\tUsage error for: {:?}", var); + log::warn!( + "\tAllowed usage: {:?}, requested: {:?}", + allowed_usage, + usage + ); + return Err(EntryPointError::InvalidGlobalUsage(var_handle, usage)); + } + + if let Some(crate::Binding::Resource { group, binding }) = var.binding { + let mask = 1 << binding; + let group_mask = &mut bind_group_masks[group as usize]; + if *group_mask & mask != 0 { + return Err(EntryPointError::BindingCollision(var_handle)); + } + *group_mask |= mask; + } + } + + self.validate_function(&ep.function, module)?; + Ok(()) + } + + /// Check the given module to be valid. + pub fn validate(&mut self, module: &crate::Module) -> Result<(), ValidationError> { + // check the types + for (handle, ty) in module.types.iter() { + use crate::TypeInner as Ti; + match ty.inner { + Ti::Scalar { kind, width } | Ti::Vector { kind, width, .. } => { + let expected = match kind { + crate::ScalarKind::Bool => 1, + _ => 4, + }; + if width != expected { + return Err(ValidationError::InvalidTypeWidth(kind, width)); + } + } + Ti::Matrix { width, .. } => { + if width != 4 { + return Err(ValidationError::InvalidTypeWidth( + crate::ScalarKind::Float, + width, + )); + } + } + Ti::Pointer { base, class: _ } => { + if base >= handle { + return Err(ValidationError::UnresolvedType(base)); + } + } + Ti::Array { base, size, .. } => { + if base >= handle { + return Err(ValidationError::UnresolvedType(base)); + } + if let crate::ArraySize::Constant(const_handle) = size { + let constant = module + .constants + .try_get(const_handle) + .ok_or(ValidationError::Corrupted)?; + match constant.inner { + crate::ConstantInner::Uint(_) => {} + _ => { + return Err(ValidationError::InvalidArraySizeConstant(const_handle)) + } + } + } + } + Ti::Struct { ref members } => { + //TODO: check that offsets are not intersecting? + for member in members { + if member.ty >= handle { + return Err(ValidationError::UnresolvedType(member.ty)); + } + } + } + Ti::Image { .. } => {} + Ti::Sampler { comparison: _ } => {} + } + } + + for (var_handle, var) in module.global_variables.iter() { + self.validate_global_var(var, &module.types) + .map_err(|error| ValidationError::GlobalVariable { + handle: var_handle, + name: var.name.clone().unwrap_or_default(), + error, + })?; + } + + for (fun_handle, fun) in module.functions.iter() { + self.validate_function(fun, module) + .map_err(|e| ValidationError::Function(fun_handle, e))?; + } + + for (&(stage, ref name), entry_point) in module.entry_points.iter() { + self.validate_entry_point(entry_point, stage, module) + .map_err(|error| ValidationError::EntryPoint { + stage, + name: name.to_string(), + error, + })?; + } + + Ok(()) + } +} |