summaryrefslogtreecommitdiffstats
path: root/third_party/rust/naga/src/front/wgsl/lower/mod.rs
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-06-12 05:43:14 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-06-12 05:43:14 +0000
commit8dd16259287f58f9273002717ec4d27e97127719 (patch)
tree3863e62a53829a84037444beab3abd4ed9dfc7d0 /third_party/rust/naga/src/front/wgsl/lower/mod.rs
parentReleasing progress-linux version 126.0.1-1~progress7.99u1. (diff)
downloadfirefox-8dd16259287f58f9273002717ec4d27e97127719.tar.xz
firefox-8dd16259287f58f9273002717ec4d27e97127719.zip
Merging upstream version 127.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/naga/src/front/wgsl/lower/mod.rs')
-rw-r--r--third_party/rust/naga/src/front/wgsl/lower/mod.rs414
1 files changed, 307 insertions, 107 deletions
diff --git a/third_party/rust/naga/src/front/wgsl/lower/mod.rs b/third_party/rust/naga/src/front/wgsl/lower/mod.rs
index 2ca6c182b7..e7cce17723 100644
--- a/third_party/rust/naga/src/front/wgsl/lower/mod.rs
+++ b/third_party/rust/naga/src/front/wgsl/lower/mod.rs
@@ -86,6 +86,8 @@ pub struct GlobalContext<'source, 'temp, 'out> {
module: &'out mut crate::Module,
const_typifier: &'temp mut Typifier,
+
+ global_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker,
}
impl<'source> GlobalContext<'source, '_, '_> {
@@ -97,6 +99,19 @@ impl<'source> GlobalContext<'source, '_, '_> {
module: self.module,
const_typifier: self.const_typifier,
expr_type: ExpressionContextType::Constant,
+ global_expression_kind_tracker: self.global_expression_kind_tracker,
+ }
+ }
+
+ fn as_override(&mut self) -> ExpressionContext<'source, '_, '_> {
+ ExpressionContext {
+ ast_expressions: self.ast_expressions,
+ globals: self.globals,
+ types: self.types,
+ module: self.module,
+ const_typifier: self.const_typifier,
+ expr_type: ExpressionContextType::Override,
+ global_expression_kind_tracker: self.global_expression_kind_tracker,
}
}
@@ -164,7 +179,8 @@ pub struct StatementContext<'source, 'temp, 'out> {
/// with the form of the expressions; it is also tracking whether WGSL says
/// we should consider them to be const. See the use of `force_non_const` in
/// the code for lowering `let` bindings.
- expression_constness: &'temp mut crate::proc::ExpressionConstnessTracker,
+ local_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker,
+ global_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker,
}
impl<'a, 'temp> StatementContext<'a, 'temp, '_> {
@@ -181,6 +197,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> {
types: self.types,
ast_expressions: self.ast_expressions,
const_typifier: self.const_typifier,
+ global_expression_kind_tracker: self.global_expression_kind_tracker,
module: self.module,
expr_type: ExpressionContextType::Runtime(RuntimeExpressionContext {
local_table: self.local_table,
@@ -188,7 +205,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> {
block,
emitter,
typifier: self.typifier,
- expression_constness: self.expression_constness,
+ local_expression_kind_tracker: self.local_expression_kind_tracker,
}),
}
}
@@ -200,6 +217,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> {
types: self.types,
module: self.module,
const_typifier: self.const_typifier,
+ global_expression_kind_tracker: self.global_expression_kind_tracker,
}
}
@@ -232,8 +250,8 @@ pub struct RuntimeExpressionContext<'temp, 'out> {
/// Which `Expression`s in `self.naga_expressions` are const expressions, in
/// the WGSL sense.
///
- /// See [`StatementContext::expression_constness`] for details.
- expression_constness: &'temp mut crate::proc::ExpressionConstnessTracker,
+ /// See [`StatementContext::local_expression_kind_tracker`] for details.
+ local_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker,
}
/// The type of Naga IR expression we are lowering an [`ast::Expression`] to.
@@ -253,6 +271,14 @@ pub enum ExpressionContextType<'temp, 'out> {
/// available in the [`ExpressionContext`], so this variant
/// carries no further information.
Constant,
+
+ /// We are lowering to an override expression, to be included in the module's
+ /// constant expression arena.
+ ///
+ /// Everything override expressions are allowed to refer to is
+ /// available in the [`ExpressionContext`], so this variant
+ /// carries no further information.
+ Override,
}
/// State for lowering an [`ast::Expression`] to Naga IR.
@@ -307,10 +333,11 @@ pub struct ExpressionContext<'source, 'temp, 'out> {
/// [`Module`]: crate::Module
module: &'out mut crate::Module,
- /// Type judgments for [`module::const_expressions`].
+ /// Type judgments for [`module::global_expressions`].
///
- /// [`module::const_expressions`]: crate::Module::const_expressions
+ /// [`module::global_expressions`]: crate::Module::global_expressions
const_typifier: &'temp mut Typifier,
+ global_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker,
/// Whether we are lowering a constant expression or a general
/// runtime expression, and the data needed in each case.
@@ -326,6 +353,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
const_typifier: self.const_typifier,
module: self.module,
expr_type: ExpressionContextType::Constant,
+ global_expression_kind_tracker: self.global_expression_kind_tracker,
}
}
@@ -336,6 +364,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
types: self.types,
module: self.module,
const_typifier: self.const_typifier,
+ global_expression_kind_tracker: self.global_expression_kind_tracker,
}
}
@@ -344,11 +373,20 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
ExpressionContextType::Runtime(ref mut rctx) => ConstantEvaluator::for_wgsl_function(
self.module,
&mut rctx.function.expressions,
- rctx.expression_constness,
+ rctx.local_expression_kind_tracker,
rctx.emitter,
rctx.block,
),
- ExpressionContextType::Constant => ConstantEvaluator::for_wgsl_module(self.module),
+ ExpressionContextType::Constant => ConstantEvaluator::for_wgsl_module(
+ self.module,
+ self.global_expression_kind_tracker,
+ false,
+ ),
+ ExpressionContextType::Override => ConstantEvaluator::for_wgsl_module(
+ self.module,
+ self.global_expression_kind_tracker,
+ true,
+ ),
}
}
@@ -358,24 +396,14 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
span: Span,
) -> Result<Handle<crate::Expression>, Error<'source>> {
let mut eval = self.as_const_evaluator();
- match eval.try_eval_and_append(&expr, span) {
- Ok(expr) => Ok(expr),
-
- // `expr` is not a constant expression. This is fine as
- // long as we're not building `Module::const_expressions`.
- Err(err) => match self.expr_type {
- ExpressionContextType::Runtime(ref mut rctx) => {
- Ok(rctx.function.expressions.append(expr, span))
- }
- ExpressionContextType::Constant => Err(Error::ConstantEvaluatorError(err, span)),
- },
- }
+ eval.try_eval_and_append(expr, span)
+ .map_err(|e| Error::ConstantEvaluatorError(e, span))
}
fn const_access(&self, handle: Handle<crate::Expression>) -> Option<u32> {
match self.expr_type {
ExpressionContextType::Runtime(ref ctx) => {
- if !ctx.expression_constness.is_const(handle) {
+ if !ctx.local_expression_kind_tracker.is_const(handle) {
return None;
}
@@ -385,20 +413,25 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
.ok()
}
ExpressionContextType::Constant => self.module.to_ctx().eval_expr_to_u32(handle).ok(),
+ ExpressionContextType::Override => None,
}
}
fn get_expression_span(&self, handle: Handle<crate::Expression>) -> Span {
match self.expr_type {
ExpressionContextType::Runtime(ref ctx) => ctx.function.expressions.get_span(handle),
- ExpressionContextType::Constant => self.module.const_expressions.get_span(handle),
+ ExpressionContextType::Constant | ExpressionContextType::Override => {
+ self.module.global_expressions.get_span(handle)
+ }
}
}
fn typifier(&self) -> &Typifier {
match self.expr_type {
ExpressionContextType::Runtime(ref ctx) => ctx.typifier,
- ExpressionContextType::Constant => self.const_typifier,
+ ExpressionContextType::Constant | ExpressionContextType::Override => {
+ self.const_typifier
+ }
}
}
@@ -408,7 +441,9 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
) -> Result<&mut RuntimeExpressionContext<'temp, 'out>, Error<'source>> {
match self.expr_type {
ExpressionContextType::Runtime(ref mut ctx) => Ok(ctx),
- ExpressionContextType::Constant => Err(Error::UnexpectedOperationInConstContext(span)),
+ ExpressionContextType::Constant | ExpressionContextType::Override => {
+ Err(Error::UnexpectedOperationInConstContext(span))
+ }
}
}
@@ -420,7 +455,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
) -> Result<crate::SwizzleComponent, Error<'source>> {
match self.expr_type {
ExpressionContextType::Runtime(ref rctx) => {
- if !rctx.expression_constness.is_const(expr) {
+ if !rctx.local_expression_kind_tracker.is_const(expr) {
return Err(Error::ExpectedConstExprConcreteIntegerScalar(
component_span,
));
@@ -445,7 +480,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
}
// This means a `gather` operation appeared in a constant expression.
// This error refers to the `gather` itself, not its "component" argument.
- ExpressionContextType::Constant => {
+ ExpressionContextType::Constant | ExpressionContextType::Override => {
Err(Error::UnexpectedOperationInConstContext(gather_span))
}
}
@@ -471,7 +506,9 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
// to also borrow self.module.types mutably below.
let typifier = match self.expr_type {
ExpressionContextType::Runtime(ref ctx) => ctx.typifier,
- ExpressionContextType::Constant => &*self.const_typifier,
+ ExpressionContextType::Constant | ExpressionContextType::Override => {
+ &*self.const_typifier
+ }
};
Ok(typifier.register_type(handle, &mut self.module.types))
}
@@ -514,10 +551,10 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
typifier = &mut *ctx.typifier;
expressions = &ctx.function.expressions;
}
- ExpressionContextType::Constant => {
+ ExpressionContextType::Constant | ExpressionContextType::Override => {
resolve_ctx = ResolveContext::with_locals(self.module, &empty_arena, &[]);
typifier = self.const_typifier;
- expressions = &self.module.const_expressions;
+ expressions = &self.module.global_expressions;
}
};
typifier
@@ -610,14 +647,14 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
rctx.block
.extend(rctx.emitter.finish(&rctx.function.expressions));
}
- ExpressionContextType::Constant => {}
+ ExpressionContextType::Constant | ExpressionContextType::Override => {}
}
let result = self.append_expression(expression, span);
match self.expr_type {
ExpressionContextType::Runtime(ref mut rctx) => {
rctx.emitter.start(&rctx.function.expressions);
}
- ExpressionContextType::Constant => {}
+ ExpressionContextType::Constant | ExpressionContextType::Override => {}
}
result
}
@@ -786,6 +823,7 @@ enum LoweredGlobalDecl {
Function(Handle<crate::Function>),
Var(Handle<crate::GlobalVariable>),
Const(Handle<crate::Constant>),
+ Override(Handle<crate::Override>),
Type(Handle<crate::Type>),
EntryPoint,
}
@@ -836,6 +874,29 @@ impl Texture {
}
}
+enum SubgroupGather {
+ BroadcastFirst,
+ Broadcast,
+ Shuffle,
+ ShuffleDown,
+ ShuffleUp,
+ ShuffleXor,
+}
+
+impl SubgroupGather {
+ pub fn map(word: &str) -> Option<Self> {
+ Some(match word {
+ "subgroupBroadcastFirst" => Self::BroadcastFirst,
+ "subgroupBroadcast" => Self::Broadcast,
+ "subgroupShuffle" => Self::Shuffle,
+ "subgroupShuffleDown" => Self::ShuffleDown,
+ "subgroupShuffleUp" => Self::ShuffleUp,
+ "subgroupShuffleXor" => Self::ShuffleXor,
+ _ => return None,
+ })
+ }
+}
+
pub struct Lowerer<'source, 'temp> {
index: &'temp Index<'source>,
layouter: Layouter,
@@ -861,6 +922,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
types: &tu.types,
module: &mut module,
const_typifier: &mut Typifier::new(),
+ global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker::new(),
};
for decl_handle in self.index.visit_ordered() {
@@ -877,7 +939,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let init;
if let Some(init_ast) = v.init {
- let mut ectx = ctx.as_const();
+ let mut ectx = ctx.as_override();
let lowered = self.expression_for_abstract(init_ast, &mut ectx)?;
let ty_res = crate::proc::TypeResolution::Handle(ty);
let converted = ectx
@@ -956,7 +1018,6 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let handle = ctx.module.constants.append(
crate::Constant {
name: Some(c.name.name.to_string()),
- r#override: crate::Override::None,
ty,
init,
},
@@ -966,6 +1027,65 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
ctx.globals
.insert(c.name.name, LoweredGlobalDecl::Const(handle));
}
+ ast::GlobalDeclKind::Override(ref o) => {
+ let init = o
+ .init
+ .map(|init| self.expression(init, &mut ctx.as_override()))
+ .transpose()?;
+ let inferred_type = init
+ .map(|init| ctx.as_const().register_type(init))
+ .transpose()?;
+
+ let explicit_ty =
+ o.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx))
+ .transpose()?;
+
+ let id =
+ o.id.map(|id| self.const_u32(id, &mut ctx.as_const()))
+ .transpose()?;
+
+ let id = if let Some((id, id_span)) = id {
+ Some(
+ u16::try_from(id)
+ .map_err(|_| Error::PipelineConstantIDValue(id_span))?,
+ )
+ } else {
+ None
+ };
+
+ let ty = match (explicit_ty, inferred_type) {
+ (Some(explicit_ty), Some(inferred_type)) => {
+ if explicit_ty == inferred_type {
+ explicit_ty
+ } else {
+ let gctx = ctx.module.to_ctx();
+ return Err(Error::InitializationTypeMismatch {
+ name: o.name.span,
+ expected: explicit_ty.to_wgsl(&gctx),
+ got: inferred_type.to_wgsl(&gctx),
+ });
+ }
+ }
+ (Some(explicit_ty), None) => explicit_ty,
+ (None, Some(inferred_type)) => inferred_type,
+ (None, None) => {
+ return Err(Error::DeclMissingTypeAndInit(o.name.span));
+ }
+ };
+
+ let handle = ctx.module.overrides.append(
+ crate::Override {
+ name: Some(o.name.name.to_string()),
+ id,
+ ty,
+ init,
+ },
+ span,
+ );
+
+ ctx.globals
+ .insert(o.name.name, LoweredGlobalDecl::Override(handle));
+ }
ast::GlobalDeclKind::Struct(ref s) => {
let handle = self.r#struct(s, span, &mut ctx)?;
ctx.globals
@@ -1000,6 +1120,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let mut local_table = FastHashMap::default();
let mut expressions = Arena::new();
let mut named_expressions = FastIndexMap::default();
+ let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
let arguments = f
.arguments
@@ -1011,6 +1132,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.append(crate::Expression::FunctionArgument(i as u32), arg.name.span);
local_table.insert(arg.handle, Typed::Plain(expr));
named_expressions.insert(expr, (arg.name.name.to_string(), arg.name.span));
+ local_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Runtime);
Ok(crate::FunctionArgument {
name: Some(arg.name.name.to_string()),
@@ -1053,7 +1175,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
named_expressions: &mut named_expressions,
types: ctx.types,
module: ctx.module,
- expression_constness: &mut crate::proc::ExpressionConstnessTracker::new(),
+ local_expression_kind_tracker: &mut local_expression_kind_tracker,
+ global_expression_kind_tracker: ctx.global_expression_kind_tracker,
};
let mut body = self.block(&f.body, false, &mut stmt_ctx)?;
ensure_block_returns(&mut body);
@@ -1132,7 +1255,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
// affects when errors must be reported, so we can't even
// treat suitable `let` bindings as constant as an
// optimization.
- ctx.expression_constness.force_non_const(value);
+ ctx.local_expression_kind_tracker.force_non_const(value);
let explicit_ty =
l.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx.as_global()))
@@ -1203,7 +1326,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
ty = explicit_ty;
initializer = None;
}
- (None, None) => return Err(Error::MissingType(v.name.span)),
+ (None, None) => return Err(Error::DeclMissingTypeAndInit(v.name.span)),
}
let (const_initializer, initializer) = {
@@ -1216,7 +1339,9 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
// - the initialization is not a constant
// expression, so its value depends on the
// state at the point of initialization.
- if is_inside_loop || !ctx.expression_constness.is_const(init) {
+ if is_inside_loop
+ || !ctx.local_expression_kind_tracker.is_const_or_override(init)
+ {
(None, Some(init))
} else {
(Some(init), None)
@@ -1469,6 +1594,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.function
.expressions
.append(crate::Expression::Binary { op, left, right }, stmt.span);
+ rctx.local_expression_kind_tracker
+ .insert(left, crate::proc::ExpressionKind::Runtime);
+ rctx.local_expression_kind_tracker
+ .insert(value, crate::proc::ExpressionKind::Runtime);
block.extend(emitter.finish(&ctx.function.expressions));
crate::Statement::Store {
@@ -1562,7 +1691,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
LoweredGlobalDecl::Const(handle) => {
Typed::Plain(crate::Expression::Constant(handle))
}
- _ => {
+ LoweredGlobalDecl::Override(handle) => {
+ Typed::Plain(crate::Expression::Override(handle))
+ }
+ LoweredGlobalDecl::Function(_)
+ | LoweredGlobalDecl::Type(_)
+ | LoweredGlobalDecl::EntryPoint => {
return Err(Error::Unexpected(span, ExpectedToken::Variable));
}
};
@@ -1819,9 +1953,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
)?;
Ok(Some(handle))
}
- Some(&LoweredGlobalDecl::Const(_) | &LoweredGlobalDecl::Var(_)) => {
- Err(Error::Unexpected(function.span, ExpectedToken::Function))
- }
+ Some(
+ &LoweredGlobalDecl::Const(_)
+ | &LoweredGlobalDecl::Override(_)
+ | &LoweredGlobalDecl::Var(_),
+ ) => Err(Error::Unexpected(function.span, ExpectedToken::Function)),
Some(&LoweredGlobalDecl::EntryPoint) => Err(Error::CalledEntryPoint(function.span)),
Some(&LoweredGlobalDecl::Function(function)) => {
let arguments = arguments
@@ -1835,9 +1971,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
rctx.block
.extend(rctx.emitter.finish(&rctx.function.expressions));
let result = has_result.then(|| {
- rctx.function
+ let result = rctx
+ .function
.expressions
- .append(crate::Expression::CallResult(function), span)
+ .append(crate::Expression::CallResult(function), span);
+ rctx.local_expression_kind_tracker
+ .insert(result, crate::proc::ExpressionKind::Runtime);
+ result
});
rctx.emitter.start(&rctx.function.expressions);
rctx.block.push(
@@ -1937,6 +2077,16 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}
} else if let Some(fun) = Texture::map(function.name) {
self.texture_sample_helper(fun, arguments, span, ctx)?
+ } else if let Some((op, cop)) = conv::map_subgroup_operation(function.name) {
+ return Ok(Some(
+ self.subgroup_operation_helper(span, op, cop, arguments, ctx)?,
+ ));
+ } else if let Some(mode) = SubgroupGather::map(function.name) {
+ return Ok(Some(
+ self.subgroup_gather_helper(span, mode, arguments, ctx)?,
+ ));
+ } else if let Some(fun) = crate::AtomicFunction::map(function.name) {
+ return Ok(Some(self.atomic_helper(span, fun, arguments, ctx)?));
} else {
match function.name {
"select" => {
@@ -1982,70 +2132,6 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.push(crate::Statement::Store { pointer, value }, span);
return Ok(None);
}
- "atomicAdd" => {
- return Ok(Some(self.atomic_helper(
- span,
- crate::AtomicFunction::Add,
- arguments,
- ctx,
- )?))
- }
- "atomicSub" => {
- return Ok(Some(self.atomic_helper(
- span,
- crate::AtomicFunction::Subtract,
- arguments,
- ctx,
- )?))
- }
- "atomicAnd" => {
- return Ok(Some(self.atomic_helper(
- span,
- crate::AtomicFunction::And,
- arguments,
- ctx,
- )?))
- }
- "atomicOr" => {
- return Ok(Some(self.atomic_helper(
- span,
- crate::AtomicFunction::InclusiveOr,
- arguments,
- ctx,
- )?))
- }
- "atomicXor" => {
- return Ok(Some(self.atomic_helper(
- span,
- crate::AtomicFunction::ExclusiveOr,
- arguments,
- ctx,
- )?))
- }
- "atomicMin" => {
- return Ok(Some(self.atomic_helper(
- span,
- crate::AtomicFunction::Min,
- arguments,
- ctx,
- )?))
- }
- "atomicMax" => {
- return Ok(Some(self.atomic_helper(
- span,
- crate::AtomicFunction::Max,
- arguments,
- ctx,
- )?))
- }
- "atomicExchange" => {
- return Ok(Some(self.atomic_helper(
- span,
- crate::AtomicFunction::Exchange { compare: None },
- arguments,
- ctx,
- )?))
- }
"atomicCompareExchangeWeak" => {
let mut args = ctx.prepare_args(arguments, 3, span);
@@ -2104,6 +2190,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.push(crate::Statement::Barrier(crate::Barrier::WORK_GROUP), span);
return Ok(None);
}
+ "subgroupBarrier" => {
+ ctx.prepare_args(arguments, 0, span).finish()?;
+
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ rctx.block
+ .push(crate::Statement::Barrier(crate::Barrier::SUB_GROUP), span);
+ return Ok(None);
+ }
"workgroupUniformLoad" => {
let mut args = ctx.prepare_args(arguments, 1, span);
let expr = args.next()?;
@@ -2311,6 +2405,22 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
)?;
return Ok(Some(handle));
}
+ "subgroupBallot" => {
+ let mut args = ctx.prepare_args(arguments, 0, span);
+ let predicate = if arguments.len() == 1 {
+ Some(self.expression(args.next()?, ctx)?)
+ } else {
+ None
+ };
+ args.finish()?;
+
+ let result = ctx
+ .interrupt_emitter(crate::Expression::SubgroupBallotResult, span)?;
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ rctx.block
+ .push(crate::Statement::SubgroupBallot { result, predicate }, span);
+ return Ok(Some(result));
+ }
_ => return Err(Error::UnknownIdent(function.span, function.name)),
}
};
@@ -2502,6 +2612,80 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
})
}
+ fn subgroup_operation_helper(
+ &mut self,
+ span: Span,
+ op: crate::SubgroupOperation,
+ collective_op: crate::CollectiveOperation,
+ arguments: &[Handle<ast::Expression<'source>>],
+ ctx: &mut ExpressionContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+
+ let argument = self.expression(args.next()?, ctx)?;
+ args.finish()?;
+
+ let ty = ctx.register_type(argument)?;
+
+ let result =
+ ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?;
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ rctx.block.push(
+ crate::Statement::SubgroupCollectiveOperation {
+ op,
+ collective_op,
+ argument,
+ result,
+ },
+ span,
+ );
+ Ok(result)
+ }
+
+ fn subgroup_gather_helper(
+ &mut self,
+ span: Span,
+ mode: SubgroupGather,
+ arguments: &[Handle<ast::Expression<'source>>],
+ ctx: &mut ExpressionContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ let mut args = ctx.prepare_args(arguments, 2, span);
+
+ let argument = self.expression(args.next()?, ctx)?;
+
+ use SubgroupGather as Sg;
+ let mode = if let Sg::BroadcastFirst = mode {
+ crate::GatherMode::BroadcastFirst
+ } else {
+ let index = self.expression(args.next()?, ctx)?;
+ match mode {
+ Sg::Broadcast => crate::GatherMode::Broadcast(index),
+ Sg::Shuffle => crate::GatherMode::Shuffle(index),
+ Sg::ShuffleDown => crate::GatherMode::ShuffleDown(index),
+ Sg::ShuffleUp => crate::GatherMode::ShuffleUp(index),
+ Sg::ShuffleXor => crate::GatherMode::ShuffleXor(index),
+ Sg::BroadcastFirst => unreachable!(),
+ }
+ };
+
+ args.finish()?;
+
+ let ty = ctx.register_type(argument)?;
+
+ let result =
+ ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?;
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ rctx.block.push(
+ crate::Statement::SubgroupGather {
+ mode,
+ argument,
+ result,
+ },
+ span,
+ );
+ Ok(result)
+ }
+
fn r#struct(
&mut self,
s: &ast::Struct<'source>,
@@ -2760,3 +2944,19 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}
}
}
+
+impl crate::AtomicFunction {
+ pub fn map(word: &str) -> Option<Self> {
+ Some(match word {
+ "atomicAdd" => crate::AtomicFunction::Add,
+ "atomicSub" => crate::AtomicFunction::Subtract,
+ "atomicAnd" => crate::AtomicFunction::And,
+ "atomicOr" => crate::AtomicFunction::InclusiveOr,
+ "atomicXor" => crate::AtomicFunction::ExclusiveOr,
+ "atomicMin" => crate::AtomicFunction::Min,
+ "atomicMax" => crate::AtomicFunction::Max,
+ "atomicExchange" => crate::AtomicFunction::Exchange { compare: None },
+ _ => return None,
+ })
+ }
+}