summaryrefslogtreecommitdiffstats
path: root/third_party/rust/naga/src/front/wgsl
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/naga/src/front/wgsl')
-rw-r--r--third_party/rust/naga/src/front/wgsl/error.rs29
-rw-r--r--third_party/rust/naga/src/front/wgsl/index.rs1
-rw-r--r--third_party/rust/naga/src/front/wgsl/lower/mod.rs414
-rw-r--r--third_party/rust/naga/src/front/wgsl/mod.rs11
-rw-r--r--third_party/rust/naga/src/front/wgsl/parse/ast.rs9
-rw-r--r--third_party/rust/naga/src/front/wgsl/parse/conv.rs28
-rw-r--r--third_party/rust/naga/src/front/wgsl/parse/mod.rs106
-rw-r--r--third_party/rust/naga/src/front/wgsl/to_wgsl.rs3
8 files changed, 469 insertions, 132 deletions
diff --git a/third_party/rust/naga/src/front/wgsl/error.rs b/third_party/rust/naga/src/front/wgsl/error.rs
index 54aa8296b1..dc1339521c 100644
--- a/third_party/rust/naga/src/front/wgsl/error.rs
+++ b/third_party/rust/naga/src/front/wgsl/error.rs
@@ -13,6 +13,7 @@ use thiserror::Error;
#[derive(Clone, Debug)]
pub struct ParseError {
message: String,
+ // The first span should be the primary span, and the other ones should be complementary.
labels: Vec<(Span, Cow<'static, str>)>,
notes: Vec<String>,
}
@@ -190,7 +191,7 @@ pub enum Error<'a> {
expected: String,
got: String,
},
- MissingType(Span),
+ DeclMissingTypeAndInit(Span),
MissingAttribute(&'static str, Span),
InvalidAtomicPointer(Span),
InvalidAtomicOperandType(Span),
@@ -269,6 +270,11 @@ pub enum Error<'a> {
scalar: String,
inner: ConstantEvaluatorError,
},
+ ExceededLimitForNestedBraces {
+ span: Span,
+ limit: u8,
+ },
+ PipelineConstantIDValue(Span),
}
impl<'a> Error<'a> {
@@ -518,11 +524,11 @@ impl<'a> Error<'a> {
notes: vec![],
}
}
- Error::MissingType(name_span) => ParseError {
- message: format!("variable `{}` needs a type", &source[name_span]),
+ Error::DeclMissingTypeAndInit(name_span) => ParseError {
+ message: format!("declaration of `{}` needs a type specifier or initializer", &source[name_span]),
labels: vec![(
name_span,
- format!("definition of `{}`", &source[name_span]).into(),
+ "needs a type specifier or initializer".into(),
)],
notes: vec![],
},
@@ -770,6 +776,21 @@ impl<'a> Error<'a> {
format!("the expression should have been converted to have {} scalar type", scalar),
]
},
+ Error::ExceededLimitForNestedBraces { span, limit } => ParseError {
+ message: "brace nesting limit reached".into(),
+ labels: vec![(span, "limit reached at this brace".into())],
+ notes: vec![
+ format!("nesting limit is currently set to {limit}"),
+ ],
+ },
+ Error::PipelineConstantIDValue(span) => ParseError {
+ message: "pipeline constant ID must be between 0 and 65535 inclusive".to_string(),
+ labels: vec![(
+ span,
+ "must be between 0 and 65535 inclusive".into(),
+ )],
+ notes: vec![],
+ },
}
}
}
diff --git a/third_party/rust/naga/src/front/wgsl/index.rs b/third_party/rust/naga/src/front/wgsl/index.rs
index a5524fe8f1..593405508f 100644
--- a/third_party/rust/naga/src/front/wgsl/index.rs
+++ b/third_party/rust/naga/src/front/wgsl/index.rs
@@ -187,6 +187,7 @@ const fn decl_ident<'a>(decl: &ast::GlobalDecl<'a>) -> ast::Ident<'a> {
ast::GlobalDeclKind::Fn(ref f) => f.name,
ast::GlobalDeclKind::Var(ref v) => v.name,
ast::GlobalDeclKind::Const(ref c) => c.name,
+ ast::GlobalDeclKind::Override(ref o) => o.name,
ast::GlobalDeclKind::Struct(ref s) => s.name,
ast::GlobalDeclKind::Type(ref t) => t.name,
}
diff --git a/third_party/rust/naga/src/front/wgsl/lower/mod.rs b/third_party/rust/naga/src/front/wgsl/lower/mod.rs
index 2ca6c182b7..e7cce17723 100644
--- a/third_party/rust/naga/src/front/wgsl/lower/mod.rs
+++ b/third_party/rust/naga/src/front/wgsl/lower/mod.rs
@@ -86,6 +86,8 @@ pub struct GlobalContext<'source, 'temp, 'out> {
module: &'out mut crate::Module,
const_typifier: &'temp mut Typifier,
+
+ global_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker,
}
impl<'source> GlobalContext<'source, '_, '_> {
@@ -97,6 +99,19 @@ impl<'source> GlobalContext<'source, '_, '_> {
module: self.module,
const_typifier: self.const_typifier,
expr_type: ExpressionContextType::Constant,
+ global_expression_kind_tracker: self.global_expression_kind_tracker,
+ }
+ }
+
+ fn as_override(&mut self) -> ExpressionContext<'source, '_, '_> {
+ ExpressionContext {
+ ast_expressions: self.ast_expressions,
+ globals: self.globals,
+ types: self.types,
+ module: self.module,
+ const_typifier: self.const_typifier,
+ expr_type: ExpressionContextType::Override,
+ global_expression_kind_tracker: self.global_expression_kind_tracker,
}
}
@@ -164,7 +179,8 @@ pub struct StatementContext<'source, 'temp, 'out> {
/// with the form of the expressions; it is also tracking whether WGSL says
/// we should consider them to be const. See the use of `force_non_const` in
/// the code for lowering `let` bindings.
- expression_constness: &'temp mut crate::proc::ExpressionConstnessTracker,
+ local_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker,
+ global_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker,
}
impl<'a, 'temp> StatementContext<'a, 'temp, '_> {
@@ -181,6 +197,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> {
types: self.types,
ast_expressions: self.ast_expressions,
const_typifier: self.const_typifier,
+ global_expression_kind_tracker: self.global_expression_kind_tracker,
module: self.module,
expr_type: ExpressionContextType::Runtime(RuntimeExpressionContext {
local_table: self.local_table,
@@ -188,7 +205,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> {
block,
emitter,
typifier: self.typifier,
- expression_constness: self.expression_constness,
+ local_expression_kind_tracker: self.local_expression_kind_tracker,
}),
}
}
@@ -200,6 +217,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> {
types: self.types,
module: self.module,
const_typifier: self.const_typifier,
+ global_expression_kind_tracker: self.global_expression_kind_tracker,
}
}
@@ -232,8 +250,8 @@ pub struct RuntimeExpressionContext<'temp, 'out> {
/// Which `Expression`s in `self.naga_expressions` are const expressions, in
/// the WGSL sense.
///
- /// See [`StatementContext::expression_constness`] for details.
- expression_constness: &'temp mut crate::proc::ExpressionConstnessTracker,
+ /// See [`StatementContext::local_expression_kind_tracker`] for details.
+ local_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker,
}
/// The type of Naga IR expression we are lowering an [`ast::Expression`] to.
@@ -253,6 +271,14 @@ pub enum ExpressionContextType<'temp, 'out> {
/// available in the [`ExpressionContext`], so this variant
/// carries no further information.
Constant,
+
+ /// We are lowering to an override expression, to be included in the module's
+ /// constant expression arena.
+ ///
+ /// Everything override expressions are allowed to refer to is
+ /// available in the [`ExpressionContext`], so this variant
+ /// carries no further information.
+ Override,
}
/// State for lowering an [`ast::Expression`] to Naga IR.
@@ -307,10 +333,11 @@ pub struct ExpressionContext<'source, 'temp, 'out> {
/// [`Module`]: crate::Module
module: &'out mut crate::Module,
- /// Type judgments for [`module::const_expressions`].
+ /// Type judgments for [`module::global_expressions`].
///
- /// [`module::const_expressions`]: crate::Module::const_expressions
+ /// [`module::global_expressions`]: crate::Module::global_expressions
const_typifier: &'temp mut Typifier,
+ global_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker,
/// Whether we are lowering a constant expression or a general
/// runtime expression, and the data needed in each case.
@@ -326,6 +353,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
const_typifier: self.const_typifier,
module: self.module,
expr_type: ExpressionContextType::Constant,
+ global_expression_kind_tracker: self.global_expression_kind_tracker,
}
}
@@ -336,6 +364,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
types: self.types,
module: self.module,
const_typifier: self.const_typifier,
+ global_expression_kind_tracker: self.global_expression_kind_tracker,
}
}
@@ -344,11 +373,20 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
ExpressionContextType::Runtime(ref mut rctx) => ConstantEvaluator::for_wgsl_function(
self.module,
&mut rctx.function.expressions,
- rctx.expression_constness,
+ rctx.local_expression_kind_tracker,
rctx.emitter,
rctx.block,
),
- ExpressionContextType::Constant => ConstantEvaluator::for_wgsl_module(self.module),
+ ExpressionContextType::Constant => ConstantEvaluator::for_wgsl_module(
+ self.module,
+ self.global_expression_kind_tracker,
+ false,
+ ),
+ ExpressionContextType::Override => ConstantEvaluator::for_wgsl_module(
+ self.module,
+ self.global_expression_kind_tracker,
+ true,
+ ),
}
}
@@ -358,24 +396,14 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
span: Span,
) -> Result<Handle<crate::Expression>, Error<'source>> {
let mut eval = self.as_const_evaluator();
- match eval.try_eval_and_append(&expr, span) {
- Ok(expr) => Ok(expr),
-
- // `expr` is not a constant expression. This is fine as
- // long as we're not building `Module::const_expressions`.
- Err(err) => match self.expr_type {
- ExpressionContextType::Runtime(ref mut rctx) => {
- Ok(rctx.function.expressions.append(expr, span))
- }
- ExpressionContextType::Constant => Err(Error::ConstantEvaluatorError(err, span)),
- },
- }
+ eval.try_eval_and_append(expr, span)
+ .map_err(|e| Error::ConstantEvaluatorError(e, span))
}
fn const_access(&self, handle: Handle<crate::Expression>) -> Option<u32> {
match self.expr_type {
ExpressionContextType::Runtime(ref ctx) => {
- if !ctx.expression_constness.is_const(handle) {
+ if !ctx.local_expression_kind_tracker.is_const(handle) {
return None;
}
@@ -385,20 +413,25 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
.ok()
}
ExpressionContextType::Constant => self.module.to_ctx().eval_expr_to_u32(handle).ok(),
+ ExpressionContextType::Override => None,
}
}
fn get_expression_span(&self, handle: Handle<crate::Expression>) -> Span {
match self.expr_type {
ExpressionContextType::Runtime(ref ctx) => ctx.function.expressions.get_span(handle),
- ExpressionContextType::Constant => self.module.const_expressions.get_span(handle),
+ ExpressionContextType::Constant | ExpressionContextType::Override => {
+ self.module.global_expressions.get_span(handle)
+ }
}
}
fn typifier(&self) -> &Typifier {
match self.expr_type {
ExpressionContextType::Runtime(ref ctx) => ctx.typifier,
- ExpressionContextType::Constant => self.const_typifier,
+ ExpressionContextType::Constant | ExpressionContextType::Override => {
+ self.const_typifier
+ }
}
}
@@ -408,7 +441,9 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
) -> Result<&mut RuntimeExpressionContext<'temp, 'out>, Error<'source>> {
match self.expr_type {
ExpressionContextType::Runtime(ref mut ctx) => Ok(ctx),
- ExpressionContextType::Constant => Err(Error::UnexpectedOperationInConstContext(span)),
+ ExpressionContextType::Constant | ExpressionContextType::Override => {
+ Err(Error::UnexpectedOperationInConstContext(span))
+ }
}
}
@@ -420,7 +455,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
) -> Result<crate::SwizzleComponent, Error<'source>> {
match self.expr_type {
ExpressionContextType::Runtime(ref rctx) => {
- if !rctx.expression_constness.is_const(expr) {
+ if !rctx.local_expression_kind_tracker.is_const(expr) {
return Err(Error::ExpectedConstExprConcreteIntegerScalar(
component_span,
));
@@ -445,7 +480,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
}
// This means a `gather` operation appeared in a constant expression.
// This error refers to the `gather` itself, not its "component" argument.
- ExpressionContextType::Constant => {
+ ExpressionContextType::Constant | ExpressionContextType::Override => {
Err(Error::UnexpectedOperationInConstContext(gather_span))
}
}
@@ -471,7 +506,9 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
// to also borrow self.module.types mutably below.
let typifier = match self.expr_type {
ExpressionContextType::Runtime(ref ctx) => ctx.typifier,
- ExpressionContextType::Constant => &*self.const_typifier,
+ ExpressionContextType::Constant | ExpressionContextType::Override => {
+ &*self.const_typifier
+ }
};
Ok(typifier.register_type(handle, &mut self.module.types))
}
@@ -514,10 +551,10 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
typifier = &mut *ctx.typifier;
expressions = &ctx.function.expressions;
}
- ExpressionContextType::Constant => {
+ ExpressionContextType::Constant | ExpressionContextType::Override => {
resolve_ctx = ResolveContext::with_locals(self.module, &empty_arena, &[]);
typifier = self.const_typifier;
- expressions = &self.module.const_expressions;
+ expressions = &self.module.global_expressions;
}
};
typifier
@@ -610,14 +647,14 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
rctx.block
.extend(rctx.emitter.finish(&rctx.function.expressions));
}
- ExpressionContextType::Constant => {}
+ ExpressionContextType::Constant | ExpressionContextType::Override => {}
}
let result = self.append_expression(expression, span);
match self.expr_type {
ExpressionContextType::Runtime(ref mut rctx) => {
rctx.emitter.start(&rctx.function.expressions);
}
- ExpressionContextType::Constant => {}
+ ExpressionContextType::Constant | ExpressionContextType::Override => {}
}
result
}
@@ -786,6 +823,7 @@ enum LoweredGlobalDecl {
Function(Handle<crate::Function>),
Var(Handle<crate::GlobalVariable>),
Const(Handle<crate::Constant>),
+ Override(Handle<crate::Override>),
Type(Handle<crate::Type>),
EntryPoint,
}
@@ -836,6 +874,29 @@ impl Texture {
}
}
+enum SubgroupGather {
+ BroadcastFirst,
+ Broadcast,
+ Shuffle,
+ ShuffleDown,
+ ShuffleUp,
+ ShuffleXor,
+}
+
+impl SubgroupGather {
+ pub fn map(word: &str) -> Option<Self> {
+ Some(match word {
+ "subgroupBroadcastFirst" => Self::BroadcastFirst,
+ "subgroupBroadcast" => Self::Broadcast,
+ "subgroupShuffle" => Self::Shuffle,
+ "subgroupShuffleDown" => Self::ShuffleDown,
+ "subgroupShuffleUp" => Self::ShuffleUp,
+ "subgroupShuffleXor" => Self::ShuffleXor,
+ _ => return None,
+ })
+ }
+}
+
pub struct Lowerer<'source, 'temp> {
index: &'temp Index<'source>,
layouter: Layouter,
@@ -861,6 +922,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
types: &tu.types,
module: &mut module,
const_typifier: &mut Typifier::new(),
+ global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker::new(),
};
for decl_handle in self.index.visit_ordered() {
@@ -877,7 +939,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let init;
if let Some(init_ast) = v.init {
- let mut ectx = ctx.as_const();
+ let mut ectx = ctx.as_override();
let lowered = self.expression_for_abstract(init_ast, &mut ectx)?;
let ty_res = crate::proc::TypeResolution::Handle(ty);
let converted = ectx
@@ -956,7 +1018,6 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let handle = ctx.module.constants.append(
crate::Constant {
name: Some(c.name.name.to_string()),
- r#override: crate::Override::None,
ty,
init,
},
@@ -966,6 +1027,65 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
ctx.globals
.insert(c.name.name, LoweredGlobalDecl::Const(handle));
}
+ ast::GlobalDeclKind::Override(ref o) => {
+ let init = o
+ .init
+ .map(|init| self.expression(init, &mut ctx.as_override()))
+ .transpose()?;
+ let inferred_type = init
+ .map(|init| ctx.as_const().register_type(init))
+ .transpose()?;
+
+ let explicit_ty =
+ o.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx))
+ .transpose()?;
+
+ let id =
+ o.id.map(|id| self.const_u32(id, &mut ctx.as_const()))
+ .transpose()?;
+
+ let id = if let Some((id, id_span)) = id {
+ Some(
+ u16::try_from(id)
+ .map_err(|_| Error::PipelineConstantIDValue(id_span))?,
+ )
+ } else {
+ None
+ };
+
+ let ty = match (explicit_ty, inferred_type) {
+ (Some(explicit_ty), Some(inferred_type)) => {
+ if explicit_ty == inferred_type {
+ explicit_ty
+ } else {
+ let gctx = ctx.module.to_ctx();
+ return Err(Error::InitializationTypeMismatch {
+ name: o.name.span,
+ expected: explicit_ty.to_wgsl(&gctx),
+ got: inferred_type.to_wgsl(&gctx),
+ });
+ }
+ }
+ (Some(explicit_ty), None) => explicit_ty,
+ (None, Some(inferred_type)) => inferred_type,
+ (None, None) => {
+ return Err(Error::DeclMissingTypeAndInit(o.name.span));
+ }
+ };
+
+ let handle = ctx.module.overrides.append(
+ crate::Override {
+ name: Some(o.name.name.to_string()),
+ id,
+ ty,
+ init,
+ },
+ span,
+ );
+
+ ctx.globals
+ .insert(o.name.name, LoweredGlobalDecl::Override(handle));
+ }
ast::GlobalDeclKind::Struct(ref s) => {
let handle = self.r#struct(s, span, &mut ctx)?;
ctx.globals
@@ -1000,6 +1120,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let mut local_table = FastHashMap::default();
let mut expressions = Arena::new();
let mut named_expressions = FastIndexMap::default();
+ let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
let arguments = f
.arguments
@@ -1011,6 +1132,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.append(crate::Expression::FunctionArgument(i as u32), arg.name.span);
local_table.insert(arg.handle, Typed::Plain(expr));
named_expressions.insert(expr, (arg.name.name.to_string(), arg.name.span));
+ local_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Runtime);
Ok(crate::FunctionArgument {
name: Some(arg.name.name.to_string()),
@@ -1053,7 +1175,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
named_expressions: &mut named_expressions,
types: ctx.types,
module: ctx.module,
- expression_constness: &mut crate::proc::ExpressionConstnessTracker::new(),
+ local_expression_kind_tracker: &mut local_expression_kind_tracker,
+ global_expression_kind_tracker: ctx.global_expression_kind_tracker,
};
let mut body = self.block(&f.body, false, &mut stmt_ctx)?;
ensure_block_returns(&mut body);
@@ -1132,7 +1255,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
// affects when errors must be reported, so we can't even
// treat suitable `let` bindings as constant as an
// optimization.
- ctx.expression_constness.force_non_const(value);
+ ctx.local_expression_kind_tracker.force_non_const(value);
let explicit_ty =
l.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx.as_global()))
@@ -1203,7 +1326,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
ty = explicit_ty;
initializer = None;
}
- (None, None) => return Err(Error::MissingType(v.name.span)),
+ (None, None) => return Err(Error::DeclMissingTypeAndInit(v.name.span)),
}
let (const_initializer, initializer) = {
@@ -1216,7 +1339,9 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
// - the initialization is not a constant
// expression, so its value depends on the
// state at the point of initialization.
- if is_inside_loop || !ctx.expression_constness.is_const(init) {
+ if is_inside_loop
+ || !ctx.local_expression_kind_tracker.is_const_or_override(init)
+ {
(None, Some(init))
} else {
(Some(init), None)
@@ -1469,6 +1594,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.function
.expressions
.append(crate::Expression::Binary { op, left, right }, stmt.span);
+ rctx.local_expression_kind_tracker
+ .insert(left, crate::proc::ExpressionKind::Runtime);
+ rctx.local_expression_kind_tracker
+ .insert(value, crate::proc::ExpressionKind::Runtime);
block.extend(emitter.finish(&ctx.function.expressions));
crate::Statement::Store {
@@ -1562,7 +1691,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
LoweredGlobalDecl::Const(handle) => {
Typed::Plain(crate::Expression::Constant(handle))
}
- _ => {
+ LoweredGlobalDecl::Override(handle) => {
+ Typed::Plain(crate::Expression::Override(handle))
+ }
+ LoweredGlobalDecl::Function(_)
+ | LoweredGlobalDecl::Type(_)
+ | LoweredGlobalDecl::EntryPoint => {
return Err(Error::Unexpected(span, ExpectedToken::Variable));
}
};
@@ -1819,9 +1953,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
)?;
Ok(Some(handle))
}
- Some(&LoweredGlobalDecl::Const(_) | &LoweredGlobalDecl::Var(_)) => {
- Err(Error::Unexpected(function.span, ExpectedToken::Function))
- }
+ Some(
+ &LoweredGlobalDecl::Const(_)
+ | &LoweredGlobalDecl::Override(_)
+ | &LoweredGlobalDecl::Var(_),
+ ) => Err(Error::Unexpected(function.span, ExpectedToken::Function)),
Some(&LoweredGlobalDecl::EntryPoint) => Err(Error::CalledEntryPoint(function.span)),
Some(&LoweredGlobalDecl::Function(function)) => {
let arguments = arguments
@@ -1835,9 +1971,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
rctx.block
.extend(rctx.emitter.finish(&rctx.function.expressions));
let result = has_result.then(|| {
- rctx.function
+ let result = rctx
+ .function
.expressions
- .append(crate::Expression::CallResult(function), span)
+ .append(crate::Expression::CallResult(function), span);
+ rctx.local_expression_kind_tracker
+ .insert(result, crate::proc::ExpressionKind::Runtime);
+ result
});
rctx.emitter.start(&rctx.function.expressions);
rctx.block.push(
@@ -1937,6 +2077,16 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}
} else if let Some(fun) = Texture::map(function.name) {
self.texture_sample_helper(fun, arguments, span, ctx)?
+ } else if let Some((op, cop)) = conv::map_subgroup_operation(function.name) {
+ return Ok(Some(
+ self.subgroup_operation_helper(span, op, cop, arguments, ctx)?,
+ ));
+ } else if let Some(mode) = SubgroupGather::map(function.name) {
+ return Ok(Some(
+ self.subgroup_gather_helper(span, mode, arguments, ctx)?,
+ ));
+ } else if let Some(fun) = crate::AtomicFunction::map(function.name) {
+ return Ok(Some(self.atomic_helper(span, fun, arguments, ctx)?));
} else {
match function.name {
"select" => {
@@ -1982,70 +2132,6 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.push(crate::Statement::Store { pointer, value }, span);
return Ok(None);
}
- "atomicAdd" => {
- return Ok(Some(self.atomic_helper(
- span,
- crate::AtomicFunction::Add,
- arguments,
- ctx,
- )?))
- }
- "atomicSub" => {
- return Ok(Some(self.atomic_helper(
- span,
- crate::AtomicFunction::Subtract,
- arguments,
- ctx,
- )?))
- }
- "atomicAnd" => {
- return Ok(Some(self.atomic_helper(
- span,
- crate::AtomicFunction::And,
- arguments,
- ctx,
- )?))
- }
- "atomicOr" => {
- return Ok(Some(self.atomic_helper(
- span,
- crate::AtomicFunction::InclusiveOr,
- arguments,
- ctx,
- )?))
- }
- "atomicXor" => {
- return Ok(Some(self.atomic_helper(
- span,
- crate::AtomicFunction::ExclusiveOr,
- arguments,
- ctx,
- )?))
- }
- "atomicMin" => {
- return Ok(Some(self.atomic_helper(
- span,
- crate::AtomicFunction::Min,
- arguments,
- ctx,
- )?))
- }
- "atomicMax" => {
- return Ok(Some(self.atomic_helper(
- span,
- crate::AtomicFunction::Max,
- arguments,
- ctx,
- )?))
- }
- "atomicExchange" => {
- return Ok(Some(self.atomic_helper(
- span,
- crate::AtomicFunction::Exchange { compare: None },
- arguments,
- ctx,
- )?))
- }
"atomicCompareExchangeWeak" => {
let mut args = ctx.prepare_args(arguments, 3, span);
@@ -2104,6 +2190,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.push(crate::Statement::Barrier(crate::Barrier::WORK_GROUP), span);
return Ok(None);
}
+ "subgroupBarrier" => {
+ ctx.prepare_args(arguments, 0, span).finish()?;
+
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ rctx.block
+ .push(crate::Statement::Barrier(crate::Barrier::SUB_GROUP), span);
+ return Ok(None);
+ }
"workgroupUniformLoad" => {
let mut args = ctx.prepare_args(arguments, 1, span);
let expr = args.next()?;
@@ -2311,6 +2405,22 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
)?;
return Ok(Some(handle));
}
+ "subgroupBallot" => {
+ let mut args = ctx.prepare_args(arguments, 0, span);
+ let predicate = if arguments.len() == 1 {
+ Some(self.expression(args.next()?, ctx)?)
+ } else {
+ None
+ };
+ args.finish()?;
+
+ let result = ctx
+ .interrupt_emitter(crate::Expression::SubgroupBallotResult, span)?;
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ rctx.block
+ .push(crate::Statement::SubgroupBallot { result, predicate }, span);
+ return Ok(Some(result));
+ }
_ => return Err(Error::UnknownIdent(function.span, function.name)),
}
};
@@ -2502,6 +2612,80 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
})
}
+ fn subgroup_operation_helper(
+ &mut self,
+ span: Span,
+ op: crate::SubgroupOperation,
+ collective_op: crate::CollectiveOperation,
+ arguments: &[Handle<ast::Expression<'source>>],
+ ctx: &mut ExpressionContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+
+ let argument = self.expression(args.next()?, ctx)?;
+ args.finish()?;
+
+ let ty = ctx.register_type(argument)?;
+
+ let result =
+ ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?;
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ rctx.block.push(
+ crate::Statement::SubgroupCollectiveOperation {
+ op,
+ collective_op,
+ argument,
+ result,
+ },
+ span,
+ );
+ Ok(result)
+ }
+
+ fn subgroup_gather_helper(
+ &mut self,
+ span: Span,
+ mode: SubgroupGather,
+ arguments: &[Handle<ast::Expression<'source>>],
+ ctx: &mut ExpressionContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ let mut args = ctx.prepare_args(arguments, 2, span);
+
+ let argument = self.expression(args.next()?, ctx)?;
+
+ use SubgroupGather as Sg;
+ let mode = if let Sg::BroadcastFirst = mode {
+ crate::GatherMode::BroadcastFirst
+ } else {
+ let index = self.expression(args.next()?, ctx)?;
+ match mode {
+ Sg::Broadcast => crate::GatherMode::Broadcast(index),
+ Sg::Shuffle => crate::GatherMode::Shuffle(index),
+ Sg::ShuffleDown => crate::GatherMode::ShuffleDown(index),
+ Sg::ShuffleUp => crate::GatherMode::ShuffleUp(index),
+ Sg::ShuffleXor => crate::GatherMode::ShuffleXor(index),
+ Sg::BroadcastFirst => unreachable!(),
+ }
+ };
+
+ args.finish()?;
+
+ let ty = ctx.register_type(argument)?;
+
+ let result =
+ ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?;
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ rctx.block.push(
+ crate::Statement::SubgroupGather {
+ mode,
+ argument,
+ result,
+ },
+ span,
+ );
+ Ok(result)
+ }
+
fn r#struct(
&mut self,
s: &ast::Struct<'source>,
@@ -2760,3 +2944,19 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}
}
}
+
+impl crate::AtomicFunction {
+ pub fn map(word: &str) -> Option<Self> {
+ Some(match word {
+ "atomicAdd" => crate::AtomicFunction::Add,
+ "atomicSub" => crate::AtomicFunction::Subtract,
+ "atomicAnd" => crate::AtomicFunction::And,
+ "atomicOr" => crate::AtomicFunction::InclusiveOr,
+ "atomicXor" => crate::AtomicFunction::ExclusiveOr,
+ "atomicMin" => crate::AtomicFunction::Min,
+ "atomicMax" => crate::AtomicFunction::Max,
+ "atomicExchange" => crate::AtomicFunction::Exchange { compare: None },
+ _ => return None,
+ })
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/mod.rs b/third_party/rust/naga/src/front/wgsl/mod.rs
index b6151fe1c0..aec1e657fc 100644
--- a/third_party/rust/naga/src/front/wgsl/mod.rs
+++ b/third_party/rust/naga/src/front/wgsl/mod.rs
@@ -44,6 +44,17 @@ impl Frontend {
}
}
+/// <div class="warning">
+// NOTE: Keep this in sync with `wgpu::Device::create_shader_module`!
+// NOTE: Keep this in sync with `wgpu_core::Global::device_create_shader_module`!
+///
+/// This function may consume a lot of stack space. Compiler-enforced limits for parsing recursion
+/// exist; if shader compilation runs into them, it will return an error gracefully. However, on
+/// some build profiles and platforms, the default stack size for a thread may be exceeded before
+/// this limit is reached during parsing. Callers should ensure that there is enough stack space
+/// for this, particularly if calls to this method are exposed to user input.
+///
+/// </div>
pub fn parse_str(source: &str) -> Result<crate::Module, ParseError> {
Frontend::new().parse(source)
}
diff --git a/third_party/rust/naga/src/front/wgsl/parse/ast.rs b/third_party/rust/naga/src/front/wgsl/parse/ast.rs
index dbaac523cb..ea8013ee7c 100644
--- a/third_party/rust/naga/src/front/wgsl/parse/ast.rs
+++ b/third_party/rust/naga/src/front/wgsl/parse/ast.rs
@@ -82,6 +82,7 @@ pub enum GlobalDeclKind<'a> {
Fn(Function<'a>),
Var(GlobalVariable<'a>),
Const(Const<'a>),
+ Override(Override<'a>),
Struct(Struct<'a>),
Type(TypeAlias<'a>),
}
@@ -200,6 +201,14 @@ pub struct Const<'a> {
pub init: Handle<Expression<'a>>,
}
+#[derive(Debug)]
+pub struct Override<'a> {
+ pub name: Ident<'a>,
+ pub id: Option<Handle<Expression<'a>>>,
+ pub ty: Option<Handle<Type<'a>>>,
+ pub init: Option<Handle<Expression<'a>>>,
+}
+
/// The size of an [`Array`] or [`BindingArray`].
///
/// [`Array`]: Type::Array
diff --git a/third_party/rust/naga/src/front/wgsl/parse/conv.rs b/third_party/rust/naga/src/front/wgsl/parse/conv.rs
index 1a4911a3bd..207f0eda41 100644
--- a/third_party/rust/naga/src/front/wgsl/parse/conv.rs
+++ b/third_party/rust/naga/src/front/wgsl/parse/conv.rs
@@ -35,6 +35,11 @@ pub fn map_built_in(word: &str, span: Span) -> Result<crate::BuiltIn, Error<'_>>
"local_invocation_index" => crate::BuiltIn::LocalInvocationIndex,
"workgroup_id" => crate::BuiltIn::WorkGroupId,
"num_workgroups" => crate::BuiltIn::NumWorkGroups,
+ // subgroup
+ "num_subgroups" => crate::BuiltIn::NumSubgroups,
+ "subgroup_id" => crate::BuiltIn::SubgroupId,
+ "subgroup_size" => crate::BuiltIn::SubgroupSize,
+ "subgroup_invocation_id" => crate::BuiltIn::SubgroupInvocationId,
_ => return Err(Error::UnknownBuiltin(span)),
})
}
@@ -260,3 +265,26 @@ pub fn map_conservative_depth(
_ => Err(Error::UnknownConservativeDepth(span)),
}
}
+
+pub fn map_subgroup_operation(
+ word: &str,
+) -> Option<(crate::SubgroupOperation, crate::CollectiveOperation)> {
+ use crate::CollectiveOperation as co;
+ use crate::SubgroupOperation as sg;
+ Some(match word {
+ "subgroupAll" => (sg::All, co::Reduce),
+ "subgroupAny" => (sg::Any, co::Reduce),
+ "subgroupAdd" => (sg::Add, co::Reduce),
+ "subgroupMul" => (sg::Mul, co::Reduce),
+ "subgroupMin" => (sg::Min, co::Reduce),
+ "subgroupMax" => (sg::Max, co::Reduce),
+ "subgroupAnd" => (sg::And, co::Reduce),
+ "subgroupOr" => (sg::Or, co::Reduce),
+ "subgroupXor" => (sg::Xor, co::Reduce),
+ "subgroupExclusiveAdd" => (sg::Add, co::ExclusiveScan),
+ "subgroupExclusiveMul" => (sg::Mul, co::ExclusiveScan),
+ "subgroupInclusiveAdd" => (sg::Add, co::InclusiveScan),
+ "subgroupInclusiveMul" => (sg::Mul, co::InclusiveScan),
+ _ => return None,
+ })
+}
diff --git a/third_party/rust/naga/src/front/wgsl/parse/mod.rs b/third_party/rust/naga/src/front/wgsl/parse/mod.rs
index 51fc2f013b..79ea1ae609 100644
--- a/third_party/rust/naga/src/front/wgsl/parse/mod.rs
+++ b/third_party/rust/naga/src/front/wgsl/parse/mod.rs
@@ -1619,22 +1619,21 @@ impl Parser {
lexer: &mut Lexer<'a>,
ctx: &mut ExpressionContext<'a, '_, '_>,
block: &mut ast::Block<'a>,
+ brace_nesting_level: u8,
) -> Result<(), Error<'a>> {
self.push_rule_span(Rule::Statement, lexer);
match lexer.peek() {
(Token::Separator(';'), _) => {
let _ = lexer.next();
self.pop_rule_span(lexer);
- return Ok(());
}
(Token::Paren('{'), _) => {
- let (inner, span) = self.block(lexer, ctx)?;
+ let (inner, span) = self.block(lexer, ctx, brace_nesting_level)?;
block.stmts.push(ast::Statement {
kind: ast::StatementKind::Block(inner),
span,
});
self.pop_rule_span(lexer);
- return Ok(());
}
(Token::Word(word), _) => {
let kind = match word {
@@ -1711,7 +1710,7 @@ impl Parser {
let _ = lexer.next();
let condition = self.general_expression(lexer, ctx)?;
- let accept = self.block(lexer, ctx)?.0;
+ let accept = self.block(lexer, ctx, brace_nesting_level)?.0;
let mut elsif_stack = Vec::new();
let mut elseif_span_start = lexer.start_byte_offset();
@@ -1722,12 +1721,12 @@ impl Parser {
if !lexer.skip(Token::Word("if")) {
// ... else { ... }
- break self.block(lexer, ctx)?.0;
+ break self.block(lexer, ctx, brace_nesting_level)?.0;
}
// ... else if (...) { ... }
let other_condition = self.general_expression(lexer, ctx)?;
- let other_block = self.block(lexer, ctx)?;
+ let other_block = self.block(lexer, ctx, brace_nesting_level)?;
elsif_stack.push((elseif_span_start, other_condition, other_block));
elseif_span_start = lexer.start_byte_offset();
};
@@ -1759,7 +1758,9 @@ impl Parser {
"switch" => {
let _ = lexer.next();
let selector = self.general_expression(lexer, ctx)?;
- lexer.expect(Token::Paren('{'))?;
+ let brace_span = lexer.expect_span(Token::Paren('{'))?;
+ let brace_nesting_level =
+ Self::increase_brace_nesting(brace_nesting_level, brace_span)?;
let mut cases = Vec::new();
loop {
@@ -1784,7 +1785,7 @@ impl Parser {
});
};
- let body = self.block(lexer, ctx)?.0;
+ let body = self.block(lexer, ctx, brace_nesting_level)?.0;
cases.push(ast::SwitchCase {
value,
@@ -1794,7 +1795,7 @@ impl Parser {
}
(Token::Word("default"), _) => {
lexer.skip(Token::Separator(':'));
- let body = self.block(lexer, ctx)?.0;
+ let body = self.block(lexer, ctx, brace_nesting_level)?.0;
cases.push(ast::SwitchCase {
value: ast::SwitchValue::Default,
body,
@@ -1810,7 +1811,7 @@ impl Parser {
ast::StatementKind::Switch { selector, cases }
}
- "loop" => self.r#loop(lexer, ctx)?,
+ "loop" => self.r#loop(lexer, ctx, brace_nesting_level)?,
"while" => {
let _ = lexer.next();
let mut body = ast::Block::default();
@@ -1834,7 +1835,7 @@ impl Parser {
span,
});
- let (block, span) = self.block(lexer, ctx)?;
+ let (block, span) = self.block(lexer, ctx, brace_nesting_level)?;
body.stmts.push(ast::Statement {
kind: ast::StatementKind::Block(block),
span,
@@ -1857,7 +1858,9 @@ impl Parser {
let (_, span) = {
let ctx = &mut *ctx;
let block = &mut *block;
- lexer.capture_span(|lexer| self.statement(lexer, ctx, block))?
+ lexer.capture_span(|lexer| {
+ self.statement(lexer, ctx, block, brace_nesting_level)
+ })?
};
if block.stmts.len() != num_statements {
@@ -1902,7 +1905,7 @@ impl Parser {
lexer.expect(Token::Paren(')'))?;
}
- let (block, span) = self.block(lexer, ctx)?;
+ let (block, span) = self.block(lexer, ctx, brace_nesting_level)?;
body.stmts.push(ast::Statement {
kind: ast::StatementKind::Block(block),
span,
@@ -1964,13 +1967,15 @@ impl Parser {
&mut self,
lexer: &mut Lexer<'a>,
ctx: &mut ExpressionContext<'a, '_, '_>,
+ brace_nesting_level: u8,
) -> Result<ast::StatementKind<'a>, Error<'a>> {
let _ = lexer.next();
let mut body = ast::Block::default();
let mut continuing = ast::Block::default();
let mut break_if = None;
- lexer.expect(Token::Paren('{'))?;
+ let brace_span = lexer.expect_span(Token::Paren('{'))?;
+ let brace_nesting_level = Self::increase_brace_nesting(brace_nesting_level, brace_span)?;
ctx.local_table.push_scope();
@@ -1980,7 +1985,9 @@ impl Parser {
// the last thing in the loop body
// Expect a opening brace to start the continuing block
- lexer.expect(Token::Paren('{'))?;
+ let brace_span = lexer.expect_span(Token::Paren('{'))?;
+ let brace_nesting_level =
+ Self::increase_brace_nesting(brace_nesting_level, brace_span)?;
loop {
if lexer.skip(Token::Word("break")) {
// Branch for the `break if` statement, this statement
@@ -2009,7 +2016,7 @@ impl Parser {
break;
} else {
// Otherwise try to parse a statement
- self.statement(lexer, ctx, &mut continuing)?;
+ self.statement(lexer, ctx, &mut continuing, brace_nesting_level)?;
}
}
// Since the continuing block must be the last part of the loop body,
@@ -2023,7 +2030,7 @@ impl Parser {
break;
}
// Otherwise try to parse a statement
- self.statement(lexer, ctx, &mut body)?;
+ self.statement(lexer, ctx, &mut body, brace_nesting_level)?;
}
ctx.local_table.pop_scope();
@@ -2040,15 +2047,17 @@ impl Parser {
&mut self,
lexer: &mut Lexer<'a>,
ctx: &mut ExpressionContext<'a, '_, '_>,
+ brace_nesting_level: u8,
) -> Result<(ast::Block<'a>, Span), Error<'a>> {
self.push_rule_span(Rule::Block, lexer);
ctx.local_table.push_scope();
- lexer.expect(Token::Paren('{'))?;
+ let brace_span = lexer.expect_span(Token::Paren('{'))?;
+ let brace_nesting_level = Self::increase_brace_nesting(brace_nesting_level, brace_span)?;
let mut block = ast::Block::default();
while !lexer.skip(Token::Paren('}')) {
- self.statement(lexer, ctx, &mut block)?;
+ self.statement(lexer, ctx, &mut block, brace_nesting_level)?;
}
ctx.local_table.pop_scope();
@@ -2135,9 +2144,10 @@ impl Parser {
// do not use `self.block` here, since we must not push a new scope
lexer.expect(Token::Paren('{'))?;
+ let brace_nesting_level = 1;
let mut body = ast::Block::default();
while !lexer.skip(Token::Paren('}')) {
- self.statement(lexer, &mut ctx, &mut body)?;
+ self.statement(lexer, &mut ctx, &mut body, brace_nesting_level)?;
}
ctx.local_table.pop_scope();
@@ -2170,6 +2180,7 @@ impl Parser {
let mut early_depth_test = ParsedAttribute::default();
let (mut bind_index, mut bind_group) =
(ParsedAttribute::default(), ParsedAttribute::default());
+ let mut id = ParsedAttribute::default();
let mut dependencies = FastIndexSet::default();
let mut ctx = ExpressionContext {
@@ -2193,6 +2204,11 @@ impl Parser {
bind_group.set(self.general_expression(lexer, &mut ctx)?, name_span)?;
lexer.expect(Token::Paren(')'))?;
}
+ ("id", name_span) => {
+ lexer.expect(Token::Paren('('))?;
+ id.set(self.general_expression(lexer, &mut ctx)?, name_span)?;
+ lexer.expect(Token::Paren(')'))?;
+ }
("vertex", name_span) => {
stage.set(crate::ShaderStage::Vertex, name_span)?;
}
@@ -2283,6 +2299,30 @@ impl Parser {
Some(ast::GlobalDeclKind::Const(ast::Const { name, ty, init }))
}
+ (Token::Word("override"), _) => {
+ let name = lexer.next_ident()?;
+
+ let ty = if lexer.skip(Token::Separator(':')) {
+ Some(self.type_decl(lexer, &mut ctx)?)
+ } else {
+ None
+ };
+
+ let init = if lexer.skip(Token::Operation('=')) {
+ Some(self.general_expression(lexer, &mut ctx)?)
+ } else {
+ None
+ };
+
+ lexer.expect(Token::Separator(';'))?;
+
+ Some(ast::GlobalDeclKind::Override(ast::Override {
+ name,
+ id: id.value,
+ ty,
+ init,
+ }))
+ }
(Token::Word("var"), _) => {
let mut var = self.variable_decl(lexer, &mut ctx)?;
var.binding = binding.take();
@@ -2347,4 +2387,30 @@ impl Parser {
Ok(tu)
}
+
+ const fn increase_brace_nesting(
+ brace_nesting_level: u8,
+ brace_span: Span,
+ ) -> Result<u8, Error<'static>> {
+ // From [spec.](https://gpuweb.github.io/gpuweb/wgsl/#limits):
+ //
+ // > § 2.4. Limits
+ // >
+ // > …
+ // >
+ // > Maximum nesting depth of brace-enclosed statements in a function[:] 127
+ //
+ // _However_, we choose 64 instead because (a) it avoids stack overflows in CI and
+ // (b) we expect the limit to be decreased to 63 based on this conversation in
+ // WebGPU CTS upstream:
+ // <https://github.com/gpuweb/cts/pull/3389#discussion_r1543742701>
+ const BRACE_NESTING_MAXIMUM: u8 = 64;
+ if brace_nesting_level + 1 > BRACE_NESTING_MAXIMUM {
+ return Err(Error::ExceededLimitForNestedBraces {
+ span: brace_span,
+ limit: BRACE_NESTING_MAXIMUM,
+ });
+ }
+ Ok(brace_nesting_level + 1)
+ }
}
diff --git a/third_party/rust/naga/src/front/wgsl/to_wgsl.rs b/third_party/rust/naga/src/front/wgsl/to_wgsl.rs
index c8331ace09..63bc9f7317 100644
--- a/third_party/rust/naga/src/front/wgsl/to_wgsl.rs
+++ b/third_party/rust/naga/src/front/wgsl/to_wgsl.rs
@@ -226,7 +226,8 @@ mod tests {
let gctx = crate::proc::GlobalCtx {
types: &types,
constants: &crate::Arena::new(),
- const_expressions: &crate::Arena::new(),
+ overrides: &crate::Arena::new(),
+ global_expressions: &crate::Arena::new(),
};
let array = crate::TypeInner::Array {
base: mytype1,