summaryrefslogtreecommitdiffstats
path: root/third_party/rust/naga/src/front
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/naga/src/front')
-rw-r--r--third_party/rust/naga/src/front/glsl/context.rs48
-rw-r--r--third_party/rust/naga/src/front/glsl/error.rs18
-rw-r--r--third_party/rust/naga/src/front/glsl/functions.rs10
-rw-r--r--third_party/rust/naga/src/front/glsl/mod.rs4
-rw-r--r--third_party/rust/naga/src/front/glsl/parser.rs17
-rw-r--r--third_party/rust/naga/src/front/glsl/parser/declarations.rs9
-rw-r--r--third_party/rust/naga/src/front/glsl/parser/functions.rs9
-rw-r--r--third_party/rust/naga/src/front/glsl/parser_tests.rs22
-rw-r--r--third_party/rust/naga/src/front/glsl/types.rs17
-rw-r--r--third_party/rust/naga/src/front/glsl/variables.rs1
-rw-r--r--third_party/rust/naga/src/front/spv/convert.rs5
-rw-r--r--third_party/rust/naga/src/front/spv/error.rs10
-rw-r--r--third_party/rust/naga/src/front/spv/function.rs13
-rw-r--r--third_party/rust/naga/src/front/spv/image.rs13
-rw-r--r--third_party/rust/naga/src/front/spv/mod.rs455
-rw-r--r--third_party/rust/naga/src/front/spv/null.rs8
-rw-r--r--third_party/rust/naga/src/front/wgsl/error.rs29
-rw-r--r--third_party/rust/naga/src/front/wgsl/index.rs1
-rw-r--r--third_party/rust/naga/src/front/wgsl/lower/mod.rs414
-rw-r--r--third_party/rust/naga/src/front/wgsl/mod.rs11
-rw-r--r--third_party/rust/naga/src/front/wgsl/parse/ast.rs9
-rw-r--r--third_party/rust/naga/src/front/wgsl/parse/conv.rs28
-rw-r--r--third_party/rust/naga/src/front/wgsl/parse/mod.rs106
-rw-r--r--third_party/rust/naga/src/front/wgsl/to_wgsl.rs3
24 files changed, 944 insertions, 316 deletions
diff --git a/third_party/rust/naga/src/front/glsl/context.rs b/third_party/rust/naga/src/front/glsl/context.rs
index f26c57965d..6ba7df593a 100644
--- a/third_party/rust/naga/src/front/glsl/context.rs
+++ b/third_party/rust/naga/src/front/glsl/context.rs
@@ -77,12 +77,19 @@ pub struct Context<'a> {
pub body: Block,
pub module: &'a mut crate::Module,
pub is_const: bool,
- /// Tracks the constness of `Expression`s residing in `self.expressions`
- pub expression_constness: crate::proc::ExpressionConstnessTracker,
+ /// Tracks the expression kind of `Expression`s residing in `self.expressions`
+ pub local_expression_kind_tracker: crate::proc::ExpressionKindTracker,
+ /// Tracks the expression kind of `Expression`s residing in `self.module.global_expressions`
+ pub global_expression_kind_tracker: &'a mut crate::proc::ExpressionKindTracker,
}
impl<'a> Context<'a> {
- pub fn new(frontend: &Frontend, module: &'a mut crate::Module, is_const: bool) -> Result<Self> {
+ pub fn new(
+ frontend: &Frontend,
+ module: &'a mut crate::Module,
+ is_const: bool,
+ global_expression_kind_tracker: &'a mut crate::proc::ExpressionKindTracker,
+ ) -> Result<Self> {
let mut this = Context {
expressions: Arena::new(),
locals: Arena::new(),
@@ -101,7 +108,8 @@ impl<'a> Context<'a> {
body: Block::new(),
module,
is_const: false,
- expression_constness: crate::proc::ExpressionConstnessTracker::new(),
+ local_expression_kind_tracker: crate::proc::ExpressionKindTracker::new(),
+ global_expression_kind_tracker,
};
this.emit_start();
@@ -249,40 +257,24 @@ impl<'a> Context<'a> {
pub fn add_expression(&mut self, expr: Expression, meta: Span) -> Result<Handle<Expression>> {
let mut eval = if self.is_const {
- crate::proc::ConstantEvaluator::for_glsl_module(self.module)
+ crate::proc::ConstantEvaluator::for_glsl_module(
+ self.module,
+ self.global_expression_kind_tracker,
+ )
} else {
crate::proc::ConstantEvaluator::for_glsl_function(
self.module,
&mut self.expressions,
- &mut self.expression_constness,
+ &mut self.local_expression_kind_tracker,
&mut self.emitter,
&mut self.body,
)
};
- let res = eval.try_eval_and_append(&expr, meta).map_err(|e| Error {
+ eval.try_eval_and_append(expr, meta).map_err(|e| Error {
kind: e.into(),
meta,
- });
-
- match res {
- Ok(expr) => Ok(expr),
- Err(e) => {
- if self.is_const {
- Err(e)
- } else {
- let needs_pre_emit = expr.needs_pre_emit();
- if needs_pre_emit {
- self.body.extend(self.emitter.finish(&self.expressions));
- }
- let h = self.expressions.append(expr, meta);
- if needs_pre_emit {
- self.emitter.start(&self.expressions);
- }
- Ok(h)
- }
- }
- }
+ })
}
/// Add variable to current scope
@@ -1479,7 +1471,7 @@ impl Index<Handle<Expression>> for Context<'_> {
fn index(&self, index: Handle<Expression>) -> &Self::Output {
if self.is_const {
- &self.module.const_expressions[index]
+ &self.module.global_expressions[index]
} else {
&self.expressions[index]
}
diff --git a/third_party/rust/naga/src/front/glsl/error.rs b/third_party/rust/naga/src/front/glsl/error.rs
index bd16ee30bc..e0771437e6 100644
--- a/third_party/rust/naga/src/front/glsl/error.rs
+++ b/third_party/rust/naga/src/front/glsl/error.rs
@@ -1,4 +1,5 @@
use super::token::TokenValue;
+use crate::SourceLocation;
use crate::{proc::ConstantEvaluatorError, Span};
use codespan_reporting::diagnostic::{Diagnostic, Label};
use codespan_reporting::files::SimpleFile;
@@ -137,14 +138,21 @@ pub struct Error {
pub meta: Span,
}
+impl Error {
+ /// Returns a [`SourceLocation`] for the error message.
+ pub fn location(&self, source: &str) -> Option<SourceLocation> {
+ Some(self.meta.location(source))
+ }
+}
+
/// A collection of errors returned during shader parsing.
#[derive(Clone, Debug)]
#[cfg_attr(test, derive(PartialEq))]
-pub struct ParseError {
+pub struct ParseErrors {
pub errors: Vec<Error>,
}
-impl ParseError {
+impl ParseErrors {
pub fn emit_to_writer(&self, writer: &mut impl WriteColor, source: &str) {
self.emit_to_writer_with_path(writer, source, "glsl");
}
@@ -172,19 +180,19 @@ impl ParseError {
}
}
-impl std::fmt::Display for ParseError {
+impl std::fmt::Display for ParseErrors {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
self.errors.iter().try_for_each(|e| write!(f, "{e:?}"))
}
}
-impl std::error::Error for ParseError {
+impl std::error::Error for ParseErrors {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
None
}
}
-impl From<Vec<Error>> for ParseError {
+impl From<Vec<Error>> for ParseErrors {
fn from(errors: Vec<Error>) -> Self {
Self { errors }
}
diff --git a/third_party/rust/naga/src/front/glsl/functions.rs b/third_party/rust/naga/src/front/glsl/functions.rs
index 01846eb814..fa1bbef56b 100644
--- a/third_party/rust/naga/src/front/glsl/functions.rs
+++ b/third_party/rust/naga/src/front/glsl/functions.rs
@@ -1236,6 +1236,8 @@ impl Frontend {
let pointer = ctx
.expressions
.append(Expression::GlobalVariable(arg.handle), Default::default());
+ ctx.local_expression_kind_tracker
+ .insert(pointer, crate::proc::ExpressionKind::Runtime);
let ty = ctx.module.global_variables[arg.handle].ty;
@@ -1256,6 +1258,8 @@ impl Frontend {
let value = ctx
.expressions
.append(Expression::FunctionArgument(idx), Default::default());
+ ctx.local_expression_kind_tracker
+ .insert(value, crate::proc::ExpressionKind::Runtime);
ctx.body
.push(Statement::Store { pointer, value }, Default::default());
},
@@ -1285,6 +1289,8 @@ impl Frontend {
let pointer = ctx
.expressions
.append(Expression::GlobalVariable(arg.handle), Default::default());
+ ctx.local_expression_kind_tracker
+ .insert(pointer, crate::proc::ExpressionKind::Runtime);
let ty = ctx.module.global_variables[arg.handle].ty;
@@ -1307,6 +1313,8 @@ impl Frontend {
let load = ctx
.expressions
.append(Expression::Load { pointer }, Default::default());
+ ctx.local_expression_kind_tracker
+ .insert(load, crate::proc::ExpressionKind::Runtime);
ctx.body.push(
Statement::Emit(ctx.expressions.range_from(len)),
Default::default(),
@@ -1329,6 +1337,8 @@ impl Frontend {
let res = ctx
.expressions
.append(Expression::Compose { ty, components }, Default::default());
+ ctx.local_expression_kind_tracker
+ .insert(res, crate::proc::ExpressionKind::Runtime);
ctx.body.push(
Statement::Emit(ctx.expressions.range_from(len)),
Default::default(),
diff --git a/third_party/rust/naga/src/front/glsl/mod.rs b/third_party/rust/naga/src/front/glsl/mod.rs
index 75f3929db4..ea202b2445 100644
--- a/third_party/rust/naga/src/front/glsl/mod.rs
+++ b/third_party/rust/naga/src/front/glsl/mod.rs
@@ -13,7 +13,7 @@ To begin, take a look at the documentation for the [`Frontend`].
*/
pub use ast::{Precision, Profile};
-pub use error::{Error, ErrorKind, ExpectedToken, ParseError};
+pub use error::{Error, ErrorKind, ExpectedToken, ParseErrors};
pub use token::TokenValue;
use crate::{proc::Layouter, FastHashMap, FastHashSet, Handle, Module, ShaderStage, Span, Type};
@@ -196,7 +196,7 @@ impl Frontend {
&mut self,
options: &Options,
source: &str,
- ) -> std::result::Result<Module, ParseError> {
+ ) -> std::result::Result<Module, ParseErrors> {
self.reset(options.stage);
let lexer = lex::Lexer::new(source, &options.defines);
diff --git a/third_party/rust/naga/src/front/glsl/parser.rs b/third_party/rust/naga/src/front/glsl/parser.rs
index 851d2e1d79..28e0808063 100644
--- a/third_party/rust/naga/src/front/glsl/parser.rs
+++ b/third_party/rust/naga/src/front/glsl/parser.rs
@@ -164,9 +164,15 @@ impl<'source> ParsingContext<'source> {
pub fn parse(&mut self, frontend: &mut Frontend) -> Result<Module> {
let mut module = Module::default();
+ let mut global_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
// Body and expression arena for global initialization
- let mut ctx = Context::new(frontend, &mut module, false)?;
+ let mut ctx = Context::new(
+ frontend,
+ &mut module,
+ false,
+ &mut global_expression_kind_tracker,
+ )?;
while self.peek(frontend).is_some() {
self.parse_external_declaration(frontend, &mut ctx)?;
@@ -196,7 +202,11 @@ impl<'source> ParsingContext<'source> {
frontend: &mut Frontend,
ctx: &mut Context,
) -> Result<(u32, Span)> {
- let (const_expr, meta) = self.parse_constant_expression(frontend, ctx.module)?;
+ let (const_expr, meta) = self.parse_constant_expression(
+ frontend,
+ ctx.module,
+ ctx.global_expression_kind_tracker,
+ )?;
let res = ctx.module.to_ctx().eval_expr_to_u32(const_expr);
@@ -219,8 +229,9 @@ impl<'source> ParsingContext<'source> {
&mut self,
frontend: &mut Frontend,
module: &mut Module,
+ global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker,
) -> Result<(Handle<Expression>, Span)> {
- let mut ctx = Context::new(frontend, module, true)?;
+ let mut ctx = Context::new(frontend, module, true, global_expression_kind_tracker)?;
let mut stmt_ctx = ctx.stmt_ctx();
let expr = self.parse_conditional(frontend, &mut ctx, &mut stmt_ctx, None)?;
diff --git a/third_party/rust/naga/src/front/glsl/parser/declarations.rs b/third_party/rust/naga/src/front/glsl/parser/declarations.rs
index f5e38fb016..2d253a378d 100644
--- a/third_party/rust/naga/src/front/glsl/parser/declarations.rs
+++ b/third_party/rust/naga/src/front/glsl/parser/declarations.rs
@@ -251,7 +251,7 @@ impl<'source> ParsingContext<'source> {
init.and_then(|expr| ctx.ctx.lift_up_const_expression(expr).ok());
late_initializer = None;
} else if let Some(init) = init {
- if ctx.is_inside_loop || !ctx.ctx.expression_constness.is_const(init) {
+ if ctx.is_inside_loop || !ctx.ctx.local_expression_kind_tracker.is_const(init) {
decl_initializer = None;
late_initializer = Some(init);
} else {
@@ -326,7 +326,12 @@ impl<'source> ParsingContext<'source> {
let result = ty.map(|ty| FunctionResult { ty, binding: None });
- let mut context = Context::new(frontend, ctx.module, false)?;
+ let mut context = Context::new(
+ frontend,
+ ctx.module,
+ false,
+ ctx.global_expression_kind_tracker,
+ )?;
self.parse_function_args(frontend, &mut context)?;
diff --git a/third_party/rust/naga/src/front/glsl/parser/functions.rs b/third_party/rust/naga/src/front/glsl/parser/functions.rs
index d428d74761..d0c889e4d3 100644
--- a/third_party/rust/naga/src/front/glsl/parser/functions.rs
+++ b/third_party/rust/naga/src/front/glsl/parser/functions.rs
@@ -192,10 +192,13 @@ impl<'source> ParsingContext<'source> {
TokenValue::Case => {
self.bump(frontend)?;
- let (const_expr, meta) =
- self.parse_constant_expression(frontend, ctx.module)?;
+ let (const_expr, meta) = self.parse_constant_expression(
+ frontend,
+ ctx.module,
+ ctx.global_expression_kind_tracker,
+ )?;
- match ctx.module.const_expressions[const_expr] {
+ match ctx.module.global_expressions[const_expr] {
Expression::Literal(Literal::I32(value)) => match uint {
// This unchecked cast isn't good, but since
// we only reach this code when the selector
diff --git a/third_party/rust/naga/src/front/glsl/parser_tests.rs b/third_party/rust/naga/src/front/glsl/parser_tests.rs
index 259052cd27..135765ca58 100644
--- a/third_party/rust/naga/src/front/glsl/parser_tests.rs
+++ b/third_party/rust/naga/src/front/glsl/parser_tests.rs
@@ -1,7 +1,7 @@
use super::{
ast::Profile,
error::ExpectedToken,
- error::{Error, ErrorKind, ParseError},
+ error::{Error, ErrorKind, ParseErrors},
token::TokenValue,
Frontend, Options, Span,
};
@@ -21,7 +21,7 @@ fn version() {
)
.err()
.unwrap(),
- ParseError {
+ ParseErrors {
errors: vec![Error {
kind: ErrorKind::InvalidVersion(99000),
meta: Span::new(9, 14)
@@ -37,7 +37,7 @@ fn version() {
)
.err()
.unwrap(),
- ParseError {
+ ParseErrors {
errors: vec![Error {
kind: ErrorKind::InvalidVersion(449),
meta: Span::new(9, 12)
@@ -53,7 +53,7 @@ fn version() {
)
.err()
.unwrap(),
- ParseError {
+ ParseErrors {
errors: vec![Error {
kind: ErrorKind::InvalidProfile("smart".into()),
meta: Span::new(13, 18),
@@ -69,7 +69,7 @@ fn version() {
)
.err()
.unwrap(),
- ParseError {
+ ParseErrors {
errors: vec![
Error {
kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedHash,),
@@ -455,7 +455,7 @@ fn functions() {
)
.err()
.unwrap(),
- ParseError {
+ ParseErrors {
errors: vec![Error {
kind: ErrorKind::SemanticError("Function already defined".into()),
meta: Span::new(134, 152),
@@ -539,7 +539,7 @@ fn constants() {
let mut types = module.types.iter();
let mut constants = module.constants.iter();
- let mut const_expressions = module.const_expressions.iter();
+ let mut global_expressions = module.global_expressions.iter();
let (ty_handle, ty) = types.next().unwrap();
assert_eq!(
@@ -550,14 +550,13 @@ fn constants() {
}
);
- let (init_handle, init) = const_expressions.next().unwrap();
+ let (init_handle, init) = global_expressions.next().unwrap();
assert_eq!(init, &Expression::Literal(crate::Literal::F32(1.0)));
assert_eq!(
constants.next().unwrap().1,
&Constant {
name: Some("a".to_owned()),
- r#override: crate::Override::None,
ty: ty_handle,
init: init_handle
}
@@ -567,7 +566,6 @@ fn constants() {
constants.next().unwrap().1,
&Constant {
name: Some("b".to_owned()),
- r#override: crate::Override::None,
ty: ty_handle,
init: init_handle
}
@@ -636,7 +634,7 @@ fn implicit_conversions() {
)
.err()
.unwrap(),
- ParseError {
+ ParseErrors {
errors: vec![Error {
kind: ErrorKind::SemanticError("Unknown function \'test\'".into()),
meta: Span::new(156, 165),
@@ -660,7 +658,7 @@ fn implicit_conversions() {
)
.err()
.unwrap(),
- ParseError {
+ ParseErrors {
errors: vec![Error {
kind: ErrorKind::SemanticError("Ambiguous best function for \'test\'".into()),
meta: Span::new(158, 165),
diff --git a/third_party/rust/naga/src/front/glsl/types.rs b/third_party/rust/naga/src/front/glsl/types.rs
index e87d76fffc..f6836169c0 100644
--- a/third_party/rust/naga/src/front/glsl/types.rs
+++ b/third_party/rust/naga/src/front/glsl/types.rs
@@ -233,7 +233,7 @@ impl Context<'_> {
};
let expressions = if self.is_const {
- &self.module.const_expressions
+ &self.module.global_expressions
} else {
&self.expressions
};
@@ -330,23 +330,25 @@ impl Context<'_> {
expr: Handle<Expression>,
) -> Result<Handle<Expression>> {
let meta = self.expressions.get_span(expr);
- Ok(match self.expressions[expr] {
+ let h = match self.expressions[expr] {
ref expr @ (Expression::Literal(_)
| Expression::Constant(_)
- | Expression::ZeroValue(_)) => self.module.const_expressions.append(expr.clone(), meta),
+ | Expression::ZeroValue(_)) => {
+ self.module.global_expressions.append(expr.clone(), meta)
+ }
Expression::Compose { ty, ref components } => {
let mut components = components.clone();
for component in &mut components {
*component = self.lift_up_const_expression(*component)?;
}
self.module
- .const_expressions
+ .global_expressions
.append(Expression::Compose { ty, components }, meta)
}
Expression::Splat { size, value } => {
let value = self.lift_up_const_expression(value)?;
self.module
- .const_expressions
+ .global_expressions
.append(Expression::Splat { size, value }, meta)
}
_ => {
@@ -355,6 +357,9 @@ impl Context<'_> {
meta,
})
}
- })
+ };
+ self.global_expression_kind_tracker
+ .insert(h, crate::proc::ExpressionKind::Const);
+ Ok(h)
}
}
diff --git a/third_party/rust/naga/src/front/glsl/variables.rs b/third_party/rust/naga/src/front/glsl/variables.rs
index 9d2e7a0e7b..0725fbd94f 100644
--- a/third_party/rust/naga/src/front/glsl/variables.rs
+++ b/third_party/rust/naga/src/front/glsl/variables.rs
@@ -472,7 +472,6 @@ impl Frontend {
let constant = Constant {
name: name.clone(),
- r#override: crate::Override::None,
ty,
init,
};
diff --git a/third_party/rust/naga/src/front/spv/convert.rs b/third_party/rust/naga/src/front/spv/convert.rs
index f0a714fbeb..a6bf0e0451 100644
--- a/third_party/rust/naga/src/front/spv/convert.rs
+++ b/third_party/rust/naga/src/front/spv/convert.rs
@@ -153,6 +153,11 @@ pub(super) fn map_builtin(word: spirv::Word, invariant: bool) -> Result<crate::B
Some(Bi::WorkgroupId) => crate::BuiltIn::WorkGroupId,
Some(Bi::WorkgroupSize) => crate::BuiltIn::WorkGroupSize,
Some(Bi::NumWorkgroups) => crate::BuiltIn::NumWorkGroups,
+ // subgroup
+ Some(Bi::NumSubgroups) => crate::BuiltIn::NumSubgroups,
+ Some(Bi::SubgroupId) => crate::BuiltIn::SubgroupId,
+ Some(Bi::SubgroupSize) => crate::BuiltIn::SubgroupSize,
+ Some(Bi::SubgroupLocalInvocationId) => crate::BuiltIn::SubgroupInvocationId,
_ => return Err(Error::UnsupportedBuiltIn(word)),
})
}
diff --git a/third_party/rust/naga/src/front/spv/error.rs b/third_party/rust/naga/src/front/spv/error.rs
index af025636c0..44beadce98 100644
--- a/third_party/rust/naga/src/front/spv/error.rs
+++ b/third_party/rust/naga/src/front/spv/error.rs
@@ -5,7 +5,7 @@ use codespan_reporting::files::SimpleFile;
use codespan_reporting::term;
use termcolor::{NoColor, WriteColor};
-#[derive(Debug, thiserror::Error)]
+#[derive(Clone, Debug, thiserror::Error)]
pub enum Error {
#[error("invalid header")]
InvalidHeader,
@@ -58,6 +58,8 @@ pub enum Error {
UnknownBinaryOperator(spirv::Op),
#[error("unknown relational function {0:?}")]
UnknownRelationalFunction(spirv::Op),
+ #[error("unsupported group operation %{0}")]
+ UnsupportedGroupOperation(spirv::Word),
#[error("invalid parameter {0:?}")]
InvalidParameter(spirv::Op),
#[error("invalid operand count {1} for {0:?}")]
@@ -118,8 +120,8 @@ pub enum Error {
ControlFlowGraphCycle(crate::front::spv::BlockId),
#[error("recursive function call %{0}")]
FunctionCallCycle(spirv::Word),
- #[error("invalid array size {0:?}")]
- InvalidArraySize(Handle<crate::Constant>),
+ #[error("invalid array size %{0}")]
+ InvalidArraySize(spirv::Word),
#[error("invalid barrier scope %{0}")]
InvalidBarrierScope(spirv::Word),
#[error("invalid barrier memory semantics %{0}")]
@@ -130,6 +132,8 @@ pub enum Error {
come from a binding)"
)]
NonBindingArrayOfImageOrSamplers,
+ #[error("naga only supports specialization constant IDs up to 65535 but was given {0}")]
+ SpecIdTooHigh(u32),
}
impl Error {
diff --git a/third_party/rust/naga/src/front/spv/function.rs b/third_party/rust/naga/src/front/spv/function.rs
index e81ecf5c9b..113ca56313 100644
--- a/third_party/rust/naga/src/front/spv/function.rs
+++ b/third_party/rust/naga/src/front/spv/function.rs
@@ -59,8 +59,11 @@ impl<I: Iterator<Item = u32>> super::Frontend<I> {
})
},
local_variables: Arena::new(),
- expressions: self
- .make_expression_storage(&module.global_variables, &module.constants),
+ expressions: self.make_expression_storage(
+ &module.global_variables,
+ &module.constants,
+ &module.overrides,
+ ),
named_expressions: crate::NamedExpressions::default(),
body: crate::Block::new(),
}
@@ -128,7 +131,8 @@ impl<I: Iterator<Item = u32>> super::Frontend<I> {
expressions: &mut fun.expressions,
local_arena: &mut fun.local_variables,
const_arena: &mut module.constants,
- const_expressions: &mut module.const_expressions,
+ overrides: &mut module.overrides,
+ global_expressions: &mut module.global_expressions,
type_arena: &module.types,
global_arena: &module.global_variables,
arguments: &fun.arguments,
@@ -581,7 +585,8 @@ impl<'function> BlockContext<'function> {
crate::proc::GlobalCtx {
types: self.type_arena,
constants: self.const_arena,
- const_expressions: self.const_expressions,
+ overrides: self.overrides,
+ global_expressions: self.global_expressions,
}
}
diff --git a/third_party/rust/naga/src/front/spv/image.rs b/third_party/rust/naga/src/front/spv/image.rs
index 0f25dd626b..284c4cf7fd 100644
--- a/third_party/rust/naga/src/front/spv/image.rs
+++ b/third_party/rust/naga/src/front/spv/image.rs
@@ -507,11 +507,14 @@ impl<I: Iterator<Item = u32>> super::Frontend<I> {
}
spirv::ImageOperands::CONST_OFFSET => {
let offset_constant = self.next()?;
- let offset_handle = self.lookup_constant.lookup(offset_constant)?.handle;
- let offset_handle = ctx.const_expressions.append(
- crate::Expression::Constant(offset_handle),
- Default::default(),
- );
+ let offset_expr = self
+ .lookup_constant
+ .lookup(offset_constant)?
+ .inner
+ .to_expr();
+ let offset_handle = ctx
+ .global_expressions
+ .append(offset_expr, Default::default());
offset = Some(offset_handle);
words_left -= 1;
}
diff --git a/third_party/rust/naga/src/front/spv/mod.rs b/third_party/rust/naga/src/front/spv/mod.rs
index b793448597..7ac5a18cd6 100644
--- a/third_party/rust/naga/src/front/spv/mod.rs
+++ b/third_party/rust/naga/src/front/spv/mod.rs
@@ -196,7 +196,7 @@ struct Decoration {
location: Option<spirv::Word>,
desc_set: Option<spirv::Word>,
desc_index: Option<spirv::Word>,
- specialization: Option<spirv::Word>,
+ specialization_constant_id: Option<spirv::Word>,
storage_buffer: bool,
offset: Option<spirv::Word>,
array_stride: Option<NonZeroU32>,
@@ -216,11 +216,6 @@ impl Decoration {
}
}
- fn specialization(&self) -> crate::Override {
- self.specialization
- .map_or(crate::Override::None, crate::Override::ByNameOrId)
- }
-
const fn resource_binding(&self) -> Option<crate::ResourceBinding> {
match *self {
Decoration {
@@ -284,8 +279,23 @@ struct LookupType {
}
#[derive(Debug)]
+enum Constant {
+ Constant(Handle<crate::Constant>),
+ Override(Handle<crate::Override>),
+}
+
+impl Constant {
+ const fn to_expr(&self) -> crate::Expression {
+ match *self {
+ Self::Constant(c) => crate::Expression::Constant(c),
+ Self::Override(o) => crate::Expression::Override(o),
+ }
+ }
+}
+
+#[derive(Debug)]
struct LookupConstant {
- handle: Handle<crate::Constant>,
+ inner: Constant,
type_id: spirv::Word,
}
@@ -537,7 +547,8 @@ struct BlockContext<'function> {
local_arena: &'function mut Arena<crate::LocalVariable>,
/// Constants arena of the module being processed
const_arena: &'function mut Arena<crate::Constant>,
- const_expressions: &'function mut Arena<crate::Expression>,
+ overrides: &'function mut Arena<crate::Override>,
+ global_expressions: &'function mut Arena<crate::Expression>,
/// Type arena of the module being processed
type_arena: &'function UniqueArena<crate::Type>,
/// Global arena of the module being processed
@@ -757,7 +768,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
dec.matrix_major = Some(Majority::Row);
}
spirv::Decoration::SpecId => {
- dec.specialization = Some(self.next()?);
+ dec.specialization_constant_id = Some(self.next()?);
}
other => {
log::warn!("Unknown decoration {:?}", other);
@@ -1393,10 +1404,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
inst.expect(5)?;
let init_id = self.next()?;
let lconst = self.lookup_constant.lookup(init_id)?;
- Some(
- ctx.expressions
- .append(crate::Expression::Constant(lconst.handle), span),
- )
+ Some(ctx.expressions.append(lconst.inner.to_expr(), span))
} else {
None
};
@@ -3650,9 +3658,9 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?;
let semantics_const = self.lookup_constant.lookup(semantics_id)?;
- let exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle)
+ let exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner)
.ok_or(Error::InvalidBarrierScope(exec_scope_id))?;
- let semantics = resolve_constant(ctx.gctx(), semantics_const.handle)
+ let semantics = resolve_constant(ctx.gctx(), &semantics_const.inner)
.ok_or(Error::InvalidBarrierMemorySemantics(semantics_id))?;
if exec_scope == spirv::Scope::Workgroup as u32 {
@@ -3692,6 +3700,254 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
},
);
}
+ Op::GroupNonUniformBallot => {
+ inst.expect(5)?;
+ block.extend(emitter.finish(ctx.expressions));
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let exec_scope_id = self.next()?;
+ let predicate_id = self.next()?;
+
+ let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?;
+ let _exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner)
+ .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32)
+ .ok_or(Error::InvalidBarrierScope(exec_scope_id))?;
+
+ let predicate = if self
+ .lookup_constant
+ .lookup(predicate_id)
+ .ok()
+ .filter(|predicate_const| match predicate_const.inner {
+ Constant::Constant(constant) => matches!(
+ ctx.gctx().global_expressions[ctx.gctx().constants[constant].init],
+ crate::Expression::Literal(crate::Literal::Bool(true)),
+ ),
+ Constant::Override(_) => false,
+ })
+ .is_some()
+ {
+ None
+ } else {
+ let predicate_lookup = self.lookup_expression.lookup(predicate_id)?;
+ let predicate_handle = get_expr_handle!(predicate_id, predicate_lookup);
+ Some(predicate_handle)
+ };
+
+ let result_handle = ctx
+ .expressions
+ .append(crate::Expression::SubgroupBallotResult, span);
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: result_handle,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+
+ block.push(
+ crate::Statement::SubgroupBallot {
+ result: result_handle,
+ predicate,
+ },
+ span,
+ );
+ emitter.start(ctx.expressions);
+ }
+ spirv::Op::GroupNonUniformAll
+ | spirv::Op::GroupNonUniformAny
+ | spirv::Op::GroupNonUniformIAdd
+ | spirv::Op::GroupNonUniformFAdd
+ | spirv::Op::GroupNonUniformIMul
+ | spirv::Op::GroupNonUniformFMul
+ | spirv::Op::GroupNonUniformSMax
+ | spirv::Op::GroupNonUniformUMax
+ | spirv::Op::GroupNonUniformFMax
+ | spirv::Op::GroupNonUniformSMin
+ | spirv::Op::GroupNonUniformUMin
+ | spirv::Op::GroupNonUniformFMin
+ | spirv::Op::GroupNonUniformBitwiseAnd
+ | spirv::Op::GroupNonUniformBitwiseOr
+ | spirv::Op::GroupNonUniformBitwiseXor
+ | spirv::Op::GroupNonUniformLogicalAnd
+ | spirv::Op::GroupNonUniformLogicalOr
+ | spirv::Op::GroupNonUniformLogicalXor => {
+ block.extend(emitter.finish(ctx.expressions));
+ inst.expect(
+ if matches!(
+ inst.op,
+ spirv::Op::GroupNonUniformAll | spirv::Op::GroupNonUniformAny
+ ) {
+ 5
+ } else {
+ 6
+ },
+ )?;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let exec_scope_id = self.next()?;
+ let collective_op_id = match inst.op {
+ spirv::Op::GroupNonUniformAll | spirv::Op::GroupNonUniformAny => {
+ crate::CollectiveOperation::Reduce
+ }
+ _ => {
+ let group_op_id = self.next()?;
+ match spirv::GroupOperation::from_u32(group_op_id) {
+ Some(spirv::GroupOperation::Reduce) => {
+ crate::CollectiveOperation::Reduce
+ }
+ Some(spirv::GroupOperation::InclusiveScan) => {
+ crate::CollectiveOperation::InclusiveScan
+ }
+ Some(spirv::GroupOperation::ExclusiveScan) => {
+ crate::CollectiveOperation::ExclusiveScan
+ }
+ _ => return Err(Error::UnsupportedGroupOperation(group_op_id)),
+ }
+ }
+ };
+ let argument_id = self.next()?;
+
+ let argument_lookup = self.lookup_expression.lookup(argument_id)?;
+ let argument_handle = get_expr_handle!(argument_id, argument_lookup);
+
+ let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?;
+ let _exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner)
+ .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32)
+ .ok_or(Error::InvalidBarrierScope(exec_scope_id))?;
+
+ let op_id = match inst.op {
+ spirv::Op::GroupNonUniformAll => crate::SubgroupOperation::All,
+ spirv::Op::GroupNonUniformAny => crate::SubgroupOperation::Any,
+ spirv::Op::GroupNonUniformIAdd | spirv::Op::GroupNonUniformFAdd => {
+ crate::SubgroupOperation::Add
+ }
+ spirv::Op::GroupNonUniformIMul | spirv::Op::GroupNonUniformFMul => {
+ crate::SubgroupOperation::Mul
+ }
+ spirv::Op::GroupNonUniformSMax
+ | spirv::Op::GroupNonUniformUMax
+ | spirv::Op::GroupNonUniformFMax => crate::SubgroupOperation::Max,
+ spirv::Op::GroupNonUniformSMin
+ | spirv::Op::GroupNonUniformUMin
+ | spirv::Op::GroupNonUniformFMin => crate::SubgroupOperation::Min,
+ spirv::Op::GroupNonUniformBitwiseAnd
+ | spirv::Op::GroupNonUniformLogicalAnd => crate::SubgroupOperation::And,
+ spirv::Op::GroupNonUniformBitwiseOr
+ | spirv::Op::GroupNonUniformLogicalOr => crate::SubgroupOperation::Or,
+ spirv::Op::GroupNonUniformBitwiseXor
+ | spirv::Op::GroupNonUniformLogicalXor => crate::SubgroupOperation::Xor,
+ _ => unreachable!(),
+ };
+
+ let result_type = self.lookup_type.lookup(result_type_id)?;
+
+ let result_handle = ctx.expressions.append(
+ crate::Expression::SubgroupOperationResult {
+ ty: result_type.handle,
+ },
+ span,
+ );
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: result_handle,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+
+ block.push(
+ crate::Statement::SubgroupCollectiveOperation {
+ result: result_handle,
+ op: op_id,
+ collective_op: collective_op_id,
+ argument: argument_handle,
+ },
+ span,
+ );
+ emitter.start(ctx.expressions);
+ }
+ Op::GroupNonUniformBroadcastFirst
+ | Op::GroupNonUniformBroadcast
+ | Op::GroupNonUniformShuffle
+ | Op::GroupNonUniformShuffleDown
+ | Op::GroupNonUniformShuffleUp
+ | Op::GroupNonUniformShuffleXor => {
+ inst.expect(
+ if matches!(inst.op, spirv::Op::GroupNonUniformBroadcastFirst) {
+ 5
+ } else {
+ 6
+ },
+ )?;
+ block.extend(emitter.finish(ctx.expressions));
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let exec_scope_id = self.next()?;
+ let argument_id = self.next()?;
+
+ let argument_lookup = self.lookup_expression.lookup(argument_id)?;
+ let argument_handle = get_expr_handle!(argument_id, argument_lookup);
+
+ let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?;
+ let _exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner)
+ .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32)
+ .ok_or(Error::InvalidBarrierScope(exec_scope_id))?;
+
+ let mode = if matches!(inst.op, spirv::Op::GroupNonUniformBroadcastFirst) {
+ crate::GatherMode::BroadcastFirst
+ } else {
+ let index_id = self.next()?;
+ let index_lookup = self.lookup_expression.lookup(index_id)?;
+ let index_handle = get_expr_handle!(index_id, index_lookup);
+ match inst.op {
+ spirv::Op::GroupNonUniformBroadcast => {
+ crate::GatherMode::Broadcast(index_handle)
+ }
+ spirv::Op::GroupNonUniformShuffle => {
+ crate::GatherMode::Shuffle(index_handle)
+ }
+ spirv::Op::GroupNonUniformShuffleDown => {
+ crate::GatherMode::ShuffleDown(index_handle)
+ }
+ spirv::Op::GroupNonUniformShuffleUp => {
+ crate::GatherMode::ShuffleUp(index_handle)
+ }
+ spirv::Op::GroupNonUniformShuffleXor => {
+ crate::GatherMode::ShuffleXor(index_handle)
+ }
+ _ => unreachable!(),
+ }
+ };
+
+ let result_type = self.lookup_type.lookup(result_type_id)?;
+
+ let result_handle = ctx.expressions.append(
+ crate::Expression::SubgroupOperationResult {
+ ty: result_type.handle,
+ },
+ span,
+ );
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: result_handle,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+
+ block.push(
+ crate::Statement::SubgroupGather {
+ result: result_handle,
+ mode,
+ argument: argument_handle,
+ },
+ span,
+ );
+ emitter.start(ctx.expressions);
+ }
_ => return Err(Error::UnsupportedInstruction(self.state, inst.op)),
}
};
@@ -3713,6 +3969,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
&mut self,
globals: &Arena<crate::GlobalVariable>,
constants: &Arena<crate::Constant>,
+ overrides: &Arena<crate::Override>,
) -> Arena<crate::Expression> {
let mut expressions = Arena::new();
#[allow(clippy::panic)]
@@ -3737,8 +3994,11 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
}
// register constants
for (&id, con) in self.lookup_constant.iter() {
- let span = constants.get_span(con.handle);
- let handle = expressions.append(crate::Expression::Constant(con.handle), span);
+ let (expr, span) = match con.inner {
+ Constant::Constant(c) => (crate::Expression::Constant(c), constants.get_span(c)),
+ Constant::Override(o) => (crate::Expression::Override(o), overrides.get_span(o)),
+ };
+ let handle = expressions.append(expr, span);
self.lookup_expression.insert(
id,
LookupExpression {
@@ -3812,7 +4072,10 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
| S::Store { .. }
| S::ImageStore { .. }
| S::Atomic { .. }
- | S::RayQuery { .. } => {}
+ | S::RayQuery { .. }
+ | S::SubgroupBallot { .. }
+ | S::SubgroupCollectiveOperation { .. }
+ | S::SubgroupGather { .. } => {}
S::Call {
function: ref mut callee,
ref arguments,
@@ -3944,10 +4207,16 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
Op::TypeSampledImage => self.parse_type_sampled_image(inst),
Op::TypeSampler => self.parse_type_sampler(inst, &mut module),
Op::Constant | Op::SpecConstant => self.parse_constant(inst, &mut module),
- Op::ConstantComposite => self.parse_composite_constant(inst, &mut module),
+ Op::ConstantComposite | Op::SpecConstantComposite => {
+ self.parse_composite_constant(inst, &mut module)
+ }
Op::ConstantNull | Op::Undef => self.parse_null_constant(inst, &mut module),
- Op::ConstantTrue => self.parse_bool_constant(inst, true, &mut module),
- Op::ConstantFalse => self.parse_bool_constant(inst, false, &mut module),
+ Op::ConstantTrue | Op::SpecConstantTrue => {
+ self.parse_bool_constant(inst, true, &mut module)
+ }
+ Op::ConstantFalse | Op::SpecConstantFalse => {
+ self.parse_bool_constant(inst, false, &mut module)
+ }
Op::Variable => self.parse_global_variable(inst, &mut module),
Op::Function => {
self.switch(ModuleState::Function, inst.op)?;
@@ -4504,9 +4773,9 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let length_id = self.next()?;
let length_const = self.lookup_constant.lookup(length_id)?;
- let size = resolve_constant(module.to_ctx(), length_const.handle)
+ let size = resolve_constant(module.to_ctx(), &length_const.inner)
.and_then(NonZeroU32::new)
- .ok_or(Error::InvalidArraySize(length_const.handle))?;
+ .ok_or(Error::InvalidArraySize(length_id))?;
let decor = self.future_decor.remove(&id).unwrap_or_default();
let base = self.lookup_type.lookup(type_id)?.handle;
@@ -4919,29 +5188,13 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
_ => return Err(Error::UnsupportedType(type_lookup.handle)),
};
- let decor = self.future_decor.remove(&id).unwrap_or_default();
-
let span = self.span_from_with_op(start);
let init = module
- .const_expressions
+ .global_expressions
.append(crate::Expression::Literal(literal), span);
- self.lookup_constant.insert(
- id,
- LookupConstant {
- handle: module.constants.append(
- crate::Constant {
- r#override: decor.specialization(),
- name: decor.name,
- ty,
- init,
- },
- span,
- ),
- type_id,
- },
- );
- Ok(())
+
+ self.insert_parsed_constant(module, id, type_id, ty, init, span)
}
fn parse_composite_constant(
@@ -4965,34 +5218,18 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let span = self.span_from_with_op(start);
let constant = self.lookup_constant.lookup(component_id)?;
let expr = module
- .const_expressions
- .append(crate::Expression::Constant(constant.handle), span);
+ .global_expressions
+ .append(constant.inner.to_expr(), span);
components.push(expr);
}
- let decor = self.future_decor.remove(&id).unwrap_or_default();
-
let span = self.span_from_with_op(start);
let init = module
- .const_expressions
+ .global_expressions
.append(crate::Expression::Compose { ty, components }, span);
- self.lookup_constant.insert(
- id,
- LookupConstant {
- handle: module.constants.append(
- crate::Constant {
- r#override: decor.specialization(),
- name: decor.name,
- ty,
- init,
- },
- span,
- ),
- type_id,
- },
- );
- Ok(())
+
+ self.insert_parsed_constant(module, id, type_id, ty, init, span)
}
fn parse_null_constant(
@@ -5010,23 +5247,11 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let type_lookup = self.lookup_type.lookup(type_id)?;
let ty = type_lookup.handle;
- let decor = self.future_decor.remove(&id).unwrap_or_default();
-
let init = module
- .const_expressions
+ .global_expressions
.append(crate::Expression::ZeroValue(ty), span);
- let handle = module.constants.append(
- crate::Constant {
- r#override: decor.specialization(),
- name: decor.name,
- ty,
- init,
- },
- span,
- );
- self.lookup_constant
- .insert(id, LookupConstant { handle, type_id });
- Ok(())
+
+ self.insert_parsed_constant(module, id, type_id, ty, init, span)
}
fn parse_bool_constant(
@@ -5045,27 +5270,44 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let type_lookup = self.lookup_type.lookup(type_id)?;
let ty = type_lookup.handle;
- let decor = self.future_decor.remove(&id).unwrap_or_default();
-
- let init = module.const_expressions.append(
+ let init = module.global_expressions.append(
crate::Expression::Literal(crate::Literal::Bool(value)),
span,
);
- self.lookup_constant.insert(
- id,
- LookupConstant {
- handle: module.constants.append(
- crate::Constant {
- r#override: decor.specialization(),
- name: decor.name,
- ty,
- init,
- },
- span,
- ),
- type_id,
- },
- );
+
+ self.insert_parsed_constant(module, id, type_id, ty, init, span)
+ }
+
+ fn insert_parsed_constant(
+ &mut self,
+ module: &mut crate::Module,
+ id: u32,
+ type_id: u32,
+ ty: Handle<crate::Type>,
+ init: Handle<crate::Expression>,
+ span: crate::Span,
+ ) -> Result<(), Error> {
+ let decor = self.future_decor.remove(&id).unwrap_or_default();
+
+ let inner = if let Some(id) = decor.specialization_constant_id {
+ let o = crate::Override {
+ name: decor.name,
+ id: Some(id.try_into().map_err(|_| Error::SpecIdTooHigh(id))?),
+ ty,
+ init: Some(init),
+ };
+ Constant::Override(module.overrides.append(o, span))
+ } else {
+ let c = crate::Constant {
+ name: decor.name,
+ ty,
+ init,
+ };
+ Constant::Constant(module.constants.append(c, span))
+ };
+
+ self.lookup_constant
+ .insert(id, LookupConstant { inner, type_id });
Ok(())
}
@@ -5087,8 +5329,8 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let span = self.span_from_with_op(start);
let lconst = self.lookup_constant.lookup(init_id)?;
let expr = module
- .const_expressions
- .append(crate::Expression::Constant(lconst.handle), span);
+ .global_expressions
+ .append(lconst.inner.to_expr(), span);
Some(expr)
} else {
None
@@ -5209,7 +5451,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
match null::generate_default_built_in(
Some(built_in),
ty,
- &mut module.const_expressions,
+ &mut module.global_expressions,
span,
) {
Ok(handle) => Some(handle),
@@ -5231,14 +5473,14 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let handle = null::generate_default_built_in(
built_in,
member.ty,
- &mut module.const_expressions,
+ &mut module.global_expressions,
span,
)?;
components.push(handle);
}
Some(
module
- .const_expressions
+ .global_expressions
.append(crate::Expression::Compose { ty, components }, span),
)
}
@@ -5303,11 +5545,12 @@ fn make_index_literal(
Ok(expr)
}
-fn resolve_constant(
- gctx: crate::proc::GlobalCtx,
- constant: Handle<crate::Constant>,
-) -> Option<u32> {
- match gctx.const_expressions[gctx.constants[constant].init] {
+fn resolve_constant(gctx: crate::proc::GlobalCtx, constant: &Constant) -> Option<u32> {
+ let constant = match *constant {
+ Constant::Constant(constant) => constant,
+ Constant::Override(_) => return None,
+ };
+ match gctx.global_expressions[gctx.constants[constant].init] {
crate::Expression::Literal(crate::Literal::U32(id)) => Some(id),
crate::Expression::Literal(crate::Literal::I32(id)) => Some(id as u32),
_ => None,
diff --git a/third_party/rust/naga/src/front/spv/null.rs b/third_party/rust/naga/src/front/spv/null.rs
index 42cccca80a..c7d3776841 100644
--- a/third_party/rust/naga/src/front/spv/null.rs
+++ b/third_party/rust/naga/src/front/spv/null.rs
@@ -5,14 +5,14 @@ use crate::arena::{Arena, Handle};
pub fn generate_default_built_in(
built_in: Option<crate::BuiltIn>,
ty: Handle<crate::Type>,
- const_expressions: &mut Arena<crate::Expression>,
+ global_expressions: &mut Arena<crate::Expression>,
span: crate::Span,
) -> Result<Handle<crate::Expression>, Error> {
let expr = match built_in {
Some(crate::BuiltIn::Position { .. }) => {
- let zero = const_expressions
+ let zero = global_expressions
.append(crate::Expression::Literal(crate::Literal::F32(0.0)), span);
- let one = const_expressions
+ let one = global_expressions
.append(crate::Expression::Literal(crate::Literal::F32(1.0)), span);
crate::Expression::Compose {
ty,
@@ -27,5 +27,5 @@ pub fn generate_default_built_in(
// Note: `crate::BuiltIn::ClipDistance` is intentionally left for the default path
_ => crate::Expression::ZeroValue(ty),
};
- Ok(const_expressions.append(expr, span))
+ Ok(global_expressions.append(expr, span))
}
diff --git a/third_party/rust/naga/src/front/wgsl/error.rs b/third_party/rust/naga/src/front/wgsl/error.rs
index 54aa8296b1..dc1339521c 100644
--- a/third_party/rust/naga/src/front/wgsl/error.rs
+++ b/third_party/rust/naga/src/front/wgsl/error.rs
@@ -13,6 +13,7 @@ use thiserror::Error;
#[derive(Clone, Debug)]
pub struct ParseError {
message: String,
+ // The first span should be the primary span, and the other ones should be complementary.
labels: Vec<(Span, Cow<'static, str>)>,
notes: Vec<String>,
}
@@ -190,7 +191,7 @@ pub enum Error<'a> {
expected: String,
got: String,
},
- MissingType(Span),
+ DeclMissingTypeAndInit(Span),
MissingAttribute(&'static str, Span),
InvalidAtomicPointer(Span),
InvalidAtomicOperandType(Span),
@@ -269,6 +270,11 @@ pub enum Error<'a> {
scalar: String,
inner: ConstantEvaluatorError,
},
+ ExceededLimitForNestedBraces {
+ span: Span,
+ limit: u8,
+ },
+ PipelineConstantIDValue(Span),
}
impl<'a> Error<'a> {
@@ -518,11 +524,11 @@ impl<'a> Error<'a> {
notes: vec![],
}
}
- Error::MissingType(name_span) => ParseError {
- message: format!("variable `{}` needs a type", &source[name_span]),
+ Error::DeclMissingTypeAndInit(name_span) => ParseError {
+ message: format!("declaration of `{}` needs a type specifier or initializer", &source[name_span]),
labels: vec![(
name_span,
- format!("definition of `{}`", &source[name_span]).into(),
+ "needs a type specifier or initializer".into(),
)],
notes: vec![],
},
@@ -770,6 +776,21 @@ impl<'a> Error<'a> {
format!("the expression should have been converted to have {} scalar type", scalar),
]
},
+ Error::ExceededLimitForNestedBraces { span, limit } => ParseError {
+ message: "brace nesting limit reached".into(),
+ labels: vec![(span, "limit reached at this brace".into())],
+ notes: vec![
+ format!("nesting limit is currently set to {limit}"),
+ ],
+ },
+ Error::PipelineConstantIDValue(span) => ParseError {
+ message: "pipeline constant ID must be between 0 and 65535 inclusive".to_string(),
+ labels: vec![(
+ span,
+ "must be between 0 and 65535 inclusive".into(),
+ )],
+ notes: vec![],
+ },
}
}
}
diff --git a/third_party/rust/naga/src/front/wgsl/index.rs b/third_party/rust/naga/src/front/wgsl/index.rs
index a5524fe8f1..593405508f 100644
--- a/third_party/rust/naga/src/front/wgsl/index.rs
+++ b/third_party/rust/naga/src/front/wgsl/index.rs
@@ -187,6 +187,7 @@ const fn decl_ident<'a>(decl: &ast::GlobalDecl<'a>) -> ast::Ident<'a> {
ast::GlobalDeclKind::Fn(ref f) => f.name,
ast::GlobalDeclKind::Var(ref v) => v.name,
ast::GlobalDeclKind::Const(ref c) => c.name,
+ ast::GlobalDeclKind::Override(ref o) => o.name,
ast::GlobalDeclKind::Struct(ref s) => s.name,
ast::GlobalDeclKind::Type(ref t) => t.name,
}
diff --git a/third_party/rust/naga/src/front/wgsl/lower/mod.rs b/third_party/rust/naga/src/front/wgsl/lower/mod.rs
index 2ca6c182b7..e7cce17723 100644
--- a/third_party/rust/naga/src/front/wgsl/lower/mod.rs
+++ b/third_party/rust/naga/src/front/wgsl/lower/mod.rs
@@ -86,6 +86,8 @@ pub struct GlobalContext<'source, 'temp, 'out> {
module: &'out mut crate::Module,
const_typifier: &'temp mut Typifier,
+
+ global_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker,
}
impl<'source> GlobalContext<'source, '_, '_> {
@@ -97,6 +99,19 @@ impl<'source> GlobalContext<'source, '_, '_> {
module: self.module,
const_typifier: self.const_typifier,
expr_type: ExpressionContextType::Constant,
+ global_expression_kind_tracker: self.global_expression_kind_tracker,
+ }
+ }
+
+ fn as_override(&mut self) -> ExpressionContext<'source, '_, '_> {
+ ExpressionContext {
+ ast_expressions: self.ast_expressions,
+ globals: self.globals,
+ types: self.types,
+ module: self.module,
+ const_typifier: self.const_typifier,
+ expr_type: ExpressionContextType::Override,
+ global_expression_kind_tracker: self.global_expression_kind_tracker,
}
}
@@ -164,7 +179,8 @@ pub struct StatementContext<'source, 'temp, 'out> {
/// with the form of the expressions; it is also tracking whether WGSL says
/// we should consider them to be const. See the use of `force_non_const` in
/// the code for lowering `let` bindings.
- expression_constness: &'temp mut crate::proc::ExpressionConstnessTracker,
+ local_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker,
+ global_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker,
}
impl<'a, 'temp> StatementContext<'a, 'temp, '_> {
@@ -181,6 +197,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> {
types: self.types,
ast_expressions: self.ast_expressions,
const_typifier: self.const_typifier,
+ global_expression_kind_tracker: self.global_expression_kind_tracker,
module: self.module,
expr_type: ExpressionContextType::Runtime(RuntimeExpressionContext {
local_table: self.local_table,
@@ -188,7 +205,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> {
block,
emitter,
typifier: self.typifier,
- expression_constness: self.expression_constness,
+ local_expression_kind_tracker: self.local_expression_kind_tracker,
}),
}
}
@@ -200,6 +217,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> {
types: self.types,
module: self.module,
const_typifier: self.const_typifier,
+ global_expression_kind_tracker: self.global_expression_kind_tracker,
}
}
@@ -232,8 +250,8 @@ pub struct RuntimeExpressionContext<'temp, 'out> {
/// Which `Expression`s in `self.naga_expressions` are const expressions, in
/// the WGSL sense.
///
- /// See [`StatementContext::expression_constness`] for details.
- expression_constness: &'temp mut crate::proc::ExpressionConstnessTracker,
+ /// See [`StatementContext::local_expression_kind_tracker`] for details.
+ local_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker,
}
/// The type of Naga IR expression we are lowering an [`ast::Expression`] to.
@@ -253,6 +271,14 @@ pub enum ExpressionContextType<'temp, 'out> {
/// available in the [`ExpressionContext`], so this variant
/// carries no further information.
Constant,
+
+ /// We are lowering to an override expression, to be included in the module's
+ /// constant expression arena.
+ ///
+ /// Everything override expressions are allowed to refer to is
+ /// available in the [`ExpressionContext`], so this variant
+ /// carries no further information.
+ Override,
}
/// State for lowering an [`ast::Expression`] to Naga IR.
@@ -307,10 +333,11 @@ pub struct ExpressionContext<'source, 'temp, 'out> {
/// [`Module`]: crate::Module
module: &'out mut crate::Module,
- /// Type judgments for [`module::const_expressions`].
+ /// Type judgments for [`module::global_expressions`].
///
- /// [`module::const_expressions`]: crate::Module::const_expressions
+ /// [`module::global_expressions`]: crate::Module::global_expressions
const_typifier: &'temp mut Typifier,
+ global_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker,
/// Whether we are lowering a constant expression or a general
/// runtime expression, and the data needed in each case.
@@ -326,6 +353,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
const_typifier: self.const_typifier,
module: self.module,
expr_type: ExpressionContextType::Constant,
+ global_expression_kind_tracker: self.global_expression_kind_tracker,
}
}
@@ -336,6 +364,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
types: self.types,
module: self.module,
const_typifier: self.const_typifier,
+ global_expression_kind_tracker: self.global_expression_kind_tracker,
}
}
@@ -344,11 +373,20 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
ExpressionContextType::Runtime(ref mut rctx) => ConstantEvaluator::for_wgsl_function(
self.module,
&mut rctx.function.expressions,
- rctx.expression_constness,
+ rctx.local_expression_kind_tracker,
rctx.emitter,
rctx.block,
),
- ExpressionContextType::Constant => ConstantEvaluator::for_wgsl_module(self.module),
+ ExpressionContextType::Constant => ConstantEvaluator::for_wgsl_module(
+ self.module,
+ self.global_expression_kind_tracker,
+ false,
+ ),
+ ExpressionContextType::Override => ConstantEvaluator::for_wgsl_module(
+ self.module,
+ self.global_expression_kind_tracker,
+ true,
+ ),
}
}
@@ -358,24 +396,14 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
span: Span,
) -> Result<Handle<crate::Expression>, Error<'source>> {
let mut eval = self.as_const_evaluator();
- match eval.try_eval_and_append(&expr, span) {
- Ok(expr) => Ok(expr),
-
- // `expr` is not a constant expression. This is fine as
- // long as we're not building `Module::const_expressions`.
- Err(err) => match self.expr_type {
- ExpressionContextType::Runtime(ref mut rctx) => {
- Ok(rctx.function.expressions.append(expr, span))
- }
- ExpressionContextType::Constant => Err(Error::ConstantEvaluatorError(err, span)),
- },
- }
+ eval.try_eval_and_append(expr, span)
+ .map_err(|e| Error::ConstantEvaluatorError(e, span))
}
fn const_access(&self, handle: Handle<crate::Expression>) -> Option<u32> {
match self.expr_type {
ExpressionContextType::Runtime(ref ctx) => {
- if !ctx.expression_constness.is_const(handle) {
+ if !ctx.local_expression_kind_tracker.is_const(handle) {
return None;
}
@@ -385,20 +413,25 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
.ok()
}
ExpressionContextType::Constant => self.module.to_ctx().eval_expr_to_u32(handle).ok(),
+ ExpressionContextType::Override => None,
}
}
fn get_expression_span(&self, handle: Handle<crate::Expression>) -> Span {
match self.expr_type {
ExpressionContextType::Runtime(ref ctx) => ctx.function.expressions.get_span(handle),
- ExpressionContextType::Constant => self.module.const_expressions.get_span(handle),
+ ExpressionContextType::Constant | ExpressionContextType::Override => {
+ self.module.global_expressions.get_span(handle)
+ }
}
}
fn typifier(&self) -> &Typifier {
match self.expr_type {
ExpressionContextType::Runtime(ref ctx) => ctx.typifier,
- ExpressionContextType::Constant => self.const_typifier,
+ ExpressionContextType::Constant | ExpressionContextType::Override => {
+ self.const_typifier
+ }
}
}
@@ -408,7 +441,9 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
) -> Result<&mut RuntimeExpressionContext<'temp, 'out>, Error<'source>> {
match self.expr_type {
ExpressionContextType::Runtime(ref mut ctx) => Ok(ctx),
- ExpressionContextType::Constant => Err(Error::UnexpectedOperationInConstContext(span)),
+ ExpressionContextType::Constant | ExpressionContextType::Override => {
+ Err(Error::UnexpectedOperationInConstContext(span))
+ }
}
}
@@ -420,7 +455,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
) -> Result<crate::SwizzleComponent, Error<'source>> {
match self.expr_type {
ExpressionContextType::Runtime(ref rctx) => {
- if !rctx.expression_constness.is_const(expr) {
+ if !rctx.local_expression_kind_tracker.is_const(expr) {
return Err(Error::ExpectedConstExprConcreteIntegerScalar(
component_span,
));
@@ -445,7 +480,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
}
// This means a `gather` operation appeared in a constant expression.
// This error refers to the `gather` itself, not its "component" argument.
- ExpressionContextType::Constant => {
+ ExpressionContextType::Constant | ExpressionContextType::Override => {
Err(Error::UnexpectedOperationInConstContext(gather_span))
}
}
@@ -471,7 +506,9 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
// to also borrow self.module.types mutably below.
let typifier = match self.expr_type {
ExpressionContextType::Runtime(ref ctx) => ctx.typifier,
- ExpressionContextType::Constant => &*self.const_typifier,
+ ExpressionContextType::Constant | ExpressionContextType::Override => {
+ &*self.const_typifier
+ }
};
Ok(typifier.register_type(handle, &mut self.module.types))
}
@@ -514,10 +551,10 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
typifier = &mut *ctx.typifier;
expressions = &ctx.function.expressions;
}
- ExpressionContextType::Constant => {
+ ExpressionContextType::Constant | ExpressionContextType::Override => {
resolve_ctx = ResolveContext::with_locals(self.module, &empty_arena, &[]);
typifier = self.const_typifier;
- expressions = &self.module.const_expressions;
+ expressions = &self.module.global_expressions;
}
};
typifier
@@ -610,14 +647,14 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
rctx.block
.extend(rctx.emitter.finish(&rctx.function.expressions));
}
- ExpressionContextType::Constant => {}
+ ExpressionContextType::Constant | ExpressionContextType::Override => {}
}
let result = self.append_expression(expression, span);
match self.expr_type {
ExpressionContextType::Runtime(ref mut rctx) => {
rctx.emitter.start(&rctx.function.expressions);
}
- ExpressionContextType::Constant => {}
+ ExpressionContextType::Constant | ExpressionContextType::Override => {}
}
result
}
@@ -786,6 +823,7 @@ enum LoweredGlobalDecl {
Function(Handle<crate::Function>),
Var(Handle<crate::GlobalVariable>),
Const(Handle<crate::Constant>),
+ Override(Handle<crate::Override>),
Type(Handle<crate::Type>),
EntryPoint,
}
@@ -836,6 +874,29 @@ impl Texture {
}
}
+enum SubgroupGather {
+ BroadcastFirst,
+ Broadcast,
+ Shuffle,
+ ShuffleDown,
+ ShuffleUp,
+ ShuffleXor,
+}
+
+impl SubgroupGather {
+ pub fn map(word: &str) -> Option<Self> {
+ Some(match word {
+ "subgroupBroadcastFirst" => Self::BroadcastFirst,
+ "subgroupBroadcast" => Self::Broadcast,
+ "subgroupShuffle" => Self::Shuffle,
+ "subgroupShuffleDown" => Self::ShuffleDown,
+ "subgroupShuffleUp" => Self::ShuffleUp,
+ "subgroupShuffleXor" => Self::ShuffleXor,
+ _ => return None,
+ })
+ }
+}
+
pub struct Lowerer<'source, 'temp> {
index: &'temp Index<'source>,
layouter: Layouter,
@@ -861,6 +922,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
types: &tu.types,
module: &mut module,
const_typifier: &mut Typifier::new(),
+ global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker::new(),
};
for decl_handle in self.index.visit_ordered() {
@@ -877,7 +939,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let init;
if let Some(init_ast) = v.init {
- let mut ectx = ctx.as_const();
+ let mut ectx = ctx.as_override();
let lowered = self.expression_for_abstract(init_ast, &mut ectx)?;
let ty_res = crate::proc::TypeResolution::Handle(ty);
let converted = ectx
@@ -956,7 +1018,6 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let handle = ctx.module.constants.append(
crate::Constant {
name: Some(c.name.name.to_string()),
- r#override: crate::Override::None,
ty,
init,
},
@@ -966,6 +1027,65 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
ctx.globals
.insert(c.name.name, LoweredGlobalDecl::Const(handle));
}
+ ast::GlobalDeclKind::Override(ref o) => {
+ let init = o
+ .init
+ .map(|init| self.expression(init, &mut ctx.as_override()))
+ .transpose()?;
+ let inferred_type = init
+ .map(|init| ctx.as_const().register_type(init))
+ .transpose()?;
+
+ let explicit_ty =
+ o.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx))
+ .transpose()?;
+
+ let id =
+ o.id.map(|id| self.const_u32(id, &mut ctx.as_const()))
+ .transpose()?;
+
+ let id = if let Some((id, id_span)) = id {
+ Some(
+ u16::try_from(id)
+ .map_err(|_| Error::PipelineConstantIDValue(id_span))?,
+ )
+ } else {
+ None
+ };
+
+ let ty = match (explicit_ty, inferred_type) {
+ (Some(explicit_ty), Some(inferred_type)) => {
+ if explicit_ty == inferred_type {
+ explicit_ty
+ } else {
+ let gctx = ctx.module.to_ctx();
+ return Err(Error::InitializationTypeMismatch {
+ name: o.name.span,
+ expected: explicit_ty.to_wgsl(&gctx),
+ got: inferred_type.to_wgsl(&gctx),
+ });
+ }
+ }
+ (Some(explicit_ty), None) => explicit_ty,
+ (None, Some(inferred_type)) => inferred_type,
+ (None, None) => {
+ return Err(Error::DeclMissingTypeAndInit(o.name.span));
+ }
+ };
+
+ let handle = ctx.module.overrides.append(
+ crate::Override {
+ name: Some(o.name.name.to_string()),
+ id,
+ ty,
+ init,
+ },
+ span,
+ );
+
+ ctx.globals
+ .insert(o.name.name, LoweredGlobalDecl::Override(handle));
+ }
ast::GlobalDeclKind::Struct(ref s) => {
let handle = self.r#struct(s, span, &mut ctx)?;
ctx.globals
@@ -1000,6 +1120,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let mut local_table = FastHashMap::default();
let mut expressions = Arena::new();
let mut named_expressions = FastIndexMap::default();
+ let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
let arguments = f
.arguments
@@ -1011,6 +1132,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.append(crate::Expression::FunctionArgument(i as u32), arg.name.span);
local_table.insert(arg.handle, Typed::Plain(expr));
named_expressions.insert(expr, (arg.name.name.to_string(), arg.name.span));
+ local_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Runtime);
Ok(crate::FunctionArgument {
name: Some(arg.name.name.to_string()),
@@ -1053,7 +1175,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
named_expressions: &mut named_expressions,
types: ctx.types,
module: ctx.module,
- expression_constness: &mut crate::proc::ExpressionConstnessTracker::new(),
+ local_expression_kind_tracker: &mut local_expression_kind_tracker,
+ global_expression_kind_tracker: ctx.global_expression_kind_tracker,
};
let mut body = self.block(&f.body, false, &mut stmt_ctx)?;
ensure_block_returns(&mut body);
@@ -1132,7 +1255,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
// affects when errors must be reported, so we can't even
// treat suitable `let` bindings as constant as an
// optimization.
- ctx.expression_constness.force_non_const(value);
+ ctx.local_expression_kind_tracker.force_non_const(value);
let explicit_ty =
l.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx.as_global()))
@@ -1203,7 +1326,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
ty = explicit_ty;
initializer = None;
}
- (None, None) => return Err(Error::MissingType(v.name.span)),
+ (None, None) => return Err(Error::DeclMissingTypeAndInit(v.name.span)),
}
let (const_initializer, initializer) = {
@@ -1216,7 +1339,9 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
// - the initialization is not a constant
// expression, so its value depends on the
// state at the point of initialization.
- if is_inside_loop || !ctx.expression_constness.is_const(init) {
+ if is_inside_loop
+ || !ctx.local_expression_kind_tracker.is_const_or_override(init)
+ {
(None, Some(init))
} else {
(Some(init), None)
@@ -1469,6 +1594,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.function
.expressions
.append(crate::Expression::Binary { op, left, right }, stmt.span);
+ rctx.local_expression_kind_tracker
+ .insert(left, crate::proc::ExpressionKind::Runtime);
+ rctx.local_expression_kind_tracker
+ .insert(value, crate::proc::ExpressionKind::Runtime);
block.extend(emitter.finish(&ctx.function.expressions));
crate::Statement::Store {
@@ -1562,7 +1691,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
LoweredGlobalDecl::Const(handle) => {
Typed::Plain(crate::Expression::Constant(handle))
}
- _ => {
+ LoweredGlobalDecl::Override(handle) => {
+ Typed::Plain(crate::Expression::Override(handle))
+ }
+ LoweredGlobalDecl::Function(_)
+ | LoweredGlobalDecl::Type(_)
+ | LoweredGlobalDecl::EntryPoint => {
return Err(Error::Unexpected(span, ExpectedToken::Variable));
}
};
@@ -1819,9 +1953,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
)?;
Ok(Some(handle))
}
- Some(&LoweredGlobalDecl::Const(_) | &LoweredGlobalDecl::Var(_)) => {
- Err(Error::Unexpected(function.span, ExpectedToken::Function))
- }
+ Some(
+ &LoweredGlobalDecl::Const(_)
+ | &LoweredGlobalDecl::Override(_)
+ | &LoweredGlobalDecl::Var(_),
+ ) => Err(Error::Unexpected(function.span, ExpectedToken::Function)),
Some(&LoweredGlobalDecl::EntryPoint) => Err(Error::CalledEntryPoint(function.span)),
Some(&LoweredGlobalDecl::Function(function)) => {
let arguments = arguments
@@ -1835,9 +1971,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
rctx.block
.extend(rctx.emitter.finish(&rctx.function.expressions));
let result = has_result.then(|| {
- rctx.function
+ let result = rctx
+ .function
.expressions
- .append(crate::Expression::CallResult(function), span)
+ .append(crate::Expression::CallResult(function), span);
+ rctx.local_expression_kind_tracker
+ .insert(result, crate::proc::ExpressionKind::Runtime);
+ result
});
rctx.emitter.start(&rctx.function.expressions);
rctx.block.push(
@@ -1937,6 +2077,16 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}
} else if let Some(fun) = Texture::map(function.name) {
self.texture_sample_helper(fun, arguments, span, ctx)?
+ } else if let Some((op, cop)) = conv::map_subgroup_operation(function.name) {
+ return Ok(Some(
+ self.subgroup_operation_helper(span, op, cop, arguments, ctx)?,
+ ));
+ } else if let Some(mode) = SubgroupGather::map(function.name) {
+ return Ok(Some(
+ self.subgroup_gather_helper(span, mode, arguments, ctx)?,
+ ));
+ } else if let Some(fun) = crate::AtomicFunction::map(function.name) {
+ return Ok(Some(self.atomic_helper(span, fun, arguments, ctx)?));
} else {
match function.name {
"select" => {
@@ -1982,70 +2132,6 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.push(crate::Statement::Store { pointer, value }, span);
return Ok(None);
}
- "atomicAdd" => {
- return Ok(Some(self.atomic_helper(
- span,
- crate::AtomicFunction::Add,
- arguments,
- ctx,
- )?))
- }
- "atomicSub" => {
- return Ok(Some(self.atomic_helper(
- span,
- crate::AtomicFunction::Subtract,
- arguments,
- ctx,
- )?))
- }
- "atomicAnd" => {
- return Ok(Some(self.atomic_helper(
- span,
- crate::AtomicFunction::And,
- arguments,
- ctx,
- )?))
- }
- "atomicOr" => {
- return Ok(Some(self.atomic_helper(
- span,
- crate::AtomicFunction::InclusiveOr,
- arguments,
- ctx,
- )?))
- }
- "atomicXor" => {
- return Ok(Some(self.atomic_helper(
- span,
- crate::AtomicFunction::ExclusiveOr,
- arguments,
- ctx,
- )?))
- }
- "atomicMin" => {
- return Ok(Some(self.atomic_helper(
- span,
- crate::AtomicFunction::Min,
- arguments,
- ctx,
- )?))
- }
- "atomicMax" => {
- return Ok(Some(self.atomic_helper(
- span,
- crate::AtomicFunction::Max,
- arguments,
- ctx,
- )?))
- }
- "atomicExchange" => {
- return Ok(Some(self.atomic_helper(
- span,
- crate::AtomicFunction::Exchange { compare: None },
- arguments,
- ctx,
- )?))
- }
"atomicCompareExchangeWeak" => {
let mut args = ctx.prepare_args(arguments, 3, span);
@@ -2104,6 +2190,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.push(crate::Statement::Barrier(crate::Barrier::WORK_GROUP), span);
return Ok(None);
}
+ "subgroupBarrier" => {
+ ctx.prepare_args(arguments, 0, span).finish()?;
+
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ rctx.block
+ .push(crate::Statement::Barrier(crate::Barrier::SUB_GROUP), span);
+ return Ok(None);
+ }
"workgroupUniformLoad" => {
let mut args = ctx.prepare_args(arguments, 1, span);
let expr = args.next()?;
@@ -2311,6 +2405,22 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
)?;
return Ok(Some(handle));
}
+ "subgroupBallot" => {
+ let mut args = ctx.prepare_args(arguments, 0, span);
+ let predicate = if arguments.len() == 1 {
+ Some(self.expression(args.next()?, ctx)?)
+ } else {
+ None
+ };
+ args.finish()?;
+
+ let result = ctx
+ .interrupt_emitter(crate::Expression::SubgroupBallotResult, span)?;
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ rctx.block
+ .push(crate::Statement::SubgroupBallot { result, predicate }, span);
+ return Ok(Some(result));
+ }
_ => return Err(Error::UnknownIdent(function.span, function.name)),
}
};
@@ -2502,6 +2612,80 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
})
}
+ fn subgroup_operation_helper(
+ &mut self,
+ span: Span,
+ op: crate::SubgroupOperation,
+ collective_op: crate::CollectiveOperation,
+ arguments: &[Handle<ast::Expression<'source>>],
+ ctx: &mut ExpressionContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+
+ let argument = self.expression(args.next()?, ctx)?;
+ args.finish()?;
+
+ let ty = ctx.register_type(argument)?;
+
+ let result =
+ ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?;
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ rctx.block.push(
+ crate::Statement::SubgroupCollectiveOperation {
+ op,
+ collective_op,
+ argument,
+ result,
+ },
+ span,
+ );
+ Ok(result)
+ }
+
+ fn subgroup_gather_helper(
+ &mut self,
+ span: Span,
+ mode: SubgroupGather,
+ arguments: &[Handle<ast::Expression<'source>>],
+ ctx: &mut ExpressionContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ let mut args = ctx.prepare_args(arguments, 2, span);
+
+ let argument = self.expression(args.next()?, ctx)?;
+
+ use SubgroupGather as Sg;
+ let mode = if let Sg::BroadcastFirst = mode {
+ crate::GatherMode::BroadcastFirst
+ } else {
+ let index = self.expression(args.next()?, ctx)?;
+ match mode {
+ Sg::Broadcast => crate::GatherMode::Broadcast(index),
+ Sg::Shuffle => crate::GatherMode::Shuffle(index),
+ Sg::ShuffleDown => crate::GatherMode::ShuffleDown(index),
+ Sg::ShuffleUp => crate::GatherMode::ShuffleUp(index),
+ Sg::ShuffleXor => crate::GatherMode::ShuffleXor(index),
+ Sg::BroadcastFirst => unreachable!(),
+ }
+ };
+
+ args.finish()?;
+
+ let ty = ctx.register_type(argument)?;
+
+ let result =
+ ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?;
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ rctx.block.push(
+ crate::Statement::SubgroupGather {
+ mode,
+ argument,
+ result,
+ },
+ span,
+ );
+ Ok(result)
+ }
+
fn r#struct(
&mut self,
s: &ast::Struct<'source>,
@@ -2760,3 +2944,19 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}
}
}
+
+impl crate::AtomicFunction {
+ pub fn map(word: &str) -> Option<Self> {
+ Some(match word {
+ "atomicAdd" => crate::AtomicFunction::Add,
+ "atomicSub" => crate::AtomicFunction::Subtract,
+ "atomicAnd" => crate::AtomicFunction::And,
+ "atomicOr" => crate::AtomicFunction::InclusiveOr,
+ "atomicXor" => crate::AtomicFunction::ExclusiveOr,
+ "atomicMin" => crate::AtomicFunction::Min,
+ "atomicMax" => crate::AtomicFunction::Max,
+ "atomicExchange" => crate::AtomicFunction::Exchange { compare: None },
+ _ => return None,
+ })
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/mod.rs b/third_party/rust/naga/src/front/wgsl/mod.rs
index b6151fe1c0..aec1e657fc 100644
--- a/third_party/rust/naga/src/front/wgsl/mod.rs
+++ b/third_party/rust/naga/src/front/wgsl/mod.rs
@@ -44,6 +44,17 @@ impl Frontend {
}
}
+/// <div class="warning">
+// NOTE: Keep this in sync with `wgpu::Device::create_shader_module`!
+// NOTE: Keep this in sync with `wgpu_core::Global::device_create_shader_module`!
+///
+/// This function may consume a lot of stack space. Compiler-enforced limits for parsing recursion
+/// exist; if shader compilation runs into them, it will return an error gracefully. However, on
+/// some build profiles and platforms, the default stack size for a thread may be exceeded before
+/// this limit is reached during parsing. Callers should ensure that there is enough stack space
+/// for this, particularly if calls to this method are exposed to user input.
+///
+/// </div>
pub fn parse_str(source: &str) -> Result<crate::Module, ParseError> {
Frontend::new().parse(source)
}
diff --git a/third_party/rust/naga/src/front/wgsl/parse/ast.rs b/third_party/rust/naga/src/front/wgsl/parse/ast.rs
index dbaac523cb..ea8013ee7c 100644
--- a/third_party/rust/naga/src/front/wgsl/parse/ast.rs
+++ b/third_party/rust/naga/src/front/wgsl/parse/ast.rs
@@ -82,6 +82,7 @@ pub enum GlobalDeclKind<'a> {
Fn(Function<'a>),
Var(GlobalVariable<'a>),
Const(Const<'a>),
+ Override(Override<'a>),
Struct(Struct<'a>),
Type(TypeAlias<'a>),
}
@@ -200,6 +201,14 @@ pub struct Const<'a> {
pub init: Handle<Expression<'a>>,
}
+#[derive(Debug)]
+pub struct Override<'a> {
+ pub name: Ident<'a>,
+ pub id: Option<Handle<Expression<'a>>>,
+ pub ty: Option<Handle<Type<'a>>>,
+ pub init: Option<Handle<Expression<'a>>>,
+}
+
/// The size of an [`Array`] or [`BindingArray`].
///
/// [`Array`]: Type::Array
diff --git a/third_party/rust/naga/src/front/wgsl/parse/conv.rs b/third_party/rust/naga/src/front/wgsl/parse/conv.rs
index 1a4911a3bd..207f0eda41 100644
--- a/third_party/rust/naga/src/front/wgsl/parse/conv.rs
+++ b/third_party/rust/naga/src/front/wgsl/parse/conv.rs
@@ -35,6 +35,11 @@ pub fn map_built_in(word: &str, span: Span) -> Result<crate::BuiltIn, Error<'_>>
"local_invocation_index" => crate::BuiltIn::LocalInvocationIndex,
"workgroup_id" => crate::BuiltIn::WorkGroupId,
"num_workgroups" => crate::BuiltIn::NumWorkGroups,
+ // subgroup
+ "num_subgroups" => crate::BuiltIn::NumSubgroups,
+ "subgroup_id" => crate::BuiltIn::SubgroupId,
+ "subgroup_size" => crate::BuiltIn::SubgroupSize,
+ "subgroup_invocation_id" => crate::BuiltIn::SubgroupInvocationId,
_ => return Err(Error::UnknownBuiltin(span)),
})
}
@@ -260,3 +265,26 @@ pub fn map_conservative_depth(
_ => Err(Error::UnknownConservativeDepth(span)),
}
}
+
+pub fn map_subgroup_operation(
+ word: &str,
+) -> Option<(crate::SubgroupOperation, crate::CollectiveOperation)> {
+ use crate::CollectiveOperation as co;
+ use crate::SubgroupOperation as sg;
+ Some(match word {
+ "subgroupAll" => (sg::All, co::Reduce),
+ "subgroupAny" => (sg::Any, co::Reduce),
+ "subgroupAdd" => (sg::Add, co::Reduce),
+ "subgroupMul" => (sg::Mul, co::Reduce),
+ "subgroupMin" => (sg::Min, co::Reduce),
+ "subgroupMax" => (sg::Max, co::Reduce),
+ "subgroupAnd" => (sg::And, co::Reduce),
+ "subgroupOr" => (sg::Or, co::Reduce),
+ "subgroupXor" => (sg::Xor, co::Reduce),
+ "subgroupExclusiveAdd" => (sg::Add, co::ExclusiveScan),
+ "subgroupExclusiveMul" => (sg::Mul, co::ExclusiveScan),
+ "subgroupInclusiveAdd" => (sg::Add, co::InclusiveScan),
+ "subgroupInclusiveMul" => (sg::Mul, co::InclusiveScan),
+ _ => return None,
+ })
+}
diff --git a/third_party/rust/naga/src/front/wgsl/parse/mod.rs b/third_party/rust/naga/src/front/wgsl/parse/mod.rs
index 51fc2f013b..79ea1ae609 100644
--- a/third_party/rust/naga/src/front/wgsl/parse/mod.rs
+++ b/third_party/rust/naga/src/front/wgsl/parse/mod.rs
@@ -1619,22 +1619,21 @@ impl Parser {
lexer: &mut Lexer<'a>,
ctx: &mut ExpressionContext<'a, '_, '_>,
block: &mut ast::Block<'a>,
+ brace_nesting_level: u8,
) -> Result<(), Error<'a>> {
self.push_rule_span(Rule::Statement, lexer);
match lexer.peek() {
(Token::Separator(';'), _) => {
let _ = lexer.next();
self.pop_rule_span(lexer);
- return Ok(());
}
(Token::Paren('{'), _) => {
- let (inner, span) = self.block(lexer, ctx)?;
+ let (inner, span) = self.block(lexer, ctx, brace_nesting_level)?;
block.stmts.push(ast::Statement {
kind: ast::StatementKind::Block(inner),
span,
});
self.pop_rule_span(lexer);
- return Ok(());
}
(Token::Word(word), _) => {
let kind = match word {
@@ -1711,7 +1710,7 @@ impl Parser {
let _ = lexer.next();
let condition = self.general_expression(lexer, ctx)?;
- let accept = self.block(lexer, ctx)?.0;
+ let accept = self.block(lexer, ctx, brace_nesting_level)?.0;
let mut elsif_stack = Vec::new();
let mut elseif_span_start = lexer.start_byte_offset();
@@ -1722,12 +1721,12 @@ impl Parser {
if !lexer.skip(Token::Word("if")) {
// ... else { ... }
- break self.block(lexer, ctx)?.0;
+ break self.block(lexer, ctx, brace_nesting_level)?.0;
}
// ... else if (...) { ... }
let other_condition = self.general_expression(lexer, ctx)?;
- let other_block = self.block(lexer, ctx)?;
+ let other_block = self.block(lexer, ctx, brace_nesting_level)?;
elsif_stack.push((elseif_span_start, other_condition, other_block));
elseif_span_start = lexer.start_byte_offset();
};
@@ -1759,7 +1758,9 @@ impl Parser {
"switch" => {
let _ = lexer.next();
let selector = self.general_expression(lexer, ctx)?;
- lexer.expect(Token::Paren('{'))?;
+ let brace_span = lexer.expect_span(Token::Paren('{'))?;
+ let brace_nesting_level =
+ Self::increase_brace_nesting(brace_nesting_level, brace_span)?;
let mut cases = Vec::new();
loop {
@@ -1784,7 +1785,7 @@ impl Parser {
});
};
- let body = self.block(lexer, ctx)?.0;
+ let body = self.block(lexer, ctx, brace_nesting_level)?.0;
cases.push(ast::SwitchCase {
value,
@@ -1794,7 +1795,7 @@ impl Parser {
}
(Token::Word("default"), _) => {
lexer.skip(Token::Separator(':'));
- let body = self.block(lexer, ctx)?.0;
+ let body = self.block(lexer, ctx, brace_nesting_level)?.0;
cases.push(ast::SwitchCase {
value: ast::SwitchValue::Default,
body,
@@ -1810,7 +1811,7 @@ impl Parser {
ast::StatementKind::Switch { selector, cases }
}
- "loop" => self.r#loop(lexer, ctx)?,
+ "loop" => self.r#loop(lexer, ctx, brace_nesting_level)?,
"while" => {
let _ = lexer.next();
let mut body = ast::Block::default();
@@ -1834,7 +1835,7 @@ impl Parser {
span,
});
- let (block, span) = self.block(lexer, ctx)?;
+ let (block, span) = self.block(lexer, ctx, brace_nesting_level)?;
body.stmts.push(ast::Statement {
kind: ast::StatementKind::Block(block),
span,
@@ -1857,7 +1858,9 @@ impl Parser {
let (_, span) = {
let ctx = &mut *ctx;
let block = &mut *block;
- lexer.capture_span(|lexer| self.statement(lexer, ctx, block))?
+ lexer.capture_span(|lexer| {
+ self.statement(lexer, ctx, block, brace_nesting_level)
+ })?
};
if block.stmts.len() != num_statements {
@@ -1902,7 +1905,7 @@ impl Parser {
lexer.expect(Token::Paren(')'))?;
}
- let (block, span) = self.block(lexer, ctx)?;
+ let (block, span) = self.block(lexer, ctx, brace_nesting_level)?;
body.stmts.push(ast::Statement {
kind: ast::StatementKind::Block(block),
span,
@@ -1964,13 +1967,15 @@ impl Parser {
&mut self,
lexer: &mut Lexer<'a>,
ctx: &mut ExpressionContext<'a, '_, '_>,
+ brace_nesting_level: u8,
) -> Result<ast::StatementKind<'a>, Error<'a>> {
let _ = lexer.next();
let mut body = ast::Block::default();
let mut continuing = ast::Block::default();
let mut break_if = None;
- lexer.expect(Token::Paren('{'))?;
+ let brace_span = lexer.expect_span(Token::Paren('{'))?;
+ let brace_nesting_level = Self::increase_brace_nesting(brace_nesting_level, brace_span)?;
ctx.local_table.push_scope();
@@ -1980,7 +1985,9 @@ impl Parser {
// the last thing in the loop body
// Expect a opening brace to start the continuing block
- lexer.expect(Token::Paren('{'))?;
+ let brace_span = lexer.expect_span(Token::Paren('{'))?;
+ let brace_nesting_level =
+ Self::increase_brace_nesting(brace_nesting_level, brace_span)?;
loop {
if lexer.skip(Token::Word("break")) {
// Branch for the `break if` statement, this statement
@@ -2009,7 +2016,7 @@ impl Parser {
break;
} else {
// Otherwise try to parse a statement
- self.statement(lexer, ctx, &mut continuing)?;
+ self.statement(lexer, ctx, &mut continuing, brace_nesting_level)?;
}
}
// Since the continuing block must be the last part of the loop body,
@@ -2023,7 +2030,7 @@ impl Parser {
break;
}
// Otherwise try to parse a statement
- self.statement(lexer, ctx, &mut body)?;
+ self.statement(lexer, ctx, &mut body, brace_nesting_level)?;
}
ctx.local_table.pop_scope();
@@ -2040,15 +2047,17 @@ impl Parser {
&mut self,
lexer: &mut Lexer<'a>,
ctx: &mut ExpressionContext<'a, '_, '_>,
+ brace_nesting_level: u8,
) -> Result<(ast::Block<'a>, Span), Error<'a>> {
self.push_rule_span(Rule::Block, lexer);
ctx.local_table.push_scope();
- lexer.expect(Token::Paren('{'))?;
+ let brace_span = lexer.expect_span(Token::Paren('{'))?;
+ let brace_nesting_level = Self::increase_brace_nesting(brace_nesting_level, brace_span)?;
let mut block = ast::Block::default();
while !lexer.skip(Token::Paren('}')) {
- self.statement(lexer, ctx, &mut block)?;
+ self.statement(lexer, ctx, &mut block, brace_nesting_level)?;
}
ctx.local_table.pop_scope();
@@ -2135,9 +2144,10 @@ impl Parser {
// do not use `self.block` here, since we must not push a new scope
lexer.expect(Token::Paren('{'))?;
+ let brace_nesting_level = 1;
let mut body = ast::Block::default();
while !lexer.skip(Token::Paren('}')) {
- self.statement(lexer, &mut ctx, &mut body)?;
+ self.statement(lexer, &mut ctx, &mut body, brace_nesting_level)?;
}
ctx.local_table.pop_scope();
@@ -2170,6 +2180,7 @@ impl Parser {
let mut early_depth_test = ParsedAttribute::default();
let (mut bind_index, mut bind_group) =
(ParsedAttribute::default(), ParsedAttribute::default());
+ let mut id = ParsedAttribute::default();
let mut dependencies = FastIndexSet::default();
let mut ctx = ExpressionContext {
@@ -2193,6 +2204,11 @@ impl Parser {
bind_group.set(self.general_expression(lexer, &mut ctx)?, name_span)?;
lexer.expect(Token::Paren(')'))?;
}
+ ("id", name_span) => {
+ lexer.expect(Token::Paren('('))?;
+ id.set(self.general_expression(lexer, &mut ctx)?, name_span)?;
+ lexer.expect(Token::Paren(')'))?;
+ }
("vertex", name_span) => {
stage.set(crate::ShaderStage::Vertex, name_span)?;
}
@@ -2283,6 +2299,30 @@ impl Parser {
Some(ast::GlobalDeclKind::Const(ast::Const { name, ty, init }))
}
+ (Token::Word("override"), _) => {
+ let name = lexer.next_ident()?;
+
+ let ty = if lexer.skip(Token::Separator(':')) {
+ Some(self.type_decl(lexer, &mut ctx)?)
+ } else {
+ None
+ };
+
+ let init = if lexer.skip(Token::Operation('=')) {
+ Some(self.general_expression(lexer, &mut ctx)?)
+ } else {
+ None
+ };
+
+ lexer.expect(Token::Separator(';'))?;
+
+ Some(ast::GlobalDeclKind::Override(ast::Override {
+ name,
+ id: id.value,
+ ty,
+ init,
+ }))
+ }
(Token::Word("var"), _) => {
let mut var = self.variable_decl(lexer, &mut ctx)?;
var.binding = binding.take();
@@ -2347,4 +2387,30 @@ impl Parser {
Ok(tu)
}
+
+ const fn increase_brace_nesting(
+ brace_nesting_level: u8,
+ brace_span: Span,
+ ) -> Result<u8, Error<'static>> {
+ // From [spec.](https://gpuweb.github.io/gpuweb/wgsl/#limits):
+ //
+ // > § 2.4. Limits
+ // >
+ // > …
+ // >
+ // > Maximum nesting depth of brace-enclosed statements in a function[:] 127
+ //
+ // _However_, we choose 64 instead because (a) it avoids stack overflows in CI and
+ // (b) we expect the limit to be decreased to 63 based on this conversation in
+ // WebGPU CTS upstream:
+ // <https://github.com/gpuweb/cts/pull/3389#discussion_r1543742701>
+ const BRACE_NESTING_MAXIMUM: u8 = 64;
+ if brace_nesting_level + 1 > BRACE_NESTING_MAXIMUM {
+ return Err(Error::ExceededLimitForNestedBraces {
+ span: brace_span,
+ limit: BRACE_NESTING_MAXIMUM,
+ });
+ }
+ Ok(brace_nesting_level + 1)
+ }
}
diff --git a/third_party/rust/naga/src/front/wgsl/to_wgsl.rs b/third_party/rust/naga/src/front/wgsl/to_wgsl.rs
index c8331ace09..63bc9f7317 100644
--- a/third_party/rust/naga/src/front/wgsl/to_wgsl.rs
+++ b/third_party/rust/naga/src/front/wgsl/to_wgsl.rs
@@ -226,7 +226,8 @@ mod tests {
let gctx = crate::proc::GlobalCtx {
types: &types,
constants: &crate::Arena::new(),
- const_expressions: &crate::Arena::new(),
+ overrides: &crate::Arena::new(),
+ global_expressions: &crate::Arena::new(),
};
let array = crate::TypeInner::Array {
base: mytype1,