diff options
Diffstat (limited to 'third_party/rust/naga/src/front/wgsl/parse')
-rw-r--r-- | third_party/rust/naga/src/front/wgsl/parse/ast.rs | 9 | ||||
-rw-r--r-- | third_party/rust/naga/src/front/wgsl/parse/conv.rs | 28 | ||||
-rw-r--r-- | third_party/rust/naga/src/front/wgsl/parse/mod.rs | 106 |
3 files changed, 123 insertions, 20 deletions
diff --git a/third_party/rust/naga/src/front/wgsl/parse/ast.rs b/third_party/rust/naga/src/front/wgsl/parse/ast.rs index dbaac523cb..ea8013ee7c 100644 --- a/third_party/rust/naga/src/front/wgsl/parse/ast.rs +++ b/third_party/rust/naga/src/front/wgsl/parse/ast.rs @@ -82,6 +82,7 @@ pub enum GlobalDeclKind<'a> { Fn(Function<'a>), Var(GlobalVariable<'a>), Const(Const<'a>), + Override(Override<'a>), Struct(Struct<'a>), Type(TypeAlias<'a>), } @@ -200,6 +201,14 @@ pub struct Const<'a> { pub init: Handle<Expression<'a>>, } +#[derive(Debug)] +pub struct Override<'a> { + pub name: Ident<'a>, + pub id: Option<Handle<Expression<'a>>>, + pub ty: Option<Handle<Type<'a>>>, + pub init: Option<Handle<Expression<'a>>>, +} + /// The size of an [`Array`] or [`BindingArray`]. /// /// [`Array`]: Type::Array diff --git a/third_party/rust/naga/src/front/wgsl/parse/conv.rs b/third_party/rust/naga/src/front/wgsl/parse/conv.rs index 1a4911a3bd..207f0eda41 100644 --- a/third_party/rust/naga/src/front/wgsl/parse/conv.rs +++ b/third_party/rust/naga/src/front/wgsl/parse/conv.rs @@ -35,6 +35,11 @@ pub fn map_built_in(word: &str, span: Span) -> Result<crate::BuiltIn, Error<'_>> "local_invocation_index" => crate::BuiltIn::LocalInvocationIndex, "workgroup_id" => crate::BuiltIn::WorkGroupId, "num_workgroups" => crate::BuiltIn::NumWorkGroups, + // subgroup + "num_subgroups" => crate::BuiltIn::NumSubgroups, + "subgroup_id" => crate::BuiltIn::SubgroupId, + "subgroup_size" => crate::BuiltIn::SubgroupSize, + "subgroup_invocation_id" => crate::BuiltIn::SubgroupInvocationId, _ => return Err(Error::UnknownBuiltin(span)), }) } @@ -260,3 +265,26 @@ pub fn map_conservative_depth( _ => Err(Error::UnknownConservativeDepth(span)), } } + +pub fn map_subgroup_operation( + word: &str, +) -> Option<(crate::SubgroupOperation, crate::CollectiveOperation)> { + use crate::CollectiveOperation as co; + use crate::SubgroupOperation as sg; + Some(match word { + "subgroupAll" => (sg::All, co::Reduce), + "subgroupAny" => (sg::Any, co::Reduce), + "subgroupAdd" => (sg::Add, co::Reduce), + "subgroupMul" => (sg::Mul, co::Reduce), + "subgroupMin" => (sg::Min, co::Reduce), + "subgroupMax" => (sg::Max, co::Reduce), + "subgroupAnd" => (sg::And, co::Reduce), + "subgroupOr" => (sg::Or, co::Reduce), + "subgroupXor" => (sg::Xor, co::Reduce), + "subgroupExclusiveAdd" => (sg::Add, co::ExclusiveScan), + "subgroupExclusiveMul" => (sg::Mul, co::ExclusiveScan), + "subgroupInclusiveAdd" => (sg::Add, co::InclusiveScan), + "subgroupInclusiveMul" => (sg::Mul, co::InclusiveScan), + _ => return None, + }) +} diff --git a/third_party/rust/naga/src/front/wgsl/parse/mod.rs b/third_party/rust/naga/src/front/wgsl/parse/mod.rs index 51fc2f013b..79ea1ae609 100644 --- a/third_party/rust/naga/src/front/wgsl/parse/mod.rs +++ b/third_party/rust/naga/src/front/wgsl/parse/mod.rs @@ -1619,22 +1619,21 @@ impl Parser { lexer: &mut Lexer<'a>, ctx: &mut ExpressionContext<'a, '_, '_>, block: &mut ast::Block<'a>, + brace_nesting_level: u8, ) -> Result<(), Error<'a>> { self.push_rule_span(Rule::Statement, lexer); match lexer.peek() { (Token::Separator(';'), _) => { let _ = lexer.next(); self.pop_rule_span(lexer); - return Ok(()); } (Token::Paren('{'), _) => { - let (inner, span) = self.block(lexer, ctx)?; + let (inner, span) = self.block(lexer, ctx, brace_nesting_level)?; block.stmts.push(ast::Statement { kind: ast::StatementKind::Block(inner), span, }); self.pop_rule_span(lexer); - return Ok(()); } (Token::Word(word), _) => { let kind = match word { @@ -1711,7 +1710,7 @@ impl Parser { let _ = lexer.next(); let condition = self.general_expression(lexer, ctx)?; - let accept = self.block(lexer, ctx)?.0; + let accept = self.block(lexer, ctx, brace_nesting_level)?.0; let mut elsif_stack = Vec::new(); let mut elseif_span_start = lexer.start_byte_offset(); @@ -1722,12 +1721,12 @@ impl Parser { if !lexer.skip(Token::Word("if")) { // ... else { ... } - break self.block(lexer, ctx)?.0; + break self.block(lexer, ctx, brace_nesting_level)?.0; } // ... else if (...) { ... } let other_condition = self.general_expression(lexer, ctx)?; - let other_block = self.block(lexer, ctx)?; + let other_block = self.block(lexer, ctx, brace_nesting_level)?; elsif_stack.push((elseif_span_start, other_condition, other_block)); elseif_span_start = lexer.start_byte_offset(); }; @@ -1759,7 +1758,9 @@ impl Parser { "switch" => { let _ = lexer.next(); let selector = self.general_expression(lexer, ctx)?; - lexer.expect(Token::Paren('{'))?; + let brace_span = lexer.expect_span(Token::Paren('{'))?; + let brace_nesting_level = + Self::increase_brace_nesting(brace_nesting_level, brace_span)?; let mut cases = Vec::new(); loop { @@ -1784,7 +1785,7 @@ impl Parser { }); }; - let body = self.block(lexer, ctx)?.0; + let body = self.block(lexer, ctx, brace_nesting_level)?.0; cases.push(ast::SwitchCase { value, @@ -1794,7 +1795,7 @@ impl Parser { } (Token::Word("default"), _) => { lexer.skip(Token::Separator(':')); - let body = self.block(lexer, ctx)?.0; + let body = self.block(lexer, ctx, brace_nesting_level)?.0; cases.push(ast::SwitchCase { value: ast::SwitchValue::Default, body, @@ -1810,7 +1811,7 @@ impl Parser { ast::StatementKind::Switch { selector, cases } } - "loop" => self.r#loop(lexer, ctx)?, + "loop" => self.r#loop(lexer, ctx, brace_nesting_level)?, "while" => { let _ = lexer.next(); let mut body = ast::Block::default(); @@ -1834,7 +1835,7 @@ impl Parser { span, }); - let (block, span) = self.block(lexer, ctx)?; + let (block, span) = self.block(lexer, ctx, brace_nesting_level)?; body.stmts.push(ast::Statement { kind: ast::StatementKind::Block(block), span, @@ -1857,7 +1858,9 @@ impl Parser { let (_, span) = { let ctx = &mut *ctx; let block = &mut *block; - lexer.capture_span(|lexer| self.statement(lexer, ctx, block))? + lexer.capture_span(|lexer| { + self.statement(lexer, ctx, block, brace_nesting_level) + })? }; if block.stmts.len() != num_statements { @@ -1902,7 +1905,7 @@ impl Parser { lexer.expect(Token::Paren(')'))?; } - let (block, span) = self.block(lexer, ctx)?; + let (block, span) = self.block(lexer, ctx, brace_nesting_level)?; body.stmts.push(ast::Statement { kind: ast::StatementKind::Block(block), span, @@ -1964,13 +1967,15 @@ impl Parser { &mut self, lexer: &mut Lexer<'a>, ctx: &mut ExpressionContext<'a, '_, '_>, + brace_nesting_level: u8, ) -> Result<ast::StatementKind<'a>, Error<'a>> { let _ = lexer.next(); let mut body = ast::Block::default(); let mut continuing = ast::Block::default(); let mut break_if = None; - lexer.expect(Token::Paren('{'))?; + let brace_span = lexer.expect_span(Token::Paren('{'))?; + let brace_nesting_level = Self::increase_brace_nesting(brace_nesting_level, brace_span)?; ctx.local_table.push_scope(); @@ -1980,7 +1985,9 @@ impl Parser { // the last thing in the loop body // Expect a opening brace to start the continuing block - lexer.expect(Token::Paren('{'))?; + let brace_span = lexer.expect_span(Token::Paren('{'))?; + let brace_nesting_level = + Self::increase_brace_nesting(brace_nesting_level, brace_span)?; loop { if lexer.skip(Token::Word("break")) { // Branch for the `break if` statement, this statement @@ -2009,7 +2016,7 @@ impl Parser { break; } else { // Otherwise try to parse a statement - self.statement(lexer, ctx, &mut continuing)?; + self.statement(lexer, ctx, &mut continuing, brace_nesting_level)?; } } // Since the continuing block must be the last part of the loop body, @@ -2023,7 +2030,7 @@ impl Parser { break; } // Otherwise try to parse a statement - self.statement(lexer, ctx, &mut body)?; + self.statement(lexer, ctx, &mut body, brace_nesting_level)?; } ctx.local_table.pop_scope(); @@ -2040,15 +2047,17 @@ impl Parser { &mut self, lexer: &mut Lexer<'a>, ctx: &mut ExpressionContext<'a, '_, '_>, + brace_nesting_level: u8, ) -> Result<(ast::Block<'a>, Span), Error<'a>> { self.push_rule_span(Rule::Block, lexer); ctx.local_table.push_scope(); - lexer.expect(Token::Paren('{'))?; + let brace_span = lexer.expect_span(Token::Paren('{'))?; + let brace_nesting_level = Self::increase_brace_nesting(brace_nesting_level, brace_span)?; let mut block = ast::Block::default(); while !lexer.skip(Token::Paren('}')) { - self.statement(lexer, ctx, &mut block)?; + self.statement(lexer, ctx, &mut block, brace_nesting_level)?; } ctx.local_table.pop_scope(); @@ -2135,9 +2144,10 @@ impl Parser { // do not use `self.block` here, since we must not push a new scope lexer.expect(Token::Paren('{'))?; + let brace_nesting_level = 1; let mut body = ast::Block::default(); while !lexer.skip(Token::Paren('}')) { - self.statement(lexer, &mut ctx, &mut body)?; + self.statement(lexer, &mut ctx, &mut body, brace_nesting_level)?; } ctx.local_table.pop_scope(); @@ -2170,6 +2180,7 @@ impl Parser { let mut early_depth_test = ParsedAttribute::default(); let (mut bind_index, mut bind_group) = (ParsedAttribute::default(), ParsedAttribute::default()); + let mut id = ParsedAttribute::default(); let mut dependencies = FastIndexSet::default(); let mut ctx = ExpressionContext { @@ -2193,6 +2204,11 @@ impl Parser { bind_group.set(self.general_expression(lexer, &mut ctx)?, name_span)?; lexer.expect(Token::Paren(')'))?; } + ("id", name_span) => { + lexer.expect(Token::Paren('('))?; + id.set(self.general_expression(lexer, &mut ctx)?, name_span)?; + lexer.expect(Token::Paren(')'))?; + } ("vertex", name_span) => { stage.set(crate::ShaderStage::Vertex, name_span)?; } @@ -2283,6 +2299,30 @@ impl Parser { Some(ast::GlobalDeclKind::Const(ast::Const { name, ty, init })) } + (Token::Word("override"), _) => { + let name = lexer.next_ident()?; + + let ty = if lexer.skip(Token::Separator(':')) { + Some(self.type_decl(lexer, &mut ctx)?) + } else { + None + }; + + let init = if lexer.skip(Token::Operation('=')) { + Some(self.general_expression(lexer, &mut ctx)?) + } else { + None + }; + + lexer.expect(Token::Separator(';'))?; + + Some(ast::GlobalDeclKind::Override(ast::Override { + name, + id: id.value, + ty, + init, + })) + } (Token::Word("var"), _) => { let mut var = self.variable_decl(lexer, &mut ctx)?; var.binding = binding.take(); @@ -2347,4 +2387,30 @@ impl Parser { Ok(tu) } + + const fn increase_brace_nesting( + brace_nesting_level: u8, + brace_span: Span, + ) -> Result<u8, Error<'static>> { + // From [spec.](https://gpuweb.github.io/gpuweb/wgsl/#limits): + // + // > § 2.4. Limits + // > + // > … + // > + // > Maximum nesting depth of brace-enclosed statements in a function[:] 127 + // + // _However_, we choose 64 instead because (a) it avoids stack overflows in CI and + // (b) we expect the limit to be decreased to 63 based on this conversation in + // WebGPU CTS upstream: + // <https://github.com/gpuweb/cts/pull/3389#discussion_r1543742701> + const BRACE_NESTING_MAXIMUM: u8 = 64; + if brace_nesting_level + 1 > BRACE_NESTING_MAXIMUM { + return Err(Error::ExceededLimitForNestedBraces { + span: brace_span, + limit: BRACE_NESTING_MAXIMUM, + }); + } + Ok(brace_nesting_level + 1) + } } |