diff options
Diffstat (limited to 'third_party/rust/naga/src/front')
24 files changed, 944 insertions, 316 deletions
diff --git a/third_party/rust/naga/src/front/glsl/context.rs b/third_party/rust/naga/src/front/glsl/context.rs index f26c57965d..6ba7df593a 100644 --- a/third_party/rust/naga/src/front/glsl/context.rs +++ b/third_party/rust/naga/src/front/glsl/context.rs @@ -77,12 +77,19 @@ pub struct Context<'a> { pub body: Block, pub module: &'a mut crate::Module, pub is_const: bool, - /// Tracks the constness of `Expression`s residing in `self.expressions` - pub expression_constness: crate::proc::ExpressionConstnessTracker, + /// Tracks the expression kind of `Expression`s residing in `self.expressions` + pub local_expression_kind_tracker: crate::proc::ExpressionKindTracker, + /// Tracks the expression kind of `Expression`s residing in `self.module.global_expressions` + pub global_expression_kind_tracker: &'a mut crate::proc::ExpressionKindTracker, } impl<'a> Context<'a> { - pub fn new(frontend: &Frontend, module: &'a mut crate::Module, is_const: bool) -> Result<Self> { + pub fn new( + frontend: &Frontend, + module: &'a mut crate::Module, + is_const: bool, + global_expression_kind_tracker: &'a mut crate::proc::ExpressionKindTracker, + ) -> Result<Self> { let mut this = Context { expressions: Arena::new(), locals: Arena::new(), @@ -101,7 +108,8 @@ impl<'a> Context<'a> { body: Block::new(), module, is_const: false, - expression_constness: crate::proc::ExpressionConstnessTracker::new(), + local_expression_kind_tracker: crate::proc::ExpressionKindTracker::new(), + global_expression_kind_tracker, }; this.emit_start(); @@ -249,40 +257,24 @@ impl<'a> Context<'a> { pub fn add_expression(&mut self, expr: Expression, meta: Span) -> Result<Handle<Expression>> { let mut eval = if self.is_const { - crate::proc::ConstantEvaluator::for_glsl_module(self.module) + crate::proc::ConstantEvaluator::for_glsl_module( + self.module, + self.global_expression_kind_tracker, + ) } else { crate::proc::ConstantEvaluator::for_glsl_function( self.module, &mut self.expressions, - &mut self.expression_constness, + &mut self.local_expression_kind_tracker, &mut self.emitter, &mut self.body, ) }; - let res = eval.try_eval_and_append(&expr, meta).map_err(|e| Error { + eval.try_eval_and_append(expr, meta).map_err(|e| Error { kind: e.into(), meta, - }); - - match res { - Ok(expr) => Ok(expr), - Err(e) => { - if self.is_const { - Err(e) - } else { - let needs_pre_emit = expr.needs_pre_emit(); - if needs_pre_emit { - self.body.extend(self.emitter.finish(&self.expressions)); - } - let h = self.expressions.append(expr, meta); - if needs_pre_emit { - self.emitter.start(&self.expressions); - } - Ok(h) - } - } - } + }) } /// Add variable to current scope @@ -1479,7 +1471,7 @@ impl Index<Handle<Expression>> for Context<'_> { fn index(&self, index: Handle<Expression>) -> &Self::Output { if self.is_const { - &self.module.const_expressions[index] + &self.module.global_expressions[index] } else { &self.expressions[index] } diff --git a/third_party/rust/naga/src/front/glsl/error.rs b/third_party/rust/naga/src/front/glsl/error.rs index bd16ee30bc..e0771437e6 100644 --- a/third_party/rust/naga/src/front/glsl/error.rs +++ b/third_party/rust/naga/src/front/glsl/error.rs @@ -1,4 +1,5 @@ use super::token::TokenValue; +use crate::SourceLocation; use crate::{proc::ConstantEvaluatorError, Span}; use codespan_reporting::diagnostic::{Diagnostic, Label}; use codespan_reporting::files::SimpleFile; @@ -137,14 +138,21 @@ pub struct Error { pub meta: Span, } +impl Error { + /// Returns a [`SourceLocation`] for the error message. + pub fn location(&self, source: &str) -> Option<SourceLocation> { + Some(self.meta.location(source)) + } +} + /// A collection of errors returned during shader parsing. #[derive(Clone, Debug)] #[cfg_attr(test, derive(PartialEq))] -pub struct ParseError { +pub struct ParseErrors { pub errors: Vec<Error>, } -impl ParseError { +impl ParseErrors { pub fn emit_to_writer(&self, writer: &mut impl WriteColor, source: &str) { self.emit_to_writer_with_path(writer, source, "glsl"); } @@ -172,19 +180,19 @@ impl ParseError { } } -impl std::fmt::Display for ParseError { +impl std::fmt::Display for ParseErrors { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { self.errors.iter().try_for_each(|e| write!(f, "{e:?}")) } } -impl std::error::Error for ParseError { +impl std::error::Error for ParseErrors { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { None } } -impl From<Vec<Error>> for ParseError { +impl From<Vec<Error>> for ParseErrors { fn from(errors: Vec<Error>) -> Self { Self { errors } } diff --git a/third_party/rust/naga/src/front/glsl/functions.rs b/third_party/rust/naga/src/front/glsl/functions.rs index 01846eb814..fa1bbef56b 100644 --- a/third_party/rust/naga/src/front/glsl/functions.rs +++ b/third_party/rust/naga/src/front/glsl/functions.rs @@ -1236,6 +1236,8 @@ impl Frontend { let pointer = ctx .expressions .append(Expression::GlobalVariable(arg.handle), Default::default()); + ctx.local_expression_kind_tracker + .insert(pointer, crate::proc::ExpressionKind::Runtime); let ty = ctx.module.global_variables[arg.handle].ty; @@ -1256,6 +1258,8 @@ impl Frontend { let value = ctx .expressions .append(Expression::FunctionArgument(idx), Default::default()); + ctx.local_expression_kind_tracker + .insert(value, crate::proc::ExpressionKind::Runtime); ctx.body .push(Statement::Store { pointer, value }, Default::default()); }, @@ -1285,6 +1289,8 @@ impl Frontend { let pointer = ctx .expressions .append(Expression::GlobalVariable(arg.handle), Default::default()); + ctx.local_expression_kind_tracker + .insert(pointer, crate::proc::ExpressionKind::Runtime); let ty = ctx.module.global_variables[arg.handle].ty; @@ -1307,6 +1313,8 @@ impl Frontend { let load = ctx .expressions .append(Expression::Load { pointer }, Default::default()); + ctx.local_expression_kind_tracker + .insert(load, crate::proc::ExpressionKind::Runtime); ctx.body.push( Statement::Emit(ctx.expressions.range_from(len)), Default::default(), @@ -1329,6 +1337,8 @@ impl Frontend { let res = ctx .expressions .append(Expression::Compose { ty, components }, Default::default()); + ctx.local_expression_kind_tracker + .insert(res, crate::proc::ExpressionKind::Runtime); ctx.body.push( Statement::Emit(ctx.expressions.range_from(len)), Default::default(), diff --git a/third_party/rust/naga/src/front/glsl/mod.rs b/third_party/rust/naga/src/front/glsl/mod.rs index 75f3929db4..ea202b2445 100644 --- a/third_party/rust/naga/src/front/glsl/mod.rs +++ b/third_party/rust/naga/src/front/glsl/mod.rs @@ -13,7 +13,7 @@ To begin, take a look at the documentation for the [`Frontend`]. */ pub use ast::{Precision, Profile}; -pub use error::{Error, ErrorKind, ExpectedToken, ParseError}; +pub use error::{Error, ErrorKind, ExpectedToken, ParseErrors}; pub use token::TokenValue; use crate::{proc::Layouter, FastHashMap, FastHashSet, Handle, Module, ShaderStage, Span, Type}; @@ -196,7 +196,7 @@ impl Frontend { &mut self, options: &Options, source: &str, - ) -> std::result::Result<Module, ParseError> { + ) -> std::result::Result<Module, ParseErrors> { self.reset(options.stage); let lexer = lex::Lexer::new(source, &options.defines); diff --git a/third_party/rust/naga/src/front/glsl/parser.rs b/third_party/rust/naga/src/front/glsl/parser.rs index 851d2e1d79..28e0808063 100644 --- a/third_party/rust/naga/src/front/glsl/parser.rs +++ b/third_party/rust/naga/src/front/glsl/parser.rs @@ -164,9 +164,15 @@ impl<'source> ParsingContext<'source> { pub fn parse(&mut self, frontend: &mut Frontend) -> Result<Module> { let mut module = Module::default(); + let mut global_expression_kind_tracker = crate::proc::ExpressionKindTracker::new(); // Body and expression arena for global initialization - let mut ctx = Context::new(frontend, &mut module, false)?; + let mut ctx = Context::new( + frontend, + &mut module, + false, + &mut global_expression_kind_tracker, + )?; while self.peek(frontend).is_some() { self.parse_external_declaration(frontend, &mut ctx)?; @@ -196,7 +202,11 @@ impl<'source> ParsingContext<'source> { frontend: &mut Frontend, ctx: &mut Context, ) -> Result<(u32, Span)> { - let (const_expr, meta) = self.parse_constant_expression(frontend, ctx.module)?; + let (const_expr, meta) = self.parse_constant_expression( + frontend, + ctx.module, + ctx.global_expression_kind_tracker, + )?; let res = ctx.module.to_ctx().eval_expr_to_u32(const_expr); @@ -219,8 +229,9 @@ impl<'source> ParsingContext<'source> { &mut self, frontend: &mut Frontend, module: &mut Module, + global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker, ) -> Result<(Handle<Expression>, Span)> { - let mut ctx = Context::new(frontend, module, true)?; + let mut ctx = Context::new(frontend, module, true, global_expression_kind_tracker)?; let mut stmt_ctx = ctx.stmt_ctx(); let expr = self.parse_conditional(frontend, &mut ctx, &mut stmt_ctx, None)?; diff --git a/third_party/rust/naga/src/front/glsl/parser/declarations.rs b/third_party/rust/naga/src/front/glsl/parser/declarations.rs index f5e38fb016..2d253a378d 100644 --- a/third_party/rust/naga/src/front/glsl/parser/declarations.rs +++ b/third_party/rust/naga/src/front/glsl/parser/declarations.rs @@ -251,7 +251,7 @@ impl<'source> ParsingContext<'source> { init.and_then(|expr| ctx.ctx.lift_up_const_expression(expr).ok()); late_initializer = None; } else if let Some(init) = init { - if ctx.is_inside_loop || !ctx.ctx.expression_constness.is_const(init) { + if ctx.is_inside_loop || !ctx.ctx.local_expression_kind_tracker.is_const(init) { decl_initializer = None; late_initializer = Some(init); } else { @@ -326,7 +326,12 @@ impl<'source> ParsingContext<'source> { let result = ty.map(|ty| FunctionResult { ty, binding: None }); - let mut context = Context::new(frontend, ctx.module, false)?; + let mut context = Context::new( + frontend, + ctx.module, + false, + ctx.global_expression_kind_tracker, + )?; self.parse_function_args(frontend, &mut context)?; diff --git a/third_party/rust/naga/src/front/glsl/parser/functions.rs b/third_party/rust/naga/src/front/glsl/parser/functions.rs index d428d74761..d0c889e4d3 100644 --- a/third_party/rust/naga/src/front/glsl/parser/functions.rs +++ b/third_party/rust/naga/src/front/glsl/parser/functions.rs @@ -192,10 +192,13 @@ impl<'source> ParsingContext<'source> { TokenValue::Case => { self.bump(frontend)?; - let (const_expr, meta) = - self.parse_constant_expression(frontend, ctx.module)?; + let (const_expr, meta) = self.parse_constant_expression( + frontend, + ctx.module, + ctx.global_expression_kind_tracker, + )?; - match ctx.module.const_expressions[const_expr] { + match ctx.module.global_expressions[const_expr] { Expression::Literal(Literal::I32(value)) => match uint { // This unchecked cast isn't good, but since // we only reach this code when the selector diff --git a/third_party/rust/naga/src/front/glsl/parser_tests.rs b/third_party/rust/naga/src/front/glsl/parser_tests.rs index 259052cd27..135765ca58 100644 --- a/third_party/rust/naga/src/front/glsl/parser_tests.rs +++ b/third_party/rust/naga/src/front/glsl/parser_tests.rs @@ -1,7 +1,7 @@ use super::{ ast::Profile, error::ExpectedToken, - error::{Error, ErrorKind, ParseError}, + error::{Error, ErrorKind, ParseErrors}, token::TokenValue, Frontend, Options, Span, }; @@ -21,7 +21,7 @@ fn version() { ) .err() .unwrap(), - ParseError { + ParseErrors { errors: vec![Error { kind: ErrorKind::InvalidVersion(99000), meta: Span::new(9, 14) @@ -37,7 +37,7 @@ fn version() { ) .err() .unwrap(), - ParseError { + ParseErrors { errors: vec![Error { kind: ErrorKind::InvalidVersion(449), meta: Span::new(9, 12) @@ -53,7 +53,7 @@ fn version() { ) .err() .unwrap(), - ParseError { + ParseErrors { errors: vec![Error { kind: ErrorKind::InvalidProfile("smart".into()), meta: Span::new(13, 18), @@ -69,7 +69,7 @@ fn version() { ) .err() .unwrap(), - ParseError { + ParseErrors { errors: vec![ Error { kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedHash,), @@ -455,7 +455,7 @@ fn functions() { ) .err() .unwrap(), - ParseError { + ParseErrors { errors: vec![Error { kind: ErrorKind::SemanticError("Function already defined".into()), meta: Span::new(134, 152), @@ -539,7 +539,7 @@ fn constants() { let mut types = module.types.iter(); let mut constants = module.constants.iter(); - let mut const_expressions = module.const_expressions.iter(); + let mut global_expressions = module.global_expressions.iter(); let (ty_handle, ty) = types.next().unwrap(); assert_eq!( @@ -550,14 +550,13 @@ fn constants() { } ); - let (init_handle, init) = const_expressions.next().unwrap(); + let (init_handle, init) = global_expressions.next().unwrap(); assert_eq!(init, &Expression::Literal(crate::Literal::F32(1.0))); assert_eq!( constants.next().unwrap().1, &Constant { name: Some("a".to_owned()), - r#override: crate::Override::None, ty: ty_handle, init: init_handle } @@ -567,7 +566,6 @@ fn constants() { constants.next().unwrap().1, &Constant { name: Some("b".to_owned()), - r#override: crate::Override::None, ty: ty_handle, init: init_handle } @@ -636,7 +634,7 @@ fn implicit_conversions() { ) .err() .unwrap(), - ParseError { + ParseErrors { errors: vec![Error { kind: ErrorKind::SemanticError("Unknown function \'test\'".into()), meta: Span::new(156, 165), @@ -660,7 +658,7 @@ fn implicit_conversions() { ) .err() .unwrap(), - ParseError { + ParseErrors { errors: vec![Error { kind: ErrorKind::SemanticError("Ambiguous best function for \'test\'".into()), meta: Span::new(158, 165), diff --git a/third_party/rust/naga/src/front/glsl/types.rs b/third_party/rust/naga/src/front/glsl/types.rs index e87d76fffc..f6836169c0 100644 --- a/third_party/rust/naga/src/front/glsl/types.rs +++ b/third_party/rust/naga/src/front/glsl/types.rs @@ -233,7 +233,7 @@ impl Context<'_> { }; let expressions = if self.is_const { - &self.module.const_expressions + &self.module.global_expressions } else { &self.expressions }; @@ -330,23 +330,25 @@ impl Context<'_> { expr: Handle<Expression>, ) -> Result<Handle<Expression>> { let meta = self.expressions.get_span(expr); - Ok(match self.expressions[expr] { + let h = match self.expressions[expr] { ref expr @ (Expression::Literal(_) | Expression::Constant(_) - | Expression::ZeroValue(_)) => self.module.const_expressions.append(expr.clone(), meta), + | Expression::ZeroValue(_)) => { + self.module.global_expressions.append(expr.clone(), meta) + } Expression::Compose { ty, ref components } => { let mut components = components.clone(); for component in &mut components { *component = self.lift_up_const_expression(*component)?; } self.module - .const_expressions + .global_expressions .append(Expression::Compose { ty, components }, meta) } Expression::Splat { size, value } => { let value = self.lift_up_const_expression(value)?; self.module - .const_expressions + .global_expressions .append(Expression::Splat { size, value }, meta) } _ => { @@ -355,6 +357,9 @@ impl Context<'_> { meta, }) } - }) + }; + self.global_expression_kind_tracker + .insert(h, crate::proc::ExpressionKind::Const); + Ok(h) } } diff --git a/third_party/rust/naga/src/front/glsl/variables.rs b/third_party/rust/naga/src/front/glsl/variables.rs index 9d2e7a0e7b..0725fbd94f 100644 --- a/third_party/rust/naga/src/front/glsl/variables.rs +++ b/third_party/rust/naga/src/front/glsl/variables.rs @@ -472,7 +472,6 @@ impl Frontend { let constant = Constant { name: name.clone(), - r#override: crate::Override::None, ty, init, }; diff --git a/third_party/rust/naga/src/front/spv/convert.rs b/third_party/rust/naga/src/front/spv/convert.rs index f0a714fbeb..a6bf0e0451 100644 --- a/third_party/rust/naga/src/front/spv/convert.rs +++ b/third_party/rust/naga/src/front/spv/convert.rs @@ -153,6 +153,11 @@ pub(super) fn map_builtin(word: spirv::Word, invariant: bool) -> Result<crate::B Some(Bi::WorkgroupId) => crate::BuiltIn::WorkGroupId, Some(Bi::WorkgroupSize) => crate::BuiltIn::WorkGroupSize, Some(Bi::NumWorkgroups) => crate::BuiltIn::NumWorkGroups, + // subgroup + Some(Bi::NumSubgroups) => crate::BuiltIn::NumSubgroups, + Some(Bi::SubgroupId) => crate::BuiltIn::SubgroupId, + Some(Bi::SubgroupSize) => crate::BuiltIn::SubgroupSize, + Some(Bi::SubgroupLocalInvocationId) => crate::BuiltIn::SubgroupInvocationId, _ => return Err(Error::UnsupportedBuiltIn(word)), }) } diff --git a/third_party/rust/naga/src/front/spv/error.rs b/third_party/rust/naga/src/front/spv/error.rs index af025636c0..44beadce98 100644 --- a/third_party/rust/naga/src/front/spv/error.rs +++ b/third_party/rust/naga/src/front/spv/error.rs @@ -5,7 +5,7 @@ use codespan_reporting::files::SimpleFile; use codespan_reporting::term; use termcolor::{NoColor, WriteColor}; -#[derive(Debug, thiserror::Error)] +#[derive(Clone, Debug, thiserror::Error)] pub enum Error { #[error("invalid header")] InvalidHeader, @@ -58,6 +58,8 @@ pub enum Error { UnknownBinaryOperator(spirv::Op), #[error("unknown relational function {0:?}")] UnknownRelationalFunction(spirv::Op), + #[error("unsupported group operation %{0}")] + UnsupportedGroupOperation(spirv::Word), #[error("invalid parameter {0:?}")] InvalidParameter(spirv::Op), #[error("invalid operand count {1} for {0:?}")] @@ -118,8 +120,8 @@ pub enum Error { ControlFlowGraphCycle(crate::front::spv::BlockId), #[error("recursive function call %{0}")] FunctionCallCycle(spirv::Word), - #[error("invalid array size {0:?}")] - InvalidArraySize(Handle<crate::Constant>), + #[error("invalid array size %{0}")] + InvalidArraySize(spirv::Word), #[error("invalid barrier scope %{0}")] InvalidBarrierScope(spirv::Word), #[error("invalid barrier memory semantics %{0}")] @@ -130,6 +132,8 @@ pub enum Error { come from a binding)" )] NonBindingArrayOfImageOrSamplers, + #[error("naga only supports specialization constant IDs up to 65535 but was given {0}")] + SpecIdTooHigh(u32), } impl Error { diff --git a/third_party/rust/naga/src/front/spv/function.rs b/third_party/rust/naga/src/front/spv/function.rs index e81ecf5c9b..113ca56313 100644 --- a/third_party/rust/naga/src/front/spv/function.rs +++ b/third_party/rust/naga/src/front/spv/function.rs @@ -59,8 +59,11 @@ impl<I: Iterator<Item = u32>> super::Frontend<I> { }) }, local_variables: Arena::new(), - expressions: self - .make_expression_storage(&module.global_variables, &module.constants), + expressions: self.make_expression_storage( + &module.global_variables, + &module.constants, + &module.overrides, + ), named_expressions: crate::NamedExpressions::default(), body: crate::Block::new(), } @@ -128,7 +131,8 @@ impl<I: Iterator<Item = u32>> super::Frontend<I> { expressions: &mut fun.expressions, local_arena: &mut fun.local_variables, const_arena: &mut module.constants, - const_expressions: &mut module.const_expressions, + overrides: &mut module.overrides, + global_expressions: &mut module.global_expressions, type_arena: &module.types, global_arena: &module.global_variables, arguments: &fun.arguments, @@ -581,7 +585,8 @@ impl<'function> BlockContext<'function> { crate::proc::GlobalCtx { types: self.type_arena, constants: self.const_arena, - const_expressions: self.const_expressions, + overrides: self.overrides, + global_expressions: self.global_expressions, } } diff --git a/third_party/rust/naga/src/front/spv/image.rs b/third_party/rust/naga/src/front/spv/image.rs index 0f25dd626b..284c4cf7fd 100644 --- a/third_party/rust/naga/src/front/spv/image.rs +++ b/third_party/rust/naga/src/front/spv/image.rs @@ -507,11 +507,14 @@ impl<I: Iterator<Item = u32>> super::Frontend<I> { } spirv::ImageOperands::CONST_OFFSET => { let offset_constant = self.next()?; - let offset_handle = self.lookup_constant.lookup(offset_constant)?.handle; - let offset_handle = ctx.const_expressions.append( - crate::Expression::Constant(offset_handle), - Default::default(), - ); + let offset_expr = self + .lookup_constant + .lookup(offset_constant)? + .inner + .to_expr(); + let offset_handle = ctx + .global_expressions + .append(offset_expr, Default::default()); offset = Some(offset_handle); words_left -= 1; } diff --git a/third_party/rust/naga/src/front/spv/mod.rs b/third_party/rust/naga/src/front/spv/mod.rs index b793448597..7ac5a18cd6 100644 --- a/third_party/rust/naga/src/front/spv/mod.rs +++ b/third_party/rust/naga/src/front/spv/mod.rs @@ -196,7 +196,7 @@ struct Decoration { location: Option<spirv::Word>, desc_set: Option<spirv::Word>, desc_index: Option<spirv::Word>, - specialization: Option<spirv::Word>, + specialization_constant_id: Option<spirv::Word>, storage_buffer: bool, offset: Option<spirv::Word>, array_stride: Option<NonZeroU32>, @@ -216,11 +216,6 @@ impl Decoration { } } - fn specialization(&self) -> crate::Override { - self.specialization - .map_or(crate::Override::None, crate::Override::ByNameOrId) - } - const fn resource_binding(&self) -> Option<crate::ResourceBinding> { match *self { Decoration { @@ -284,8 +279,23 @@ struct LookupType { } #[derive(Debug)] +enum Constant { + Constant(Handle<crate::Constant>), + Override(Handle<crate::Override>), +} + +impl Constant { + const fn to_expr(&self) -> crate::Expression { + match *self { + Self::Constant(c) => crate::Expression::Constant(c), + Self::Override(o) => crate::Expression::Override(o), + } + } +} + +#[derive(Debug)] struct LookupConstant { - handle: Handle<crate::Constant>, + inner: Constant, type_id: spirv::Word, } @@ -537,7 +547,8 @@ struct BlockContext<'function> { local_arena: &'function mut Arena<crate::LocalVariable>, /// Constants arena of the module being processed const_arena: &'function mut Arena<crate::Constant>, - const_expressions: &'function mut Arena<crate::Expression>, + overrides: &'function mut Arena<crate::Override>, + global_expressions: &'function mut Arena<crate::Expression>, /// Type arena of the module being processed type_arena: &'function UniqueArena<crate::Type>, /// Global arena of the module being processed @@ -757,7 +768,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> { dec.matrix_major = Some(Majority::Row); } spirv::Decoration::SpecId => { - dec.specialization = Some(self.next()?); + dec.specialization_constant_id = Some(self.next()?); } other => { log::warn!("Unknown decoration {:?}", other); @@ -1393,10 +1404,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> { inst.expect(5)?; let init_id = self.next()?; let lconst = self.lookup_constant.lookup(init_id)?; - Some( - ctx.expressions - .append(crate::Expression::Constant(lconst.handle), span), - ) + Some(ctx.expressions.append(lconst.inner.to_expr(), span)) } else { None }; @@ -3650,9 +3658,9 @@ impl<I: Iterator<Item = u32>> Frontend<I> { let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; let semantics_const = self.lookup_constant.lookup(semantics_id)?; - let exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle) + let exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner) .ok_or(Error::InvalidBarrierScope(exec_scope_id))?; - let semantics = resolve_constant(ctx.gctx(), semantics_const.handle) + let semantics = resolve_constant(ctx.gctx(), &semantics_const.inner) .ok_or(Error::InvalidBarrierMemorySemantics(semantics_id))?; if exec_scope == spirv::Scope::Workgroup as u32 { @@ -3692,6 +3700,254 @@ impl<I: Iterator<Item = u32>> Frontend<I> { }, ); } + Op::GroupNonUniformBallot => { + inst.expect(5)?; + block.extend(emitter.finish(ctx.expressions)); + let result_type_id = self.next()?; + let result_id = self.next()?; + let exec_scope_id = self.next()?; + let predicate_id = self.next()?; + + let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; + let _exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner) + .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) + .ok_or(Error::InvalidBarrierScope(exec_scope_id))?; + + let predicate = if self + .lookup_constant + .lookup(predicate_id) + .ok() + .filter(|predicate_const| match predicate_const.inner { + Constant::Constant(constant) => matches!( + ctx.gctx().global_expressions[ctx.gctx().constants[constant].init], + crate::Expression::Literal(crate::Literal::Bool(true)), + ), + Constant::Override(_) => false, + }) + .is_some() + { + None + } else { + let predicate_lookup = self.lookup_expression.lookup(predicate_id)?; + let predicate_handle = get_expr_handle!(predicate_id, predicate_lookup); + Some(predicate_handle) + }; + + let result_handle = ctx + .expressions + .append(crate::Expression::SubgroupBallotResult, span); + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: result_handle, + type_id: result_type_id, + block_id, + }, + ); + + block.push( + crate::Statement::SubgroupBallot { + result: result_handle, + predicate, + }, + span, + ); + emitter.start(ctx.expressions); + } + spirv::Op::GroupNonUniformAll + | spirv::Op::GroupNonUniformAny + | spirv::Op::GroupNonUniformIAdd + | spirv::Op::GroupNonUniformFAdd + | spirv::Op::GroupNonUniformIMul + | spirv::Op::GroupNonUniformFMul + | spirv::Op::GroupNonUniformSMax + | spirv::Op::GroupNonUniformUMax + | spirv::Op::GroupNonUniformFMax + | spirv::Op::GroupNonUniformSMin + | spirv::Op::GroupNonUniformUMin + | spirv::Op::GroupNonUniformFMin + | spirv::Op::GroupNonUniformBitwiseAnd + | spirv::Op::GroupNonUniformBitwiseOr + | spirv::Op::GroupNonUniformBitwiseXor + | spirv::Op::GroupNonUniformLogicalAnd + | spirv::Op::GroupNonUniformLogicalOr + | spirv::Op::GroupNonUniformLogicalXor => { + block.extend(emitter.finish(ctx.expressions)); + inst.expect( + if matches!( + inst.op, + spirv::Op::GroupNonUniformAll | spirv::Op::GroupNonUniformAny + ) { + 5 + } else { + 6 + }, + )?; + let result_type_id = self.next()?; + let result_id = self.next()?; + let exec_scope_id = self.next()?; + let collective_op_id = match inst.op { + spirv::Op::GroupNonUniformAll | spirv::Op::GroupNonUniformAny => { + crate::CollectiveOperation::Reduce + } + _ => { + let group_op_id = self.next()?; + match spirv::GroupOperation::from_u32(group_op_id) { + Some(spirv::GroupOperation::Reduce) => { + crate::CollectiveOperation::Reduce + } + Some(spirv::GroupOperation::InclusiveScan) => { + crate::CollectiveOperation::InclusiveScan + } + Some(spirv::GroupOperation::ExclusiveScan) => { + crate::CollectiveOperation::ExclusiveScan + } + _ => return Err(Error::UnsupportedGroupOperation(group_op_id)), + } + } + }; + let argument_id = self.next()?; + + let argument_lookup = self.lookup_expression.lookup(argument_id)?; + let argument_handle = get_expr_handle!(argument_id, argument_lookup); + + let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; + let _exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner) + .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) + .ok_or(Error::InvalidBarrierScope(exec_scope_id))?; + + let op_id = match inst.op { + spirv::Op::GroupNonUniformAll => crate::SubgroupOperation::All, + spirv::Op::GroupNonUniformAny => crate::SubgroupOperation::Any, + spirv::Op::GroupNonUniformIAdd | spirv::Op::GroupNonUniformFAdd => { + crate::SubgroupOperation::Add + } + spirv::Op::GroupNonUniformIMul | spirv::Op::GroupNonUniformFMul => { + crate::SubgroupOperation::Mul + } + spirv::Op::GroupNonUniformSMax + | spirv::Op::GroupNonUniformUMax + | spirv::Op::GroupNonUniformFMax => crate::SubgroupOperation::Max, + spirv::Op::GroupNonUniformSMin + | spirv::Op::GroupNonUniformUMin + | spirv::Op::GroupNonUniformFMin => crate::SubgroupOperation::Min, + spirv::Op::GroupNonUniformBitwiseAnd + | spirv::Op::GroupNonUniformLogicalAnd => crate::SubgroupOperation::And, + spirv::Op::GroupNonUniformBitwiseOr + | spirv::Op::GroupNonUniformLogicalOr => crate::SubgroupOperation::Or, + spirv::Op::GroupNonUniformBitwiseXor + | spirv::Op::GroupNonUniformLogicalXor => crate::SubgroupOperation::Xor, + _ => unreachable!(), + }; + + let result_type = self.lookup_type.lookup(result_type_id)?; + + let result_handle = ctx.expressions.append( + crate::Expression::SubgroupOperationResult { + ty: result_type.handle, + }, + span, + ); + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: result_handle, + type_id: result_type_id, + block_id, + }, + ); + + block.push( + crate::Statement::SubgroupCollectiveOperation { + result: result_handle, + op: op_id, + collective_op: collective_op_id, + argument: argument_handle, + }, + span, + ); + emitter.start(ctx.expressions); + } + Op::GroupNonUniformBroadcastFirst + | Op::GroupNonUniformBroadcast + | Op::GroupNonUniformShuffle + | Op::GroupNonUniformShuffleDown + | Op::GroupNonUniformShuffleUp + | Op::GroupNonUniformShuffleXor => { + inst.expect( + if matches!(inst.op, spirv::Op::GroupNonUniformBroadcastFirst) { + 5 + } else { + 6 + }, + )?; + block.extend(emitter.finish(ctx.expressions)); + let result_type_id = self.next()?; + let result_id = self.next()?; + let exec_scope_id = self.next()?; + let argument_id = self.next()?; + + let argument_lookup = self.lookup_expression.lookup(argument_id)?; + let argument_handle = get_expr_handle!(argument_id, argument_lookup); + + let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; + let _exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner) + .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) + .ok_or(Error::InvalidBarrierScope(exec_scope_id))?; + + let mode = if matches!(inst.op, spirv::Op::GroupNonUniformBroadcastFirst) { + crate::GatherMode::BroadcastFirst + } else { + let index_id = self.next()?; + let index_lookup = self.lookup_expression.lookup(index_id)?; + let index_handle = get_expr_handle!(index_id, index_lookup); + match inst.op { + spirv::Op::GroupNonUniformBroadcast => { + crate::GatherMode::Broadcast(index_handle) + } + spirv::Op::GroupNonUniformShuffle => { + crate::GatherMode::Shuffle(index_handle) + } + spirv::Op::GroupNonUniformShuffleDown => { + crate::GatherMode::ShuffleDown(index_handle) + } + spirv::Op::GroupNonUniformShuffleUp => { + crate::GatherMode::ShuffleUp(index_handle) + } + spirv::Op::GroupNonUniformShuffleXor => { + crate::GatherMode::ShuffleXor(index_handle) + } + _ => unreachable!(), + } + }; + + let result_type = self.lookup_type.lookup(result_type_id)?; + + let result_handle = ctx.expressions.append( + crate::Expression::SubgroupOperationResult { + ty: result_type.handle, + }, + span, + ); + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: result_handle, + type_id: result_type_id, + block_id, + }, + ); + + block.push( + crate::Statement::SubgroupGather { + result: result_handle, + mode, + argument: argument_handle, + }, + span, + ); + emitter.start(ctx.expressions); + } _ => return Err(Error::UnsupportedInstruction(self.state, inst.op)), } }; @@ -3713,6 +3969,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> { &mut self, globals: &Arena<crate::GlobalVariable>, constants: &Arena<crate::Constant>, + overrides: &Arena<crate::Override>, ) -> Arena<crate::Expression> { let mut expressions = Arena::new(); #[allow(clippy::panic)] @@ -3737,8 +3994,11 @@ impl<I: Iterator<Item = u32>> Frontend<I> { } // register constants for (&id, con) in self.lookup_constant.iter() { - let span = constants.get_span(con.handle); - let handle = expressions.append(crate::Expression::Constant(con.handle), span); + let (expr, span) = match con.inner { + Constant::Constant(c) => (crate::Expression::Constant(c), constants.get_span(c)), + Constant::Override(o) => (crate::Expression::Override(o), overrides.get_span(o)), + }; + let handle = expressions.append(expr, span); self.lookup_expression.insert( id, LookupExpression { @@ -3812,7 +4072,10 @@ impl<I: Iterator<Item = u32>> Frontend<I> { | S::Store { .. } | S::ImageStore { .. } | S::Atomic { .. } - | S::RayQuery { .. } => {} + | S::RayQuery { .. } + | S::SubgroupBallot { .. } + | S::SubgroupCollectiveOperation { .. } + | S::SubgroupGather { .. } => {} S::Call { function: ref mut callee, ref arguments, @@ -3944,10 +4207,16 @@ impl<I: Iterator<Item = u32>> Frontend<I> { Op::TypeSampledImage => self.parse_type_sampled_image(inst), Op::TypeSampler => self.parse_type_sampler(inst, &mut module), Op::Constant | Op::SpecConstant => self.parse_constant(inst, &mut module), - Op::ConstantComposite => self.parse_composite_constant(inst, &mut module), + Op::ConstantComposite | Op::SpecConstantComposite => { + self.parse_composite_constant(inst, &mut module) + } Op::ConstantNull | Op::Undef => self.parse_null_constant(inst, &mut module), - Op::ConstantTrue => self.parse_bool_constant(inst, true, &mut module), - Op::ConstantFalse => self.parse_bool_constant(inst, false, &mut module), + Op::ConstantTrue | Op::SpecConstantTrue => { + self.parse_bool_constant(inst, true, &mut module) + } + Op::ConstantFalse | Op::SpecConstantFalse => { + self.parse_bool_constant(inst, false, &mut module) + } Op::Variable => self.parse_global_variable(inst, &mut module), Op::Function => { self.switch(ModuleState::Function, inst.op)?; @@ -4504,9 +4773,9 @@ impl<I: Iterator<Item = u32>> Frontend<I> { let length_id = self.next()?; let length_const = self.lookup_constant.lookup(length_id)?; - let size = resolve_constant(module.to_ctx(), length_const.handle) + let size = resolve_constant(module.to_ctx(), &length_const.inner) .and_then(NonZeroU32::new) - .ok_or(Error::InvalidArraySize(length_const.handle))?; + .ok_or(Error::InvalidArraySize(length_id))?; let decor = self.future_decor.remove(&id).unwrap_or_default(); let base = self.lookup_type.lookup(type_id)?.handle; @@ -4919,29 +5188,13 @@ impl<I: Iterator<Item = u32>> Frontend<I> { _ => return Err(Error::UnsupportedType(type_lookup.handle)), }; - let decor = self.future_decor.remove(&id).unwrap_or_default(); - let span = self.span_from_with_op(start); let init = module - .const_expressions + .global_expressions .append(crate::Expression::Literal(literal), span); - self.lookup_constant.insert( - id, - LookupConstant { - handle: module.constants.append( - crate::Constant { - r#override: decor.specialization(), - name: decor.name, - ty, - init, - }, - span, - ), - type_id, - }, - ); - Ok(()) + + self.insert_parsed_constant(module, id, type_id, ty, init, span) } fn parse_composite_constant( @@ -4965,34 +5218,18 @@ impl<I: Iterator<Item = u32>> Frontend<I> { let span = self.span_from_with_op(start); let constant = self.lookup_constant.lookup(component_id)?; let expr = module - .const_expressions - .append(crate::Expression::Constant(constant.handle), span); + .global_expressions + .append(constant.inner.to_expr(), span); components.push(expr); } - let decor = self.future_decor.remove(&id).unwrap_or_default(); - let span = self.span_from_with_op(start); let init = module - .const_expressions + .global_expressions .append(crate::Expression::Compose { ty, components }, span); - self.lookup_constant.insert( - id, - LookupConstant { - handle: module.constants.append( - crate::Constant { - r#override: decor.specialization(), - name: decor.name, - ty, - init, - }, - span, - ), - type_id, - }, - ); - Ok(()) + + self.insert_parsed_constant(module, id, type_id, ty, init, span) } fn parse_null_constant( @@ -5010,23 +5247,11 @@ impl<I: Iterator<Item = u32>> Frontend<I> { let type_lookup = self.lookup_type.lookup(type_id)?; let ty = type_lookup.handle; - let decor = self.future_decor.remove(&id).unwrap_or_default(); - let init = module - .const_expressions + .global_expressions .append(crate::Expression::ZeroValue(ty), span); - let handle = module.constants.append( - crate::Constant { - r#override: decor.specialization(), - name: decor.name, - ty, - init, - }, - span, - ); - self.lookup_constant - .insert(id, LookupConstant { handle, type_id }); - Ok(()) + + self.insert_parsed_constant(module, id, type_id, ty, init, span) } fn parse_bool_constant( @@ -5045,27 +5270,44 @@ impl<I: Iterator<Item = u32>> Frontend<I> { let type_lookup = self.lookup_type.lookup(type_id)?; let ty = type_lookup.handle; - let decor = self.future_decor.remove(&id).unwrap_or_default(); - - let init = module.const_expressions.append( + let init = module.global_expressions.append( crate::Expression::Literal(crate::Literal::Bool(value)), span, ); - self.lookup_constant.insert( - id, - LookupConstant { - handle: module.constants.append( - crate::Constant { - r#override: decor.specialization(), - name: decor.name, - ty, - init, - }, - span, - ), - type_id, - }, - ); + + self.insert_parsed_constant(module, id, type_id, ty, init, span) + } + + fn insert_parsed_constant( + &mut self, + module: &mut crate::Module, + id: u32, + type_id: u32, + ty: Handle<crate::Type>, + init: Handle<crate::Expression>, + span: crate::Span, + ) -> Result<(), Error> { + let decor = self.future_decor.remove(&id).unwrap_or_default(); + + let inner = if let Some(id) = decor.specialization_constant_id { + let o = crate::Override { + name: decor.name, + id: Some(id.try_into().map_err(|_| Error::SpecIdTooHigh(id))?), + ty, + init: Some(init), + }; + Constant::Override(module.overrides.append(o, span)) + } else { + let c = crate::Constant { + name: decor.name, + ty, + init, + }; + Constant::Constant(module.constants.append(c, span)) + }; + + self.lookup_constant + .insert(id, LookupConstant { inner, type_id }); Ok(()) } @@ -5087,8 +5329,8 @@ impl<I: Iterator<Item = u32>> Frontend<I> { let span = self.span_from_with_op(start); let lconst = self.lookup_constant.lookup(init_id)?; let expr = module - .const_expressions - .append(crate::Expression::Constant(lconst.handle), span); + .global_expressions + .append(lconst.inner.to_expr(), span); Some(expr) } else { None @@ -5209,7 +5451,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> { match null::generate_default_built_in( Some(built_in), ty, - &mut module.const_expressions, + &mut module.global_expressions, span, ) { Ok(handle) => Some(handle), @@ -5231,14 +5473,14 @@ impl<I: Iterator<Item = u32>> Frontend<I> { let handle = null::generate_default_built_in( built_in, member.ty, - &mut module.const_expressions, + &mut module.global_expressions, span, )?; components.push(handle); } Some( module - .const_expressions + .global_expressions .append(crate::Expression::Compose { ty, components }, span), ) } @@ -5303,11 +5545,12 @@ fn make_index_literal( Ok(expr) } -fn resolve_constant( - gctx: crate::proc::GlobalCtx, - constant: Handle<crate::Constant>, -) -> Option<u32> { - match gctx.const_expressions[gctx.constants[constant].init] { +fn resolve_constant(gctx: crate::proc::GlobalCtx, constant: &Constant) -> Option<u32> { + let constant = match *constant { + Constant::Constant(constant) => constant, + Constant::Override(_) => return None, + }; + match gctx.global_expressions[gctx.constants[constant].init] { crate::Expression::Literal(crate::Literal::U32(id)) => Some(id), crate::Expression::Literal(crate::Literal::I32(id)) => Some(id as u32), _ => None, diff --git a/third_party/rust/naga/src/front/spv/null.rs b/third_party/rust/naga/src/front/spv/null.rs index 42cccca80a..c7d3776841 100644 --- a/third_party/rust/naga/src/front/spv/null.rs +++ b/third_party/rust/naga/src/front/spv/null.rs @@ -5,14 +5,14 @@ use crate::arena::{Arena, Handle}; pub fn generate_default_built_in( built_in: Option<crate::BuiltIn>, ty: Handle<crate::Type>, - const_expressions: &mut Arena<crate::Expression>, + global_expressions: &mut Arena<crate::Expression>, span: crate::Span, ) -> Result<Handle<crate::Expression>, Error> { let expr = match built_in { Some(crate::BuiltIn::Position { .. }) => { - let zero = const_expressions + let zero = global_expressions .append(crate::Expression::Literal(crate::Literal::F32(0.0)), span); - let one = const_expressions + let one = global_expressions .append(crate::Expression::Literal(crate::Literal::F32(1.0)), span); crate::Expression::Compose { ty, @@ -27,5 +27,5 @@ pub fn generate_default_built_in( // Note: `crate::BuiltIn::ClipDistance` is intentionally left for the default path _ => crate::Expression::ZeroValue(ty), }; - Ok(const_expressions.append(expr, span)) + Ok(global_expressions.append(expr, span)) } diff --git a/third_party/rust/naga/src/front/wgsl/error.rs b/third_party/rust/naga/src/front/wgsl/error.rs index 54aa8296b1..dc1339521c 100644 --- a/third_party/rust/naga/src/front/wgsl/error.rs +++ b/third_party/rust/naga/src/front/wgsl/error.rs @@ -13,6 +13,7 @@ use thiserror::Error; #[derive(Clone, Debug)] pub struct ParseError { message: String, + // The first span should be the primary span, and the other ones should be complementary. labels: Vec<(Span, Cow<'static, str>)>, notes: Vec<String>, } @@ -190,7 +191,7 @@ pub enum Error<'a> { expected: String, got: String, }, - MissingType(Span), + DeclMissingTypeAndInit(Span), MissingAttribute(&'static str, Span), InvalidAtomicPointer(Span), InvalidAtomicOperandType(Span), @@ -269,6 +270,11 @@ pub enum Error<'a> { scalar: String, inner: ConstantEvaluatorError, }, + ExceededLimitForNestedBraces { + span: Span, + limit: u8, + }, + PipelineConstantIDValue(Span), } impl<'a> Error<'a> { @@ -518,11 +524,11 @@ impl<'a> Error<'a> { notes: vec![], } } - Error::MissingType(name_span) => ParseError { - message: format!("variable `{}` needs a type", &source[name_span]), + Error::DeclMissingTypeAndInit(name_span) => ParseError { + message: format!("declaration of `{}` needs a type specifier or initializer", &source[name_span]), labels: vec![( name_span, - format!("definition of `{}`", &source[name_span]).into(), + "needs a type specifier or initializer".into(), )], notes: vec![], }, @@ -770,6 +776,21 @@ impl<'a> Error<'a> { format!("the expression should have been converted to have {} scalar type", scalar), ] }, + Error::ExceededLimitForNestedBraces { span, limit } => ParseError { + message: "brace nesting limit reached".into(), + labels: vec![(span, "limit reached at this brace".into())], + notes: vec![ + format!("nesting limit is currently set to {limit}"), + ], + }, + Error::PipelineConstantIDValue(span) => ParseError { + message: "pipeline constant ID must be between 0 and 65535 inclusive".to_string(), + labels: vec![( + span, + "must be between 0 and 65535 inclusive".into(), + )], + notes: vec![], + }, } } } diff --git a/third_party/rust/naga/src/front/wgsl/index.rs b/third_party/rust/naga/src/front/wgsl/index.rs index a5524fe8f1..593405508f 100644 --- a/third_party/rust/naga/src/front/wgsl/index.rs +++ b/third_party/rust/naga/src/front/wgsl/index.rs @@ -187,6 +187,7 @@ const fn decl_ident<'a>(decl: &ast::GlobalDecl<'a>) -> ast::Ident<'a> { ast::GlobalDeclKind::Fn(ref f) => f.name, ast::GlobalDeclKind::Var(ref v) => v.name, ast::GlobalDeclKind::Const(ref c) => c.name, + ast::GlobalDeclKind::Override(ref o) => o.name, ast::GlobalDeclKind::Struct(ref s) => s.name, ast::GlobalDeclKind::Type(ref t) => t.name, } diff --git a/third_party/rust/naga/src/front/wgsl/lower/mod.rs b/third_party/rust/naga/src/front/wgsl/lower/mod.rs index 2ca6c182b7..e7cce17723 100644 --- a/third_party/rust/naga/src/front/wgsl/lower/mod.rs +++ b/third_party/rust/naga/src/front/wgsl/lower/mod.rs @@ -86,6 +86,8 @@ pub struct GlobalContext<'source, 'temp, 'out> { module: &'out mut crate::Module, const_typifier: &'temp mut Typifier, + + global_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker, } impl<'source> GlobalContext<'source, '_, '_> { @@ -97,6 +99,19 @@ impl<'source> GlobalContext<'source, '_, '_> { module: self.module, const_typifier: self.const_typifier, expr_type: ExpressionContextType::Constant, + global_expression_kind_tracker: self.global_expression_kind_tracker, + } + } + + fn as_override(&mut self) -> ExpressionContext<'source, '_, '_> { + ExpressionContext { + ast_expressions: self.ast_expressions, + globals: self.globals, + types: self.types, + module: self.module, + const_typifier: self.const_typifier, + expr_type: ExpressionContextType::Override, + global_expression_kind_tracker: self.global_expression_kind_tracker, } } @@ -164,7 +179,8 @@ pub struct StatementContext<'source, 'temp, 'out> { /// with the form of the expressions; it is also tracking whether WGSL says /// we should consider them to be const. See the use of `force_non_const` in /// the code for lowering `let` bindings. - expression_constness: &'temp mut crate::proc::ExpressionConstnessTracker, + local_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker, + global_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker, } impl<'a, 'temp> StatementContext<'a, 'temp, '_> { @@ -181,6 +197,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> { types: self.types, ast_expressions: self.ast_expressions, const_typifier: self.const_typifier, + global_expression_kind_tracker: self.global_expression_kind_tracker, module: self.module, expr_type: ExpressionContextType::Runtime(RuntimeExpressionContext { local_table: self.local_table, @@ -188,7 +205,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> { block, emitter, typifier: self.typifier, - expression_constness: self.expression_constness, + local_expression_kind_tracker: self.local_expression_kind_tracker, }), } } @@ -200,6 +217,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> { types: self.types, module: self.module, const_typifier: self.const_typifier, + global_expression_kind_tracker: self.global_expression_kind_tracker, } } @@ -232,8 +250,8 @@ pub struct RuntimeExpressionContext<'temp, 'out> { /// Which `Expression`s in `self.naga_expressions` are const expressions, in /// the WGSL sense. /// - /// See [`StatementContext::expression_constness`] for details. - expression_constness: &'temp mut crate::proc::ExpressionConstnessTracker, + /// See [`StatementContext::local_expression_kind_tracker`] for details. + local_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker, } /// The type of Naga IR expression we are lowering an [`ast::Expression`] to. @@ -253,6 +271,14 @@ pub enum ExpressionContextType<'temp, 'out> { /// available in the [`ExpressionContext`], so this variant /// carries no further information. Constant, + + /// We are lowering to an override expression, to be included in the module's + /// constant expression arena. + /// + /// Everything override expressions are allowed to refer to is + /// available in the [`ExpressionContext`], so this variant + /// carries no further information. + Override, } /// State for lowering an [`ast::Expression`] to Naga IR. @@ -307,10 +333,11 @@ pub struct ExpressionContext<'source, 'temp, 'out> { /// [`Module`]: crate::Module module: &'out mut crate::Module, - /// Type judgments for [`module::const_expressions`]. + /// Type judgments for [`module::global_expressions`]. /// - /// [`module::const_expressions`]: crate::Module::const_expressions + /// [`module::global_expressions`]: crate::Module::global_expressions const_typifier: &'temp mut Typifier, + global_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker, /// Whether we are lowering a constant expression or a general /// runtime expression, and the data needed in each case. @@ -326,6 +353,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { const_typifier: self.const_typifier, module: self.module, expr_type: ExpressionContextType::Constant, + global_expression_kind_tracker: self.global_expression_kind_tracker, } } @@ -336,6 +364,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { types: self.types, module: self.module, const_typifier: self.const_typifier, + global_expression_kind_tracker: self.global_expression_kind_tracker, } } @@ -344,11 +373,20 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { ExpressionContextType::Runtime(ref mut rctx) => ConstantEvaluator::for_wgsl_function( self.module, &mut rctx.function.expressions, - rctx.expression_constness, + rctx.local_expression_kind_tracker, rctx.emitter, rctx.block, ), - ExpressionContextType::Constant => ConstantEvaluator::for_wgsl_module(self.module), + ExpressionContextType::Constant => ConstantEvaluator::for_wgsl_module( + self.module, + self.global_expression_kind_tracker, + false, + ), + ExpressionContextType::Override => ConstantEvaluator::for_wgsl_module( + self.module, + self.global_expression_kind_tracker, + true, + ), } } @@ -358,24 +396,14 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { span: Span, ) -> Result<Handle<crate::Expression>, Error<'source>> { let mut eval = self.as_const_evaluator(); - match eval.try_eval_and_append(&expr, span) { - Ok(expr) => Ok(expr), - - // `expr` is not a constant expression. This is fine as - // long as we're not building `Module::const_expressions`. - Err(err) => match self.expr_type { - ExpressionContextType::Runtime(ref mut rctx) => { - Ok(rctx.function.expressions.append(expr, span)) - } - ExpressionContextType::Constant => Err(Error::ConstantEvaluatorError(err, span)), - }, - } + eval.try_eval_and_append(expr, span) + .map_err(|e| Error::ConstantEvaluatorError(e, span)) } fn const_access(&self, handle: Handle<crate::Expression>) -> Option<u32> { match self.expr_type { ExpressionContextType::Runtime(ref ctx) => { - if !ctx.expression_constness.is_const(handle) { + if !ctx.local_expression_kind_tracker.is_const(handle) { return None; } @@ -385,20 +413,25 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { .ok() } ExpressionContextType::Constant => self.module.to_ctx().eval_expr_to_u32(handle).ok(), + ExpressionContextType::Override => None, } } fn get_expression_span(&self, handle: Handle<crate::Expression>) -> Span { match self.expr_type { ExpressionContextType::Runtime(ref ctx) => ctx.function.expressions.get_span(handle), - ExpressionContextType::Constant => self.module.const_expressions.get_span(handle), + ExpressionContextType::Constant | ExpressionContextType::Override => { + self.module.global_expressions.get_span(handle) + } } } fn typifier(&self) -> &Typifier { match self.expr_type { ExpressionContextType::Runtime(ref ctx) => ctx.typifier, - ExpressionContextType::Constant => self.const_typifier, + ExpressionContextType::Constant | ExpressionContextType::Override => { + self.const_typifier + } } } @@ -408,7 +441,9 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { ) -> Result<&mut RuntimeExpressionContext<'temp, 'out>, Error<'source>> { match self.expr_type { ExpressionContextType::Runtime(ref mut ctx) => Ok(ctx), - ExpressionContextType::Constant => Err(Error::UnexpectedOperationInConstContext(span)), + ExpressionContextType::Constant | ExpressionContextType::Override => { + Err(Error::UnexpectedOperationInConstContext(span)) + } } } @@ -420,7 +455,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { ) -> Result<crate::SwizzleComponent, Error<'source>> { match self.expr_type { ExpressionContextType::Runtime(ref rctx) => { - if !rctx.expression_constness.is_const(expr) { + if !rctx.local_expression_kind_tracker.is_const(expr) { return Err(Error::ExpectedConstExprConcreteIntegerScalar( component_span, )); @@ -445,7 +480,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { } // This means a `gather` operation appeared in a constant expression. // This error refers to the `gather` itself, not its "component" argument. - ExpressionContextType::Constant => { + ExpressionContextType::Constant | ExpressionContextType::Override => { Err(Error::UnexpectedOperationInConstContext(gather_span)) } } @@ -471,7 +506,9 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { // to also borrow self.module.types mutably below. let typifier = match self.expr_type { ExpressionContextType::Runtime(ref ctx) => ctx.typifier, - ExpressionContextType::Constant => &*self.const_typifier, + ExpressionContextType::Constant | ExpressionContextType::Override => { + &*self.const_typifier + } }; Ok(typifier.register_type(handle, &mut self.module.types)) } @@ -514,10 +551,10 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { typifier = &mut *ctx.typifier; expressions = &ctx.function.expressions; } - ExpressionContextType::Constant => { + ExpressionContextType::Constant | ExpressionContextType::Override => { resolve_ctx = ResolveContext::with_locals(self.module, &empty_arena, &[]); typifier = self.const_typifier; - expressions = &self.module.const_expressions; + expressions = &self.module.global_expressions; } }; typifier @@ -610,14 +647,14 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { rctx.block .extend(rctx.emitter.finish(&rctx.function.expressions)); } - ExpressionContextType::Constant => {} + ExpressionContextType::Constant | ExpressionContextType::Override => {} } let result = self.append_expression(expression, span); match self.expr_type { ExpressionContextType::Runtime(ref mut rctx) => { rctx.emitter.start(&rctx.function.expressions); } - ExpressionContextType::Constant => {} + ExpressionContextType::Constant | ExpressionContextType::Override => {} } result } @@ -786,6 +823,7 @@ enum LoweredGlobalDecl { Function(Handle<crate::Function>), Var(Handle<crate::GlobalVariable>), Const(Handle<crate::Constant>), + Override(Handle<crate::Override>), Type(Handle<crate::Type>), EntryPoint, } @@ -836,6 +874,29 @@ impl Texture { } } +enum SubgroupGather { + BroadcastFirst, + Broadcast, + Shuffle, + ShuffleDown, + ShuffleUp, + ShuffleXor, +} + +impl SubgroupGather { + pub fn map(word: &str) -> Option<Self> { + Some(match word { + "subgroupBroadcastFirst" => Self::BroadcastFirst, + "subgroupBroadcast" => Self::Broadcast, + "subgroupShuffle" => Self::Shuffle, + "subgroupShuffleDown" => Self::ShuffleDown, + "subgroupShuffleUp" => Self::ShuffleUp, + "subgroupShuffleXor" => Self::ShuffleXor, + _ => return None, + }) + } +} + pub struct Lowerer<'source, 'temp> { index: &'temp Index<'source>, layouter: Layouter, @@ -861,6 +922,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { types: &tu.types, module: &mut module, const_typifier: &mut Typifier::new(), + global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker::new(), }; for decl_handle in self.index.visit_ordered() { @@ -877,7 +939,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let init; if let Some(init_ast) = v.init { - let mut ectx = ctx.as_const(); + let mut ectx = ctx.as_override(); let lowered = self.expression_for_abstract(init_ast, &mut ectx)?; let ty_res = crate::proc::TypeResolution::Handle(ty); let converted = ectx @@ -956,7 +1018,6 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let handle = ctx.module.constants.append( crate::Constant { name: Some(c.name.name.to_string()), - r#override: crate::Override::None, ty, init, }, @@ -966,6 +1027,65 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ctx.globals .insert(c.name.name, LoweredGlobalDecl::Const(handle)); } + ast::GlobalDeclKind::Override(ref o) => { + let init = o + .init + .map(|init| self.expression(init, &mut ctx.as_override())) + .transpose()?; + let inferred_type = init + .map(|init| ctx.as_const().register_type(init)) + .transpose()?; + + let explicit_ty = + o.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx)) + .transpose()?; + + let id = + o.id.map(|id| self.const_u32(id, &mut ctx.as_const())) + .transpose()?; + + let id = if let Some((id, id_span)) = id { + Some( + u16::try_from(id) + .map_err(|_| Error::PipelineConstantIDValue(id_span))?, + ) + } else { + None + }; + + let ty = match (explicit_ty, inferred_type) { + (Some(explicit_ty), Some(inferred_type)) => { + if explicit_ty == inferred_type { + explicit_ty + } else { + let gctx = ctx.module.to_ctx(); + return Err(Error::InitializationTypeMismatch { + name: o.name.span, + expected: explicit_ty.to_wgsl(&gctx), + got: inferred_type.to_wgsl(&gctx), + }); + } + } + (Some(explicit_ty), None) => explicit_ty, + (None, Some(inferred_type)) => inferred_type, + (None, None) => { + return Err(Error::DeclMissingTypeAndInit(o.name.span)); + } + }; + + let handle = ctx.module.overrides.append( + crate::Override { + name: Some(o.name.name.to_string()), + id, + ty, + init, + }, + span, + ); + + ctx.globals + .insert(o.name.name, LoweredGlobalDecl::Override(handle)); + } ast::GlobalDeclKind::Struct(ref s) => { let handle = self.r#struct(s, span, &mut ctx)?; ctx.globals @@ -1000,6 +1120,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let mut local_table = FastHashMap::default(); let mut expressions = Arena::new(); let mut named_expressions = FastIndexMap::default(); + let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new(); let arguments = f .arguments @@ -1011,6 +1132,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .append(crate::Expression::FunctionArgument(i as u32), arg.name.span); local_table.insert(arg.handle, Typed::Plain(expr)); named_expressions.insert(expr, (arg.name.name.to_string(), arg.name.span)); + local_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Runtime); Ok(crate::FunctionArgument { name: Some(arg.name.name.to_string()), @@ -1053,7 +1175,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { named_expressions: &mut named_expressions, types: ctx.types, module: ctx.module, - expression_constness: &mut crate::proc::ExpressionConstnessTracker::new(), + local_expression_kind_tracker: &mut local_expression_kind_tracker, + global_expression_kind_tracker: ctx.global_expression_kind_tracker, }; let mut body = self.block(&f.body, false, &mut stmt_ctx)?; ensure_block_returns(&mut body); @@ -1132,7 +1255,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { // affects when errors must be reported, so we can't even // treat suitable `let` bindings as constant as an // optimization. - ctx.expression_constness.force_non_const(value); + ctx.local_expression_kind_tracker.force_non_const(value); let explicit_ty = l.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx.as_global())) @@ -1203,7 +1326,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ty = explicit_ty; initializer = None; } - (None, None) => return Err(Error::MissingType(v.name.span)), + (None, None) => return Err(Error::DeclMissingTypeAndInit(v.name.span)), } let (const_initializer, initializer) = { @@ -1216,7 +1339,9 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { // - the initialization is not a constant // expression, so its value depends on the // state at the point of initialization. - if is_inside_loop || !ctx.expression_constness.is_const(init) { + if is_inside_loop + || !ctx.local_expression_kind_tracker.is_const_or_override(init) + { (None, Some(init)) } else { (Some(init), None) @@ -1469,6 +1594,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .function .expressions .append(crate::Expression::Binary { op, left, right }, stmt.span); + rctx.local_expression_kind_tracker + .insert(left, crate::proc::ExpressionKind::Runtime); + rctx.local_expression_kind_tracker + .insert(value, crate::proc::ExpressionKind::Runtime); block.extend(emitter.finish(&ctx.function.expressions)); crate::Statement::Store { @@ -1562,7 +1691,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { LoweredGlobalDecl::Const(handle) => { Typed::Plain(crate::Expression::Constant(handle)) } - _ => { + LoweredGlobalDecl::Override(handle) => { + Typed::Plain(crate::Expression::Override(handle)) + } + LoweredGlobalDecl::Function(_) + | LoweredGlobalDecl::Type(_) + | LoweredGlobalDecl::EntryPoint => { return Err(Error::Unexpected(span, ExpectedToken::Variable)); } }; @@ -1819,9 +1953,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { )?; Ok(Some(handle)) } - Some(&LoweredGlobalDecl::Const(_) | &LoweredGlobalDecl::Var(_)) => { - Err(Error::Unexpected(function.span, ExpectedToken::Function)) - } + Some( + &LoweredGlobalDecl::Const(_) + | &LoweredGlobalDecl::Override(_) + | &LoweredGlobalDecl::Var(_), + ) => Err(Error::Unexpected(function.span, ExpectedToken::Function)), Some(&LoweredGlobalDecl::EntryPoint) => Err(Error::CalledEntryPoint(function.span)), Some(&LoweredGlobalDecl::Function(function)) => { let arguments = arguments @@ -1835,9 +1971,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { rctx.block .extend(rctx.emitter.finish(&rctx.function.expressions)); let result = has_result.then(|| { - rctx.function + let result = rctx + .function .expressions - .append(crate::Expression::CallResult(function), span) + .append(crate::Expression::CallResult(function), span); + rctx.local_expression_kind_tracker + .insert(result, crate::proc::ExpressionKind::Runtime); + result }); rctx.emitter.start(&rctx.function.expressions); rctx.block.push( @@ -1937,6 +2077,16 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } } else if let Some(fun) = Texture::map(function.name) { self.texture_sample_helper(fun, arguments, span, ctx)? + } else if let Some((op, cop)) = conv::map_subgroup_operation(function.name) { + return Ok(Some( + self.subgroup_operation_helper(span, op, cop, arguments, ctx)?, + )); + } else if let Some(mode) = SubgroupGather::map(function.name) { + return Ok(Some( + self.subgroup_gather_helper(span, mode, arguments, ctx)?, + )); + } else if let Some(fun) = crate::AtomicFunction::map(function.name) { + return Ok(Some(self.atomic_helper(span, fun, arguments, ctx)?)); } else { match function.name { "select" => { @@ -1982,70 +2132,6 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .push(crate::Statement::Store { pointer, value }, span); return Ok(None); } - "atomicAdd" => { - return Ok(Some(self.atomic_helper( - span, - crate::AtomicFunction::Add, - arguments, - ctx, - )?)) - } - "atomicSub" => { - return Ok(Some(self.atomic_helper( - span, - crate::AtomicFunction::Subtract, - arguments, - ctx, - )?)) - } - "atomicAnd" => { - return Ok(Some(self.atomic_helper( - span, - crate::AtomicFunction::And, - arguments, - ctx, - )?)) - } - "atomicOr" => { - return Ok(Some(self.atomic_helper( - span, - crate::AtomicFunction::InclusiveOr, - arguments, - ctx, - )?)) - } - "atomicXor" => { - return Ok(Some(self.atomic_helper( - span, - crate::AtomicFunction::ExclusiveOr, - arguments, - ctx, - )?)) - } - "atomicMin" => { - return Ok(Some(self.atomic_helper( - span, - crate::AtomicFunction::Min, - arguments, - ctx, - )?)) - } - "atomicMax" => { - return Ok(Some(self.atomic_helper( - span, - crate::AtomicFunction::Max, - arguments, - ctx, - )?)) - } - "atomicExchange" => { - return Ok(Some(self.atomic_helper( - span, - crate::AtomicFunction::Exchange { compare: None }, - arguments, - ctx, - )?)) - } "atomicCompareExchangeWeak" => { let mut args = ctx.prepare_args(arguments, 3, span); @@ -2104,6 +2190,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .push(crate::Statement::Barrier(crate::Barrier::WORK_GROUP), span); return Ok(None); } + "subgroupBarrier" => { + ctx.prepare_args(arguments, 0, span).finish()?; + + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .push(crate::Statement::Barrier(crate::Barrier::SUB_GROUP), span); + return Ok(None); + } "workgroupUniformLoad" => { let mut args = ctx.prepare_args(arguments, 1, span); let expr = args.next()?; @@ -2311,6 +2405,22 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { )?; return Ok(Some(handle)); } + "subgroupBallot" => { + let mut args = ctx.prepare_args(arguments, 0, span); + let predicate = if arguments.len() == 1 { + Some(self.expression(args.next()?, ctx)?) + } else { + None + }; + args.finish()?; + + let result = ctx + .interrupt_emitter(crate::Expression::SubgroupBallotResult, span)?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .push(crate::Statement::SubgroupBallot { result, predicate }, span); + return Ok(Some(result)); + } _ => return Err(Error::UnknownIdent(function.span, function.name)), } }; @@ -2502,6 +2612,80 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }) } + fn subgroup_operation_helper( + &mut self, + span: Span, + op: crate::SubgroupOperation, + collective_op: crate::CollectiveOperation, + arguments: &[Handle<ast::Expression<'source>>], + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result<Handle<crate::Expression>, Error<'source>> { + let mut args = ctx.prepare_args(arguments, 1, span); + + let argument = self.expression(args.next()?, ctx)?; + args.finish()?; + + let ty = ctx.register_type(argument)?; + + let result = + ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::SubgroupCollectiveOperation { + op, + collective_op, + argument, + result, + }, + span, + ); + Ok(result) + } + + fn subgroup_gather_helper( + &mut self, + span: Span, + mode: SubgroupGather, + arguments: &[Handle<ast::Expression<'source>>], + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result<Handle<crate::Expression>, Error<'source>> { + let mut args = ctx.prepare_args(arguments, 2, span); + + let argument = self.expression(args.next()?, ctx)?; + + use SubgroupGather as Sg; + let mode = if let Sg::BroadcastFirst = mode { + crate::GatherMode::BroadcastFirst + } else { + let index = self.expression(args.next()?, ctx)?; + match mode { + Sg::Broadcast => crate::GatherMode::Broadcast(index), + Sg::Shuffle => crate::GatherMode::Shuffle(index), + Sg::ShuffleDown => crate::GatherMode::ShuffleDown(index), + Sg::ShuffleUp => crate::GatherMode::ShuffleUp(index), + Sg::ShuffleXor => crate::GatherMode::ShuffleXor(index), + Sg::BroadcastFirst => unreachable!(), + } + }; + + args.finish()?; + + let ty = ctx.register_type(argument)?; + + let result = + ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::SubgroupGather { + mode, + argument, + result, + }, + span, + ); + Ok(result) + } + fn r#struct( &mut self, s: &ast::Struct<'source>, @@ -2760,3 +2944,19 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } } } + +impl crate::AtomicFunction { + pub fn map(word: &str) -> Option<Self> { + Some(match word { + "atomicAdd" => crate::AtomicFunction::Add, + "atomicSub" => crate::AtomicFunction::Subtract, + "atomicAnd" => crate::AtomicFunction::And, + "atomicOr" => crate::AtomicFunction::InclusiveOr, + "atomicXor" => crate::AtomicFunction::ExclusiveOr, + "atomicMin" => crate::AtomicFunction::Min, + "atomicMax" => crate::AtomicFunction::Max, + "atomicExchange" => crate::AtomicFunction::Exchange { compare: None }, + _ => return None, + }) + } +} diff --git a/third_party/rust/naga/src/front/wgsl/mod.rs b/third_party/rust/naga/src/front/wgsl/mod.rs index b6151fe1c0..aec1e657fc 100644 --- a/third_party/rust/naga/src/front/wgsl/mod.rs +++ b/third_party/rust/naga/src/front/wgsl/mod.rs @@ -44,6 +44,17 @@ impl Frontend { } } +/// <div class="warning"> +// NOTE: Keep this in sync with `wgpu::Device::create_shader_module`! +// NOTE: Keep this in sync with `wgpu_core::Global::device_create_shader_module`! +/// +/// This function may consume a lot of stack space. Compiler-enforced limits for parsing recursion +/// exist; if shader compilation runs into them, it will return an error gracefully. However, on +/// some build profiles and platforms, the default stack size for a thread may be exceeded before +/// this limit is reached during parsing. Callers should ensure that there is enough stack space +/// for this, particularly if calls to this method are exposed to user input. +/// +/// </div> pub fn parse_str(source: &str) -> Result<crate::Module, ParseError> { Frontend::new().parse(source) } diff --git a/third_party/rust/naga/src/front/wgsl/parse/ast.rs b/third_party/rust/naga/src/front/wgsl/parse/ast.rs index dbaac523cb..ea8013ee7c 100644 --- a/third_party/rust/naga/src/front/wgsl/parse/ast.rs +++ b/third_party/rust/naga/src/front/wgsl/parse/ast.rs @@ -82,6 +82,7 @@ pub enum GlobalDeclKind<'a> { Fn(Function<'a>), Var(GlobalVariable<'a>), Const(Const<'a>), + Override(Override<'a>), Struct(Struct<'a>), Type(TypeAlias<'a>), } @@ -200,6 +201,14 @@ pub struct Const<'a> { pub init: Handle<Expression<'a>>, } +#[derive(Debug)] +pub struct Override<'a> { + pub name: Ident<'a>, + pub id: Option<Handle<Expression<'a>>>, + pub ty: Option<Handle<Type<'a>>>, + pub init: Option<Handle<Expression<'a>>>, +} + /// The size of an [`Array`] or [`BindingArray`]. /// /// [`Array`]: Type::Array diff --git a/third_party/rust/naga/src/front/wgsl/parse/conv.rs b/third_party/rust/naga/src/front/wgsl/parse/conv.rs index 1a4911a3bd..207f0eda41 100644 --- a/third_party/rust/naga/src/front/wgsl/parse/conv.rs +++ b/third_party/rust/naga/src/front/wgsl/parse/conv.rs @@ -35,6 +35,11 @@ pub fn map_built_in(word: &str, span: Span) -> Result<crate::BuiltIn, Error<'_>> "local_invocation_index" => crate::BuiltIn::LocalInvocationIndex, "workgroup_id" => crate::BuiltIn::WorkGroupId, "num_workgroups" => crate::BuiltIn::NumWorkGroups, + // subgroup + "num_subgroups" => crate::BuiltIn::NumSubgroups, + "subgroup_id" => crate::BuiltIn::SubgroupId, + "subgroup_size" => crate::BuiltIn::SubgroupSize, + "subgroup_invocation_id" => crate::BuiltIn::SubgroupInvocationId, _ => return Err(Error::UnknownBuiltin(span)), }) } @@ -260,3 +265,26 @@ pub fn map_conservative_depth( _ => Err(Error::UnknownConservativeDepth(span)), } } + +pub fn map_subgroup_operation( + word: &str, +) -> Option<(crate::SubgroupOperation, crate::CollectiveOperation)> { + use crate::CollectiveOperation as co; + use crate::SubgroupOperation as sg; + Some(match word { + "subgroupAll" => (sg::All, co::Reduce), + "subgroupAny" => (sg::Any, co::Reduce), + "subgroupAdd" => (sg::Add, co::Reduce), + "subgroupMul" => (sg::Mul, co::Reduce), + "subgroupMin" => (sg::Min, co::Reduce), + "subgroupMax" => (sg::Max, co::Reduce), + "subgroupAnd" => (sg::And, co::Reduce), + "subgroupOr" => (sg::Or, co::Reduce), + "subgroupXor" => (sg::Xor, co::Reduce), + "subgroupExclusiveAdd" => (sg::Add, co::ExclusiveScan), + "subgroupExclusiveMul" => (sg::Mul, co::ExclusiveScan), + "subgroupInclusiveAdd" => (sg::Add, co::InclusiveScan), + "subgroupInclusiveMul" => (sg::Mul, co::InclusiveScan), + _ => return None, + }) +} diff --git a/third_party/rust/naga/src/front/wgsl/parse/mod.rs b/third_party/rust/naga/src/front/wgsl/parse/mod.rs index 51fc2f013b..79ea1ae609 100644 --- a/third_party/rust/naga/src/front/wgsl/parse/mod.rs +++ b/third_party/rust/naga/src/front/wgsl/parse/mod.rs @@ -1619,22 +1619,21 @@ impl Parser { lexer: &mut Lexer<'a>, ctx: &mut ExpressionContext<'a, '_, '_>, block: &mut ast::Block<'a>, + brace_nesting_level: u8, ) -> Result<(), Error<'a>> { self.push_rule_span(Rule::Statement, lexer); match lexer.peek() { (Token::Separator(';'), _) => { let _ = lexer.next(); self.pop_rule_span(lexer); - return Ok(()); } (Token::Paren('{'), _) => { - let (inner, span) = self.block(lexer, ctx)?; + let (inner, span) = self.block(lexer, ctx, brace_nesting_level)?; block.stmts.push(ast::Statement { kind: ast::StatementKind::Block(inner), span, }); self.pop_rule_span(lexer); - return Ok(()); } (Token::Word(word), _) => { let kind = match word { @@ -1711,7 +1710,7 @@ impl Parser { let _ = lexer.next(); let condition = self.general_expression(lexer, ctx)?; - let accept = self.block(lexer, ctx)?.0; + let accept = self.block(lexer, ctx, brace_nesting_level)?.0; let mut elsif_stack = Vec::new(); let mut elseif_span_start = lexer.start_byte_offset(); @@ -1722,12 +1721,12 @@ impl Parser { if !lexer.skip(Token::Word("if")) { // ... else { ... } - break self.block(lexer, ctx)?.0; + break self.block(lexer, ctx, brace_nesting_level)?.0; } // ... else if (...) { ... } let other_condition = self.general_expression(lexer, ctx)?; - let other_block = self.block(lexer, ctx)?; + let other_block = self.block(lexer, ctx, brace_nesting_level)?; elsif_stack.push((elseif_span_start, other_condition, other_block)); elseif_span_start = lexer.start_byte_offset(); }; @@ -1759,7 +1758,9 @@ impl Parser { "switch" => { let _ = lexer.next(); let selector = self.general_expression(lexer, ctx)?; - lexer.expect(Token::Paren('{'))?; + let brace_span = lexer.expect_span(Token::Paren('{'))?; + let brace_nesting_level = + Self::increase_brace_nesting(brace_nesting_level, brace_span)?; let mut cases = Vec::new(); loop { @@ -1784,7 +1785,7 @@ impl Parser { }); }; - let body = self.block(lexer, ctx)?.0; + let body = self.block(lexer, ctx, brace_nesting_level)?.0; cases.push(ast::SwitchCase { value, @@ -1794,7 +1795,7 @@ impl Parser { } (Token::Word("default"), _) => { lexer.skip(Token::Separator(':')); - let body = self.block(lexer, ctx)?.0; + let body = self.block(lexer, ctx, brace_nesting_level)?.0; cases.push(ast::SwitchCase { value: ast::SwitchValue::Default, body, @@ -1810,7 +1811,7 @@ impl Parser { ast::StatementKind::Switch { selector, cases } } - "loop" => self.r#loop(lexer, ctx)?, + "loop" => self.r#loop(lexer, ctx, brace_nesting_level)?, "while" => { let _ = lexer.next(); let mut body = ast::Block::default(); @@ -1834,7 +1835,7 @@ impl Parser { span, }); - let (block, span) = self.block(lexer, ctx)?; + let (block, span) = self.block(lexer, ctx, brace_nesting_level)?; body.stmts.push(ast::Statement { kind: ast::StatementKind::Block(block), span, @@ -1857,7 +1858,9 @@ impl Parser { let (_, span) = { let ctx = &mut *ctx; let block = &mut *block; - lexer.capture_span(|lexer| self.statement(lexer, ctx, block))? + lexer.capture_span(|lexer| { + self.statement(lexer, ctx, block, brace_nesting_level) + })? }; if block.stmts.len() != num_statements { @@ -1902,7 +1905,7 @@ impl Parser { lexer.expect(Token::Paren(')'))?; } - let (block, span) = self.block(lexer, ctx)?; + let (block, span) = self.block(lexer, ctx, brace_nesting_level)?; body.stmts.push(ast::Statement { kind: ast::StatementKind::Block(block), span, @@ -1964,13 +1967,15 @@ impl Parser { &mut self, lexer: &mut Lexer<'a>, ctx: &mut ExpressionContext<'a, '_, '_>, + brace_nesting_level: u8, ) -> Result<ast::StatementKind<'a>, Error<'a>> { let _ = lexer.next(); let mut body = ast::Block::default(); let mut continuing = ast::Block::default(); let mut break_if = None; - lexer.expect(Token::Paren('{'))?; + let brace_span = lexer.expect_span(Token::Paren('{'))?; + let brace_nesting_level = Self::increase_brace_nesting(brace_nesting_level, brace_span)?; ctx.local_table.push_scope(); @@ -1980,7 +1985,9 @@ impl Parser { // the last thing in the loop body // Expect a opening brace to start the continuing block - lexer.expect(Token::Paren('{'))?; + let brace_span = lexer.expect_span(Token::Paren('{'))?; + let brace_nesting_level = + Self::increase_brace_nesting(brace_nesting_level, brace_span)?; loop { if lexer.skip(Token::Word("break")) { // Branch for the `break if` statement, this statement @@ -2009,7 +2016,7 @@ impl Parser { break; } else { // Otherwise try to parse a statement - self.statement(lexer, ctx, &mut continuing)?; + self.statement(lexer, ctx, &mut continuing, brace_nesting_level)?; } } // Since the continuing block must be the last part of the loop body, @@ -2023,7 +2030,7 @@ impl Parser { break; } // Otherwise try to parse a statement - self.statement(lexer, ctx, &mut body)?; + self.statement(lexer, ctx, &mut body, brace_nesting_level)?; } ctx.local_table.pop_scope(); @@ -2040,15 +2047,17 @@ impl Parser { &mut self, lexer: &mut Lexer<'a>, ctx: &mut ExpressionContext<'a, '_, '_>, + brace_nesting_level: u8, ) -> Result<(ast::Block<'a>, Span), Error<'a>> { self.push_rule_span(Rule::Block, lexer); ctx.local_table.push_scope(); - lexer.expect(Token::Paren('{'))?; + let brace_span = lexer.expect_span(Token::Paren('{'))?; + let brace_nesting_level = Self::increase_brace_nesting(brace_nesting_level, brace_span)?; let mut block = ast::Block::default(); while !lexer.skip(Token::Paren('}')) { - self.statement(lexer, ctx, &mut block)?; + self.statement(lexer, ctx, &mut block, brace_nesting_level)?; } ctx.local_table.pop_scope(); @@ -2135,9 +2144,10 @@ impl Parser { // do not use `self.block` here, since we must not push a new scope lexer.expect(Token::Paren('{'))?; + let brace_nesting_level = 1; let mut body = ast::Block::default(); while !lexer.skip(Token::Paren('}')) { - self.statement(lexer, &mut ctx, &mut body)?; + self.statement(lexer, &mut ctx, &mut body, brace_nesting_level)?; } ctx.local_table.pop_scope(); @@ -2170,6 +2180,7 @@ impl Parser { let mut early_depth_test = ParsedAttribute::default(); let (mut bind_index, mut bind_group) = (ParsedAttribute::default(), ParsedAttribute::default()); + let mut id = ParsedAttribute::default(); let mut dependencies = FastIndexSet::default(); let mut ctx = ExpressionContext { @@ -2193,6 +2204,11 @@ impl Parser { bind_group.set(self.general_expression(lexer, &mut ctx)?, name_span)?; lexer.expect(Token::Paren(')'))?; } + ("id", name_span) => { + lexer.expect(Token::Paren('('))?; + id.set(self.general_expression(lexer, &mut ctx)?, name_span)?; + lexer.expect(Token::Paren(')'))?; + } ("vertex", name_span) => { stage.set(crate::ShaderStage::Vertex, name_span)?; } @@ -2283,6 +2299,30 @@ impl Parser { Some(ast::GlobalDeclKind::Const(ast::Const { name, ty, init })) } + (Token::Word("override"), _) => { + let name = lexer.next_ident()?; + + let ty = if lexer.skip(Token::Separator(':')) { + Some(self.type_decl(lexer, &mut ctx)?) + } else { + None + }; + + let init = if lexer.skip(Token::Operation('=')) { + Some(self.general_expression(lexer, &mut ctx)?) + } else { + None + }; + + lexer.expect(Token::Separator(';'))?; + + Some(ast::GlobalDeclKind::Override(ast::Override { + name, + id: id.value, + ty, + init, + })) + } (Token::Word("var"), _) => { let mut var = self.variable_decl(lexer, &mut ctx)?; var.binding = binding.take(); @@ -2347,4 +2387,30 @@ impl Parser { Ok(tu) } + + const fn increase_brace_nesting( + brace_nesting_level: u8, + brace_span: Span, + ) -> Result<u8, Error<'static>> { + // From [spec.](https://gpuweb.github.io/gpuweb/wgsl/#limits): + // + // > § 2.4. Limits + // > + // > … + // > + // > Maximum nesting depth of brace-enclosed statements in a function[:] 127 + // + // _However_, we choose 64 instead because (a) it avoids stack overflows in CI and + // (b) we expect the limit to be decreased to 63 based on this conversation in + // WebGPU CTS upstream: + // <https://github.com/gpuweb/cts/pull/3389#discussion_r1543742701> + const BRACE_NESTING_MAXIMUM: u8 = 64; + if brace_nesting_level + 1 > BRACE_NESTING_MAXIMUM { + return Err(Error::ExceededLimitForNestedBraces { + span: brace_span, + limit: BRACE_NESTING_MAXIMUM, + }); + } + Ok(brace_nesting_level + 1) + } } diff --git a/third_party/rust/naga/src/front/wgsl/to_wgsl.rs b/third_party/rust/naga/src/front/wgsl/to_wgsl.rs index c8331ace09..63bc9f7317 100644 --- a/third_party/rust/naga/src/front/wgsl/to_wgsl.rs +++ b/third_party/rust/naga/src/front/wgsl/to_wgsl.rs @@ -226,7 +226,8 @@ mod tests { let gctx = crate::proc::GlobalCtx { types: &types, constants: &crate::Arena::new(), - const_expressions: &crate::Arena::new(), + overrides: &crate::Arena::new(), + global_expressions: &crate::Arena::new(), }; let array = crate::TypeInner::Array { base: mytype1, |