summaryrefslogtreecommitdiffstats
path: root/third_party/rust/naga/src/front/wgsl
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-19 00:47:55 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-19 00:47:55 +0000
commit26a029d407be480d791972afb5975cf62c9360a6 (patch)
treef435a8308119effd964b339f76abb83a57c29483 /third_party/rust/naga/src/front/wgsl
parentInitial commit. (diff)
downloadfirefox-26a029d407be480d791972afb5975cf62c9360a6.tar.xz
firefox-26a029d407be480d791972afb5975cf62c9360a6.zip
Adding upstream version 124.0.1.upstream/124.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/naga/src/front/wgsl')
-rw-r--r--third_party/rust/naga/src/front/wgsl/error.rs775
-rw-r--r--third_party/rust/naga/src/front/wgsl/index.rs193
-rw-r--r--third_party/rust/naga/src/front/wgsl/lower/construction.rs616
-rw-r--r--third_party/rust/naga/src/front/wgsl/lower/conversion.rs503
-rw-r--r--third_party/rust/naga/src/front/wgsl/lower/mod.rs2760
-rw-r--r--third_party/rust/naga/src/front/wgsl/mod.rs49
-rw-r--r--third_party/rust/naga/src/front/wgsl/parse/ast.rs491
-rw-r--r--third_party/rust/naga/src/front/wgsl/parse/conv.rs254
-rw-r--r--third_party/rust/naga/src/front/wgsl/parse/lexer.rs739
-rw-r--r--third_party/rust/naga/src/front/wgsl/parse/mod.rs2350
-rw-r--r--third_party/rust/naga/src/front/wgsl/parse/number.rs420
-rw-r--r--third_party/rust/naga/src/front/wgsl/tests.rs637
-rw-r--r--third_party/rust/naga/src/front/wgsl/to_wgsl.rs283
13 files changed, 10070 insertions, 0 deletions
diff --git a/third_party/rust/naga/src/front/wgsl/error.rs b/third_party/rust/naga/src/front/wgsl/error.rs
new file mode 100644
index 0000000000..07e68f8dd9
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/error.rs
@@ -0,0 +1,775 @@
+use crate::front::wgsl::parse::lexer::Token;
+use crate::front::wgsl::Scalar;
+use crate::proc::{Alignment, ConstantEvaluatorError, ResolveError};
+use crate::{SourceLocation, Span};
+use codespan_reporting::diagnostic::{Diagnostic, Label};
+use codespan_reporting::files::SimpleFile;
+use codespan_reporting::term;
+use std::borrow::Cow;
+use std::ops::Range;
+use termcolor::{ColorChoice, NoColor, StandardStream};
+use thiserror::Error;
+
+#[derive(Clone, Debug)]
+pub struct ParseError {
+ message: String,
+ labels: Vec<(Span, Cow<'static, str>)>,
+ notes: Vec<String>,
+}
+
+impl ParseError {
+ pub fn labels(&self) -> impl ExactSizeIterator<Item = (Span, &str)> + '_ {
+ self.labels
+ .iter()
+ .map(|&(span, ref msg)| (span, msg.as_ref()))
+ }
+
+ pub fn message(&self) -> &str {
+ &self.message
+ }
+
+ fn diagnostic(&self) -> Diagnostic<()> {
+ let diagnostic = Diagnostic::error()
+ .with_message(self.message.to_string())
+ .with_labels(
+ self.labels
+ .iter()
+ .filter_map(|label| label.0.to_range().map(|range| (label, range)))
+ .map(|(label, range)| {
+ Label::primary((), range).with_message(label.1.to_string())
+ })
+ .collect(),
+ )
+ .with_notes(
+ self.notes
+ .iter()
+ .map(|note| format!("note: {note}"))
+ .collect(),
+ );
+ diagnostic
+ }
+
+ /// Emits a summary of the error to standard error stream.
+ pub fn emit_to_stderr(&self, source: &str) {
+ self.emit_to_stderr_with_path(source, "wgsl")
+ }
+
+ /// Emits a summary of the error to standard error stream.
+ pub fn emit_to_stderr_with_path<P>(&self, source: &str, path: P)
+ where
+ P: AsRef<std::path::Path>,
+ {
+ let path = path.as_ref().display().to_string();
+ let files = SimpleFile::new(path, source);
+ let config = codespan_reporting::term::Config::default();
+ let writer = StandardStream::stderr(ColorChoice::Auto);
+ term::emit(&mut writer.lock(), &config, &files, &self.diagnostic())
+ .expect("cannot write error");
+ }
+
+ /// Emits a summary of the error to a string.
+ pub fn emit_to_string(&self, source: &str) -> String {
+ self.emit_to_string_with_path(source, "wgsl")
+ }
+
+ /// Emits a summary of the error to a string.
+ pub fn emit_to_string_with_path<P>(&self, source: &str, path: P) -> String
+ where
+ P: AsRef<std::path::Path>,
+ {
+ let path = path.as_ref().display().to_string();
+ let files = SimpleFile::new(path, source);
+ let config = codespan_reporting::term::Config::default();
+ let mut writer = NoColor::new(Vec::new());
+ term::emit(&mut writer, &config, &files, &self.diagnostic()).expect("cannot write error");
+ String::from_utf8(writer.into_inner()).unwrap()
+ }
+
+ /// Returns a [`SourceLocation`] for the first label in the error message.
+ pub fn location(&self, source: &str) -> Option<SourceLocation> {
+ self.labels.get(0).map(|label| label.0.location(source))
+ }
+}
+
+impl std::fmt::Display for ParseError {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{}", self.message)
+ }
+}
+
+impl std::error::Error for ParseError {
+ fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
+ None
+ }
+}
+
+#[derive(Copy, Clone, Debug, PartialEq)]
+pub enum ExpectedToken<'a> {
+ Token(Token<'a>),
+ Identifier,
+ /// Expected: constant, parenthesized expression, identifier
+ PrimaryExpression,
+ /// Expected: assignment, increment/decrement expression
+ Assignment,
+ /// Expected: 'case', 'default', '}'
+ SwitchItem,
+ /// Expected: ',', ')'
+ WorkgroupSizeSeparator,
+ /// Expected: 'struct', 'let', 'var', 'type', ';', 'fn', eof
+ GlobalItem,
+ /// Expected a type.
+ Type,
+ /// Access of `var`, `let`, `const`.
+ Variable,
+ /// Access of a function
+ Function,
+}
+
+#[derive(Clone, Copy, Debug, Error, PartialEq)]
+pub enum NumberError {
+ #[error("invalid numeric literal format")]
+ Invalid,
+ #[error("numeric literal not representable by target type")]
+ NotRepresentable,
+ #[error("unimplemented f16 type")]
+ UnimplementedF16,
+}
+
+#[derive(Copy, Clone, Debug, PartialEq)]
+pub enum InvalidAssignmentType {
+ Other,
+ Swizzle,
+ ImmutableBinding(Span),
+}
+
+#[derive(Clone, Debug)]
+pub enum Error<'a> {
+ Unexpected(Span, ExpectedToken<'a>),
+ UnexpectedComponents(Span),
+ UnexpectedOperationInConstContext(Span),
+ BadNumber(Span, NumberError),
+ BadMatrixScalarKind(Span, Scalar),
+ BadAccessor(Span),
+ BadTexture(Span),
+ BadTypeCast {
+ span: Span,
+ from_type: String,
+ to_type: String,
+ },
+ BadTextureSampleType {
+ span: Span,
+ scalar: Scalar,
+ },
+ BadIncrDecrReferenceType(Span),
+ InvalidResolve(ResolveError),
+ InvalidForInitializer(Span),
+ /// A break if appeared outside of a continuing block
+ InvalidBreakIf(Span),
+ InvalidGatherComponent(Span),
+ InvalidConstructorComponentType(Span, i32),
+ InvalidIdentifierUnderscore(Span),
+ ReservedIdentifierPrefix(Span),
+ UnknownAddressSpace(Span),
+ RepeatedAttribute(Span),
+ UnknownAttribute(Span),
+ UnknownBuiltin(Span),
+ UnknownAccess(Span),
+ UnknownIdent(Span, &'a str),
+ UnknownScalarType(Span),
+ UnknownType(Span),
+ UnknownStorageFormat(Span),
+ UnknownConservativeDepth(Span),
+ SizeAttributeTooLow(Span, u32),
+ AlignAttributeTooLow(Span, Alignment),
+ NonPowerOfTwoAlignAttribute(Span),
+ InconsistentBinding(Span),
+ TypeNotConstructible(Span),
+ TypeNotInferable(Span),
+ InitializationTypeMismatch {
+ name: Span,
+ expected: String,
+ got: String,
+ },
+ MissingType(Span),
+ MissingAttribute(&'static str, Span),
+ InvalidAtomicPointer(Span),
+ InvalidAtomicOperandType(Span),
+ InvalidRayQueryPointer(Span),
+ Pointer(&'static str, Span),
+ NotPointer(Span),
+ NotReference(&'static str, Span),
+ InvalidAssignment {
+ span: Span,
+ ty: InvalidAssignmentType,
+ },
+ ReservedKeyword(Span),
+ /// Redefinition of an identifier (used for both module-scope and local redefinitions).
+ Redefinition {
+ /// Span of the identifier in the previous definition.
+ previous: Span,
+
+ /// Span of the identifier in the new definition.
+ current: Span,
+ },
+ /// A declaration refers to itself directly.
+ RecursiveDeclaration {
+ /// The location of the name of the declaration.
+ ident: Span,
+
+ /// The point at which it is used.
+ usage: Span,
+ },
+ /// A declaration refers to itself indirectly, through one or more other
+ /// definitions.
+ CyclicDeclaration {
+ /// The location of the name of some declaration in the cycle.
+ ident: Span,
+
+ /// The edges of the cycle of references.
+ ///
+ /// Each `(decl, reference)` pair indicates that the declaration whose
+ /// name is `decl` has an identifier at `reference` whose definition is
+ /// the next declaration in the cycle. The last pair's `reference` is
+ /// the same identifier as `ident`, above.
+ path: Vec<(Span, Span)>,
+ },
+ InvalidSwitchValue {
+ uint: bool,
+ span: Span,
+ },
+ CalledEntryPoint(Span),
+ WrongArgumentCount {
+ span: Span,
+ expected: Range<u32>,
+ found: u32,
+ },
+ FunctionReturnsVoid(Span),
+ InvalidWorkGroupUniformLoad(Span),
+ Internal(&'static str),
+ ExpectedConstExprConcreteIntegerScalar(Span),
+ ExpectedNonNegative(Span),
+ ExpectedPositiveArrayLength(Span),
+ MissingWorkgroupSize(Span),
+ ConstantEvaluatorError(ConstantEvaluatorError, Span),
+ AutoConversion {
+ dest_span: Span,
+ dest_type: String,
+ source_span: Span,
+ source_type: String,
+ },
+ AutoConversionLeafScalar {
+ dest_span: Span,
+ dest_scalar: String,
+ source_span: Span,
+ source_type: String,
+ },
+ ConcretizationFailed {
+ expr_span: Span,
+ expr_type: String,
+ scalar: String,
+ inner: ConstantEvaluatorError,
+ },
+}
+
+impl<'a> Error<'a> {
+ pub(crate) fn as_parse_error(&self, source: &'a str) -> ParseError {
+ match *self {
+ Error::Unexpected(unexpected_span, expected) => {
+ let expected_str = match expected {
+ ExpectedToken::Token(token) => {
+ match token {
+ Token::Separator(c) => format!("'{c}'"),
+ Token::Paren(c) => format!("'{c}'"),
+ Token::Attribute => "@".to_string(),
+ Token::Number(_) => "number".to_string(),
+ Token::Word(s) => s.to_string(),
+ Token::Operation(c) => format!("operation ('{c}')"),
+ Token::LogicalOperation(c) => format!("logical operation ('{c}')"),
+ Token::ShiftOperation(c) => format!("bitshift ('{c}{c}')"),
+ Token::AssignmentOperation(c) if c=='<' || c=='>' => format!("bitshift ('{c}{c}=')"),
+ Token::AssignmentOperation(c) => format!("operation ('{c}=')"),
+ Token::IncrementOperation => "increment operation".to_string(),
+ Token::DecrementOperation => "decrement operation".to_string(),
+ Token::Arrow => "->".to_string(),
+ Token::Unknown(c) => format!("unknown ('{c}')"),
+ Token::Trivia => "trivia".to_string(),
+ Token::End => "end".to_string(),
+ }
+ }
+ ExpectedToken::Identifier => "identifier".to_string(),
+ ExpectedToken::PrimaryExpression => "expression".to_string(),
+ ExpectedToken::Assignment => "assignment or increment/decrement".to_string(),
+ ExpectedToken::SwitchItem => "switch item ('case' or 'default') or a closing curly bracket to signify the end of the switch statement ('}')".to_string(),
+ ExpectedToken::WorkgroupSizeSeparator => "workgroup size separator (',') or a closing parenthesis".to_string(),
+ ExpectedToken::GlobalItem => "global item ('struct', 'const', 'var', 'alias', ';', 'fn') or the end of the file".to_string(),
+ ExpectedToken::Type => "type".to_string(),
+ ExpectedToken::Variable => "variable access".to_string(),
+ ExpectedToken::Function => "function name".to_string(),
+ };
+ ParseError {
+ message: format!(
+ "expected {}, found '{}'",
+ expected_str, &source[unexpected_span],
+ ),
+ labels: vec![(unexpected_span, format!("expected {expected_str}").into())],
+ notes: vec![],
+ }
+ }
+ Error::UnexpectedComponents(bad_span) => ParseError {
+ message: "unexpected components".to_string(),
+ labels: vec![(bad_span, "unexpected components".into())],
+ notes: vec![],
+ },
+ Error::UnexpectedOperationInConstContext(span) => ParseError {
+ message: "this operation is not supported in a const context".to_string(),
+ labels: vec![(span, "operation not supported here".into())],
+ notes: vec![],
+ },
+ Error::BadNumber(bad_span, ref err) => ParseError {
+ message: format!("{}: `{}`", err, &source[bad_span],),
+ labels: vec![(bad_span, err.to_string().into())],
+ notes: vec![],
+ },
+ Error::BadMatrixScalarKind(span, scalar) => ParseError {
+ message: format!(
+ "matrix scalar type must be floating-point, but found `{}`",
+ scalar.to_wgsl()
+ ),
+ labels: vec![(span, "must be floating-point (e.g. `f32`)".into())],
+ notes: vec![],
+ },
+ Error::BadAccessor(accessor_span) => ParseError {
+ message: format!("invalid field accessor `{}`", &source[accessor_span],),
+ labels: vec![(accessor_span, "invalid accessor".into())],
+ notes: vec![],
+ },
+ Error::UnknownIdent(ident_span, ident) => ParseError {
+ message: format!("no definition in scope for identifier: '{ident}'"),
+ labels: vec![(ident_span, "unknown identifier".into())],
+ notes: vec![],
+ },
+ Error::UnknownScalarType(bad_span) => ParseError {
+ message: format!("unknown scalar type: '{}'", &source[bad_span]),
+ labels: vec![(bad_span, "unknown scalar type".into())],
+ notes: vec!["Valid scalar types are f32, f64, i32, u32, bool".into()],
+ },
+ Error::BadTextureSampleType { span, scalar } => ParseError {
+ message: format!(
+ "texture sample type must be one of f32, i32 or u32, but found {}",
+ scalar.to_wgsl()
+ ),
+ labels: vec![(span, "must be one of f32, i32 or u32".into())],
+ notes: vec![],
+ },
+ Error::BadIncrDecrReferenceType(span) => ParseError {
+ message:
+ "increment/decrement operation requires reference type to be one of i32 or u32"
+ .to_string(),
+ labels: vec![(span, "must be a reference type of i32 or u32".into())],
+ notes: vec![],
+ },
+ Error::BadTexture(bad_span) => ParseError {
+ message: format!(
+ "expected an image, but found '{}' which is not an image",
+ &source[bad_span]
+ ),
+ labels: vec![(bad_span, "not an image".into())],
+ notes: vec![],
+ },
+ Error::BadTypeCast {
+ span,
+ ref from_type,
+ ref to_type,
+ } => {
+ let msg = format!("cannot cast a {from_type} to a {to_type}");
+ ParseError {
+ message: msg.clone(),
+ labels: vec![(span, msg.into())],
+ notes: vec![],
+ }
+ }
+ Error::InvalidResolve(ref resolve_error) => ParseError {
+ message: resolve_error.to_string(),
+ labels: vec![],
+ notes: vec![],
+ },
+ Error::InvalidForInitializer(bad_span) => ParseError {
+ message: format!(
+ "for(;;) initializer is not an assignment or a function call: '{}'",
+ &source[bad_span]
+ ),
+ labels: vec![(bad_span, "not an assignment or function call".into())],
+ notes: vec![],
+ },
+ Error::InvalidBreakIf(bad_span) => ParseError {
+ message: "A break if is only allowed in a continuing block".to_string(),
+ labels: vec![(bad_span, "not in a continuing block".into())],
+ notes: vec![],
+ },
+ Error::InvalidGatherComponent(bad_span) => ParseError {
+ message: format!(
+ "textureGather component '{}' doesn't exist, must be 0, 1, 2, or 3",
+ &source[bad_span]
+ ),
+ labels: vec![(bad_span, "invalid component".into())],
+ notes: vec![],
+ },
+ Error::InvalidConstructorComponentType(bad_span, component) => ParseError {
+ message: format!("invalid type for constructor component at index [{component}]"),
+ labels: vec![(bad_span, "invalid component type".into())],
+ notes: vec![],
+ },
+ Error::InvalidIdentifierUnderscore(bad_span) => ParseError {
+ message: "Identifier can't be '_'".to_string(),
+ labels: vec![(bad_span, "invalid identifier".into())],
+ notes: vec![
+ "Use phony assignment instead ('_ =' notice the absence of 'let' or 'var')"
+ .to_string(),
+ ],
+ },
+ Error::ReservedIdentifierPrefix(bad_span) => ParseError {
+ message: format!(
+ "Identifier starts with a reserved prefix: '{}'",
+ &source[bad_span]
+ ),
+ labels: vec![(bad_span, "invalid identifier".into())],
+ notes: vec![],
+ },
+ Error::UnknownAddressSpace(bad_span) => ParseError {
+ message: format!("unknown address space: '{}'", &source[bad_span]),
+ labels: vec![(bad_span, "unknown address space".into())],
+ notes: vec![],
+ },
+ Error::RepeatedAttribute(bad_span) => ParseError {
+ message: format!("repeated attribute: '{}'", &source[bad_span]),
+ labels: vec![(bad_span, "repeated attribute".into())],
+ notes: vec![],
+ },
+ Error::UnknownAttribute(bad_span) => ParseError {
+ message: format!("unknown attribute: '{}'", &source[bad_span]),
+ labels: vec![(bad_span, "unknown attribute".into())],
+ notes: vec![],
+ },
+ Error::UnknownBuiltin(bad_span) => ParseError {
+ message: format!("unknown builtin: '{}'", &source[bad_span]),
+ labels: vec![(bad_span, "unknown builtin".into())],
+ notes: vec![],
+ },
+ Error::UnknownAccess(bad_span) => ParseError {
+ message: format!("unknown access: '{}'", &source[bad_span]),
+ labels: vec![(bad_span, "unknown access".into())],
+ notes: vec![],
+ },
+ Error::UnknownStorageFormat(bad_span) => ParseError {
+ message: format!("unknown storage format: '{}'", &source[bad_span]),
+ labels: vec![(bad_span, "unknown storage format".into())],
+ notes: vec![],
+ },
+ Error::UnknownConservativeDepth(bad_span) => ParseError {
+ message: format!("unknown conservative depth: '{}'", &source[bad_span]),
+ labels: vec![(bad_span, "unknown conservative depth".into())],
+ notes: vec![],
+ },
+ Error::UnknownType(bad_span) => ParseError {
+ message: format!("unknown type: '{}'", &source[bad_span]),
+ labels: vec![(bad_span, "unknown type".into())],
+ notes: vec![],
+ },
+ Error::SizeAttributeTooLow(bad_span, min_size) => ParseError {
+ message: format!("struct member size must be at least {min_size}"),
+ labels: vec![(bad_span, format!("must be at least {min_size}").into())],
+ notes: vec![],
+ },
+ Error::AlignAttributeTooLow(bad_span, min_align) => ParseError {
+ message: format!("struct member alignment must be at least {min_align}"),
+ labels: vec![(bad_span, format!("must be at least {min_align}").into())],
+ notes: vec![],
+ },
+ Error::NonPowerOfTwoAlignAttribute(bad_span) => ParseError {
+ message: "struct member alignment must be a power of 2".to_string(),
+ labels: vec![(bad_span, "must be a power of 2".into())],
+ notes: vec![],
+ },
+ Error::InconsistentBinding(span) => ParseError {
+ message: "input/output binding is not consistent".to_string(),
+ labels: vec![(span, "input/output binding is not consistent".into())],
+ notes: vec![],
+ },
+ Error::TypeNotConstructible(span) => ParseError {
+ message: format!("type `{}` is not constructible", &source[span]),
+ labels: vec![(span, "type is not constructible".into())],
+ notes: vec![],
+ },
+ Error::TypeNotInferable(span) => ParseError {
+ message: "type can't be inferred".to_string(),
+ labels: vec![(span, "type can't be inferred".into())],
+ notes: vec![],
+ },
+ Error::InitializationTypeMismatch { name, ref expected, ref got } => {
+ ParseError {
+ message: format!(
+ "the type of `{}` is expected to be `{}`, but got `{}`",
+ &source[name], expected, got,
+ ),
+ labels: vec![(
+ name,
+ format!("definition of `{}`", &source[name]).into(),
+ )],
+ notes: vec![],
+ }
+ }
+ Error::MissingType(name_span) => ParseError {
+ message: format!("variable `{}` needs a type", &source[name_span]),
+ labels: vec![(
+ name_span,
+ format!("definition of `{}`", &source[name_span]).into(),
+ )],
+ notes: vec![],
+ },
+ Error::MissingAttribute(name, name_span) => ParseError {
+ message: format!(
+ "variable `{}` needs a '{}' attribute",
+ &source[name_span], name
+ ),
+ labels: vec![(
+ name_span,
+ format!("definition of `{}`", &source[name_span]).into(),
+ )],
+ notes: vec![],
+ },
+ Error::InvalidAtomicPointer(span) => ParseError {
+ message: "atomic operation is done on a pointer to a non-atomic".to_string(),
+ labels: vec![(span, "atomic pointer is invalid".into())],
+ notes: vec![],
+ },
+ Error::InvalidAtomicOperandType(span) => ParseError {
+ message: "atomic operand type is inconsistent with the operation".to_string(),
+ labels: vec![(span, "atomic operand type is invalid".into())],
+ notes: vec![],
+ },
+ Error::InvalidRayQueryPointer(span) => ParseError {
+ message: "ray query operation is done on a pointer to a non-ray-query".to_string(),
+ labels: vec![(span, "ray query pointer is invalid".into())],
+ notes: vec![],
+ },
+ Error::NotPointer(span) => ParseError {
+ message: "the operand of the `*` operator must be a pointer".to_string(),
+ labels: vec![(span, "expression is not a pointer".into())],
+ notes: vec![],
+ },
+ Error::NotReference(what, span) => ParseError {
+ message: format!("{what} must be a reference"),
+ labels: vec![(span, "expression is not a reference".into())],
+ notes: vec![],
+ },
+ Error::InvalidAssignment { span, ty } => {
+ let (extra_label, notes) = match ty {
+ InvalidAssignmentType::Swizzle => (
+ None,
+ vec![
+ "WGSL does not support assignments to swizzles".into(),
+ "consider assigning each component individually".into(),
+ ],
+ ),
+ InvalidAssignmentType::ImmutableBinding(binding_span) => (
+ Some((binding_span, "this is an immutable binding".into())),
+ vec![format!(
+ "consider declaring '{}' with `var` instead of `let`",
+ &source[binding_span]
+ )],
+ ),
+ InvalidAssignmentType::Other => (None, vec![]),
+ };
+
+ ParseError {
+ message: "invalid left-hand side of assignment".into(),
+ labels: std::iter::once((span, "cannot assign to this expression".into()))
+ .chain(extra_label)
+ .collect(),
+ notes,
+ }
+ }
+ Error::Pointer(what, span) => ParseError {
+ message: format!("{what} must not be a pointer"),
+ labels: vec![(span, "expression is a pointer".into())],
+ notes: vec![],
+ },
+ Error::ReservedKeyword(name_span) => ParseError {
+ message: format!("name `{}` is a reserved keyword", &source[name_span]),
+ labels: vec![(
+ name_span,
+ format!("definition of `{}`", &source[name_span]).into(),
+ )],
+ notes: vec![],
+ },
+ Error::Redefinition { previous, current } => ParseError {
+ message: format!("redefinition of `{}`", &source[current]),
+ labels: vec![
+ (
+ current,
+ format!("redefinition of `{}`", &source[current]).into(),
+ ),
+ (
+ previous,
+ format!("previous definition of `{}`", &source[previous]).into(),
+ ),
+ ],
+ notes: vec![],
+ },
+ Error::RecursiveDeclaration { ident, usage } => ParseError {
+ message: format!("declaration of `{}` is recursive", &source[ident]),
+ labels: vec![(ident, "".into()), (usage, "uses itself here".into())],
+ notes: vec![],
+ },
+ Error::CyclicDeclaration { ident, ref path } => ParseError {
+ message: format!("declaration of `{}` is cyclic", &source[ident]),
+ labels: path
+ .iter()
+ .enumerate()
+ .flat_map(|(i, &(ident, usage))| {
+ [
+ (ident, "".into()),
+ (
+ usage,
+ if i == path.len() - 1 {
+ "ending the cycle".into()
+ } else {
+ format!("uses `{}`", &source[ident]).into()
+ },
+ ),
+ ]
+ })
+ .collect(),
+ notes: vec![],
+ },
+ Error::InvalidSwitchValue { uint, span } => ParseError {
+ message: "invalid switch value".to_string(),
+ labels: vec![(
+ span,
+ if uint {
+ "expected unsigned integer"
+ } else {
+ "expected signed integer"
+ }
+ .into(),
+ )],
+ notes: vec![if uint {
+ format!("suffix the integer with a `u`: '{}u'", &source[span])
+ } else {
+ let span = span.to_range().unwrap();
+ format!(
+ "remove the `u` suffix: '{}'",
+ &source[span.start..span.end - 1]
+ )
+ }],
+ },
+ Error::CalledEntryPoint(span) => ParseError {
+ message: "entry point cannot be called".to_string(),
+ labels: vec![(span, "entry point cannot be called".into())],
+ notes: vec![],
+ },
+ Error::WrongArgumentCount {
+ span,
+ ref expected,
+ found,
+ } => ParseError {
+ message: format!(
+ "wrong number of arguments: expected {}, found {}",
+ if expected.len() < 2 {
+ format!("{}", expected.start)
+ } else {
+ format!("{}..{}", expected.start, expected.end)
+ },
+ found
+ ),
+ labels: vec![(span, "wrong number of arguments".into())],
+ notes: vec![],
+ },
+ Error::FunctionReturnsVoid(span) => ParseError {
+ message: "function does not return any value".to_string(),
+ labels: vec![(span, "".into())],
+ notes: vec![
+ "perhaps you meant to call the function in a separate statement?".into(),
+ ],
+ },
+ Error::InvalidWorkGroupUniformLoad(span) => ParseError {
+ message: "incorrect type passed to workgroupUniformLoad".into(),
+ labels: vec![(span, "".into())],
+ notes: vec!["passed type must be a workgroup pointer".into()],
+ },
+ Error::Internal(message) => ParseError {
+ message: "internal WGSL front end error".to_string(),
+ labels: vec![],
+ notes: vec![message.into()],
+ },
+ Error::ExpectedConstExprConcreteIntegerScalar(span) => ParseError {
+ message: "must be a const-expression that resolves to a concrete integer scalar (u32 or i32)".to_string(),
+ labels: vec![(span, "must resolve to u32 or i32".into())],
+ notes: vec![],
+ },
+ Error::ExpectedNonNegative(span) => ParseError {
+ message: "must be non-negative (>= 0)".to_string(),
+ labels: vec![(span, "must be non-negative".into())],
+ notes: vec![],
+ },
+ Error::ExpectedPositiveArrayLength(span) => ParseError {
+ message: "array element count must be positive (> 0)".to_string(),
+ labels: vec![(span, "must be positive".into())],
+ notes: vec![],
+ },
+ Error::ConstantEvaluatorError(ref e, span) => ParseError {
+ message: e.to_string(),
+ labels: vec![(span, "see msg".into())],
+ notes: vec![],
+ },
+ Error::MissingWorkgroupSize(span) => ParseError {
+ message: "workgroup size is missing on compute shader entry point".to_string(),
+ labels: vec![(
+ span,
+ "must be paired with a @workgroup_size attribute".into(),
+ )],
+ notes: vec![],
+ },
+ Error::AutoConversion { dest_span, ref dest_type, source_span, ref source_type } => ParseError {
+ message: format!("automatic conversions cannot convert `{source_type}` to `{dest_type}`"),
+ labels: vec![
+ (
+ dest_span,
+ format!("a value of type {dest_type} is required here").into(),
+ ),
+ (
+ source_span,
+ format!("this expression has type {source_type}").into(),
+ )
+ ],
+ notes: vec![],
+ },
+ Error::AutoConversionLeafScalar { dest_span, ref dest_scalar, source_span, ref source_type } => ParseError {
+ message: format!("automatic conversions cannot convert elements of `{source_type}` to `{dest_scalar}`"),
+ labels: vec![
+ (
+ dest_span,
+ format!("a value with elements of type {dest_scalar} is required here").into(),
+ ),
+ (
+ source_span,
+ format!("this expression has type {source_type}").into(),
+ )
+ ],
+ notes: vec![],
+ },
+ Error::ConcretizationFailed { expr_span, ref expr_type, ref scalar, ref inner } => ParseError {
+ message: format!("failed to convert expression to a concrete type: {}", inner),
+ labels: vec![
+ (
+ expr_span,
+ format!("this expression has type {}", expr_type).into(),
+ )
+ ],
+ notes: vec![
+ format!("the expression should have been converted to have {} scalar type", scalar),
+ ]
+ },
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/index.rs b/third_party/rust/naga/src/front/wgsl/index.rs
new file mode 100644
index 0000000000..a5524fe8f1
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/index.rs
@@ -0,0 +1,193 @@
+use super::Error;
+use crate::front::wgsl::parse::ast;
+use crate::{FastHashMap, Handle, Span};
+
+/// A `GlobalDecl` list in which each definition occurs before all its uses.
+pub struct Index<'a> {
+ dependency_order: Vec<Handle<ast::GlobalDecl<'a>>>,
+}
+
+impl<'a> Index<'a> {
+ /// Generate an `Index` for the given translation unit.
+ ///
+ /// Perform a topological sort on `tu`'s global declarations, placing
+ /// referents before the definitions that refer to them.
+ ///
+ /// Return an error if the graph of references between declarations contains
+ /// any cycles.
+ pub fn generate(tu: &ast::TranslationUnit<'a>) -> Result<Self, Error<'a>> {
+ // Produce a map from global definitions' names to their `Handle<GlobalDecl>`s.
+ // While doing so, reject conflicting definitions.
+ let mut globals = FastHashMap::with_capacity_and_hasher(tu.decls.len(), Default::default());
+ for (handle, decl) in tu.decls.iter() {
+ let ident = decl_ident(decl);
+ let name = ident.name;
+ if let Some(old) = globals.insert(name, handle) {
+ return Err(Error::Redefinition {
+ previous: decl_ident(&tu.decls[old]).span,
+ current: ident.span,
+ });
+ }
+ }
+
+ let len = tu.decls.len();
+ let solver = DependencySolver {
+ globals: &globals,
+ module: tu,
+ visited: vec![false; len],
+ temp_visited: vec![false; len],
+ path: Vec::new(),
+ out: Vec::with_capacity(len),
+ };
+ let dependency_order = solver.solve()?;
+
+ Ok(Self { dependency_order })
+ }
+
+ /// Iterate over `GlobalDecl`s, visiting each definition before all its uses.
+ ///
+ /// Produce handles for all of the `GlobalDecl`s of the `TranslationUnit`
+ /// passed to `Index::generate`, ordered so that a given declaration is
+ /// produced before any other declaration that uses it.
+ pub fn visit_ordered(&self) -> impl Iterator<Item = Handle<ast::GlobalDecl<'a>>> + '_ {
+ self.dependency_order.iter().copied()
+ }
+}
+
+/// An edge from a reference to its referent in the current depth-first
+/// traversal.
+///
+/// This is like `ast::Dependency`, except that we've determined which
+/// `GlobalDecl` it refers to.
+struct ResolvedDependency<'a> {
+ /// The referent of some identifier used in the current declaration.
+ decl: Handle<ast::GlobalDecl<'a>>,
+
+ /// Where that use occurs within the current declaration.
+ usage: Span,
+}
+
+/// Local state for ordering a `TranslationUnit`'s module-scope declarations.
+///
+/// Values of this type are used temporarily by `Index::generate`
+/// to perform a depth-first sort on the declarations.
+/// Technically, what we want is a topological sort, but a depth-first sort
+/// has one key benefit - it's much more efficient in storing
+/// the path of each node for error generation.
+struct DependencySolver<'source, 'temp> {
+ /// A map from module-scope definitions' names to their handles.
+ globals: &'temp FastHashMap<&'source str, Handle<ast::GlobalDecl<'source>>>,
+
+ /// The translation unit whose declarations we're ordering.
+ module: &'temp ast::TranslationUnit<'source>,
+
+ /// For each handle, whether we have pushed it onto `out` yet.
+ visited: Vec<bool>,
+
+ /// For each handle, whether it is an predecessor in the current depth-first
+ /// traversal. This is used to detect cycles in the reference graph.
+ temp_visited: Vec<bool>,
+
+ /// The current path in our depth-first traversal. Used for generating
+ /// error messages for non-trivial reference cycles.
+ path: Vec<ResolvedDependency<'source>>,
+
+ /// The list of declaration handles, with declarations before uses.
+ out: Vec<Handle<ast::GlobalDecl<'source>>>,
+}
+
+impl<'a> DependencySolver<'a, '_> {
+ /// Produce the sorted list of declaration handles, and check for cycles.
+ fn solve(mut self) -> Result<Vec<Handle<ast::GlobalDecl<'a>>>, Error<'a>> {
+ for (id, _) in self.module.decls.iter() {
+ if self.visited[id.index()] {
+ continue;
+ }
+
+ self.dfs(id)?;
+ }
+
+ Ok(self.out)
+ }
+
+ /// Ensure that all declarations used by `id` have been added to the
+ /// ordering, and then append `id` itself.
+ fn dfs(&mut self, id: Handle<ast::GlobalDecl<'a>>) -> Result<(), Error<'a>> {
+ let decl = &self.module.decls[id];
+ let id_usize = id.index();
+
+ self.temp_visited[id_usize] = true;
+ for dep in decl.dependencies.iter() {
+ if let Some(&dep_id) = self.globals.get(dep.ident) {
+ self.path.push(ResolvedDependency {
+ decl: dep_id,
+ usage: dep.usage,
+ });
+ let dep_id_usize = dep_id.index();
+
+ if self.temp_visited[dep_id_usize] {
+ // Found a cycle.
+ return if dep_id == id {
+ // A declaration refers to itself directly.
+ Err(Error::RecursiveDeclaration {
+ ident: decl_ident(decl).span,
+ usage: dep.usage,
+ })
+ } else {
+ // A declaration refers to itself indirectly, through
+ // one or more other definitions. Report the entire path
+ // of references.
+ let start_at = self
+ .path
+ .iter()
+ .rev()
+ .enumerate()
+ .find_map(|(i, dep)| (dep.decl == dep_id).then_some(i))
+ .unwrap_or(0);
+
+ Err(Error::CyclicDeclaration {
+ ident: decl_ident(&self.module.decls[dep_id]).span,
+ path: self.path[start_at..]
+ .iter()
+ .map(|curr_dep| {
+ let curr_id = curr_dep.decl;
+ let curr_decl = &self.module.decls[curr_id];
+
+ (decl_ident(curr_decl).span, curr_dep.usage)
+ })
+ .collect(),
+ })
+ };
+ } else if !self.visited[dep_id_usize] {
+ self.dfs(dep_id)?;
+ }
+
+ // Remove this edge from the current path.
+ self.path.pop();
+ }
+
+ // Ignore unresolved identifiers; they may be predeclared objects.
+ }
+
+ // Remove this node from the current path.
+ self.temp_visited[id_usize] = false;
+
+ // Now everything this declaration uses has been visited, and is already
+ // present in `out`. That means we we can append this one to the
+ // ordering, and mark it as visited.
+ self.out.push(id);
+ self.visited[id_usize] = true;
+
+ Ok(())
+ }
+}
+
+const fn decl_ident<'a>(decl: &ast::GlobalDecl<'a>) -> ast::Ident<'a> {
+ match decl.kind {
+ ast::GlobalDeclKind::Fn(ref f) => f.name,
+ ast::GlobalDeclKind::Var(ref v) => v.name,
+ ast::GlobalDeclKind::Const(ref c) => c.name,
+ ast::GlobalDeclKind::Struct(ref s) => s.name,
+ ast::GlobalDeclKind::Type(ref t) => t.name,
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/lower/construction.rs b/third_party/rust/naga/src/front/wgsl/lower/construction.rs
new file mode 100644
index 0000000000..de0d11d227
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/lower/construction.rs
@@ -0,0 +1,616 @@
+use std::num::NonZeroU32;
+
+use crate::front::wgsl::parse::ast;
+use crate::{Handle, Span};
+
+use crate::front::wgsl::error::Error;
+use crate::front::wgsl::lower::{ExpressionContext, Lowerer};
+
+/// A cooked form of `ast::ConstructorType` that uses Naga types whenever
+/// possible.
+enum Constructor<T> {
+ /// A vector construction whose component type is inferred from the
+ /// argument: `vec3(1.0)`.
+ PartialVector { size: crate::VectorSize },
+
+ /// A matrix construction whose component type is inferred from the
+ /// argument: `mat2x2(1,2,3,4)`.
+ PartialMatrix {
+ columns: crate::VectorSize,
+ rows: crate::VectorSize,
+ },
+
+ /// An array whose component type and size are inferred from the arguments:
+ /// `array(3,4,5)`.
+ PartialArray,
+
+ /// A known Naga type.
+ ///
+ /// When we match on this type, we need to see the `TypeInner` here, but at
+ /// the point that we build this value we'll still need mutable access to
+ /// the module later. To avoid borrowing from the module, the type parameter
+ /// `T` is `Handle<Type>` initially. Then we use `borrow_inner` to produce a
+ /// version holding a tuple `(Handle<Type>, &TypeInner)`.
+ Type(T),
+}
+
+impl Constructor<Handle<crate::Type>> {
+ /// Return an equivalent `Constructor` value that includes borrowed
+ /// `TypeInner` values alongside any type handles.
+ ///
+ /// The returned form is more convenient to match on, since the patterns
+ /// can actually see what the handle refers to.
+ fn borrow_inner(
+ self,
+ module: &crate::Module,
+ ) -> Constructor<(Handle<crate::Type>, &crate::TypeInner)> {
+ match self {
+ Constructor::PartialVector { size } => Constructor::PartialVector { size },
+ Constructor::PartialMatrix { columns, rows } => {
+ Constructor::PartialMatrix { columns, rows }
+ }
+ Constructor::PartialArray => Constructor::PartialArray,
+ Constructor::Type(handle) => Constructor::Type((handle, &module.types[handle].inner)),
+ }
+ }
+}
+
+impl Constructor<(Handle<crate::Type>, &crate::TypeInner)> {
+ fn to_error_string(&self, ctx: &ExpressionContext) -> String {
+ match *self {
+ Self::PartialVector { size } => {
+ format!("vec{}<?>", size as u32,)
+ }
+ Self::PartialMatrix { columns, rows } => {
+ format!("mat{}x{}<?>", columns as u32, rows as u32,)
+ }
+ Self::PartialArray => "array<?, ?>".to_string(),
+ Self::Type((handle, _inner)) => handle.to_wgsl(&ctx.module.to_ctx()),
+ }
+ }
+}
+
+enum Components<'a> {
+ None,
+ One {
+ component: Handle<crate::Expression>,
+ span: Span,
+ ty_inner: &'a crate::TypeInner,
+ },
+ Many {
+ components: Vec<Handle<crate::Expression>>,
+ spans: Vec<Span>,
+ },
+}
+
+impl Components<'_> {
+ fn into_components_vec(self) -> Vec<Handle<crate::Expression>> {
+ match self {
+ Self::None => vec![],
+ Self::One { component, .. } => vec![component],
+ Self::Many { components, .. } => components,
+ }
+ }
+}
+
+impl<'source, 'temp> Lowerer<'source, 'temp> {
+ /// Generate Naga IR for a type constructor expression.
+ ///
+ /// The `constructor` value represents the head of the constructor
+ /// expression, which is at least a hint of which type is being built; if
+ /// it's one of the `Partial` variants, we need to consider the argument
+ /// types as well.
+ ///
+ /// This is used for [`Construct`] expressions, but also for [`Call`]
+ /// expressions, once we've determined that the "callable" (in WGSL spec
+ /// terms) is actually a type.
+ ///
+ /// [`Construct`]: ast::Expression::Construct
+ /// [`Call`]: ast::Expression::Call
+ pub fn construct(
+ &mut self,
+ span: Span,
+ constructor: &ast::ConstructorType<'source>,
+ ty_span: Span,
+ components: &[Handle<ast::Expression<'source>>],
+ ctx: &mut ExpressionContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ use crate::proc::TypeResolution as Tr;
+
+ let constructor_h = self.constructor(constructor, ctx)?;
+
+ let components = match *components {
+ [] => Components::None,
+ [component] => {
+ let span = ctx.ast_expressions.get_span(component);
+ let component = self.expression_for_abstract(component, ctx)?;
+ let ty_inner = super::resolve_inner!(ctx, component);
+
+ Components::One {
+ component,
+ span,
+ ty_inner,
+ }
+ }
+ ref ast_components @ [_, _, ..] => {
+ let components = ast_components
+ .iter()
+ .map(|&expr| self.expression_for_abstract(expr, ctx))
+ .collect::<Result<_, _>>()?;
+ let spans = ast_components
+ .iter()
+ .map(|&expr| ctx.ast_expressions.get_span(expr))
+ .collect();
+
+ for &component in &components {
+ ctx.grow_types(component)?;
+ }
+
+ Components::Many { components, spans }
+ }
+ };
+
+ // Even though we computed `constructor` above, wait until now to borrow
+ // a reference to the `TypeInner`, so that the component-handling code
+ // above can have mutable access to the type arena.
+ let constructor = constructor_h.borrow_inner(ctx.module);
+
+ let expr;
+ match (components, constructor) {
+ // Empty constructor
+ (Components::None, dst_ty) => match dst_ty {
+ Constructor::Type((result_ty, _)) => {
+ return ctx.append_expression(crate::Expression::ZeroValue(result_ty), span)
+ }
+ Constructor::PartialVector { .. }
+ | Constructor::PartialMatrix { .. }
+ | Constructor::PartialArray => {
+ // We have no arguments from which to infer the result type, so
+ // partial constructors aren't acceptable here.
+ return Err(Error::TypeNotInferable(ty_span));
+ }
+ },
+
+ // Scalar constructor & conversion (scalar -> scalar)
+ (
+ Components::One {
+ component,
+ ty_inner: &crate::TypeInner::Scalar { .. },
+ ..
+ },
+ Constructor::Type((_, &crate::TypeInner::Scalar(scalar))),
+ ) => {
+ expr = crate::Expression::As {
+ expr: component,
+ kind: scalar.kind,
+ convert: Some(scalar.width),
+ };
+ }
+
+ // Vector conversion (vector -> vector)
+ (
+ Components::One {
+ component,
+ ty_inner: &crate::TypeInner::Vector { size: src_size, .. },
+ ..
+ },
+ Constructor::Type((
+ _,
+ &crate::TypeInner::Vector {
+ size: dst_size,
+ scalar: dst_scalar,
+ },
+ )),
+ ) if dst_size == src_size => {
+ expr = crate::Expression::As {
+ expr: component,
+ kind: dst_scalar.kind,
+ convert: Some(dst_scalar.width),
+ };
+ }
+
+ // Vector conversion (vector -> vector) - partial
+ (
+ Components::One {
+ component,
+ ty_inner: &crate::TypeInner::Vector { size: src_size, .. },
+ ..
+ },
+ Constructor::PartialVector { size: dst_size },
+ ) if dst_size == src_size => {
+ // This is a trivial conversion: the sizes match, and a Partial
+ // constructor doesn't specify a scalar type, so nothing can
+ // possibly happen.
+ return Ok(component);
+ }
+
+ // Matrix conversion (matrix -> matrix)
+ (
+ Components::One {
+ component,
+ ty_inner:
+ &crate::TypeInner::Matrix {
+ columns: src_columns,
+ rows: src_rows,
+ ..
+ },
+ ..
+ },
+ Constructor::Type((
+ _,
+ &crate::TypeInner::Matrix {
+ columns: dst_columns,
+ rows: dst_rows,
+ scalar: dst_scalar,
+ },
+ )),
+ ) if dst_columns == src_columns && dst_rows == src_rows => {
+ expr = crate::Expression::As {
+ expr: component,
+ kind: dst_scalar.kind,
+ convert: Some(dst_scalar.width),
+ };
+ }
+
+ // Matrix conversion (matrix -> matrix) - partial
+ (
+ Components::One {
+ component,
+ ty_inner:
+ &crate::TypeInner::Matrix {
+ columns: src_columns,
+ rows: src_rows,
+ ..
+ },
+ ..
+ },
+ Constructor::PartialMatrix {
+ columns: dst_columns,
+ rows: dst_rows,
+ },
+ ) if dst_columns == src_columns && dst_rows == src_rows => {
+ // This is a trivial conversion: the sizes match, and a Partial
+ // constructor doesn't specify a scalar type, so nothing can
+ // possibly happen.
+ return Ok(component);
+ }
+
+ // Vector constructor (splat) - infer type
+ (
+ Components::One {
+ component,
+ ty_inner: &crate::TypeInner::Scalar { .. },
+ ..
+ },
+ Constructor::PartialVector { size },
+ ) => {
+ expr = crate::Expression::Splat {
+ size,
+ value: component,
+ };
+ }
+
+ // Vector constructor (splat)
+ (
+ Components::One {
+ mut component,
+ ty_inner: &crate::TypeInner::Scalar(_),
+ ..
+ },
+ Constructor::Type((_, &crate::TypeInner::Vector { size, scalar })),
+ ) => {
+ ctx.convert_slice_to_common_leaf_scalar(
+ std::slice::from_mut(&mut component),
+ scalar,
+ )?;
+ expr = crate::Expression::Splat {
+ size,
+ value: component,
+ };
+ }
+
+ // Vector constructor (by elements), partial
+ (
+ Components::Many {
+ mut components,
+ spans,
+ },
+ Constructor::PartialVector { size },
+ ) => {
+ let consensus_scalar =
+ ctx.automatic_conversion_consensus(&components)
+ .map_err(|index| {
+ Error::InvalidConstructorComponentType(spans[index], index as i32)
+ })?;
+ ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?;
+ let inner = consensus_scalar.to_inner_vector(size);
+ let ty = ctx.ensure_type_exists(inner);
+ expr = crate::Expression::Compose { ty, components };
+ }
+
+ // Vector constructor (by elements), full type given
+ (
+ Components::Many { mut components, .. },
+ Constructor::Type((ty, &crate::TypeInner::Vector { scalar, .. })),
+ ) => {
+ ctx.try_automatic_conversions_for_vector(&mut components, scalar, ty_span)?;
+ expr = crate::Expression::Compose { ty, components };
+ }
+
+ // Matrix constructor (by elements), partial
+ (
+ Components::Many {
+ mut components,
+ spans,
+ },
+ Constructor::PartialMatrix { columns, rows },
+ ) if components.len() == columns as usize * rows as usize => {
+ let consensus_scalar =
+ ctx.automatic_conversion_consensus(&components)
+ .map_err(|index| {
+ Error::InvalidConstructorComponentType(spans[index], index as i32)
+ })?;
+ // We actually only accept floating-point elements.
+ let consensus_scalar = consensus_scalar
+ .automatic_conversion_combine(crate::Scalar::ABSTRACT_FLOAT)
+ .unwrap_or(consensus_scalar);
+ ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?;
+ let vec_ty = ctx.ensure_type_exists(consensus_scalar.to_inner_vector(rows));
+
+ let components = components
+ .chunks(rows as usize)
+ .map(|vec_components| {
+ ctx.append_expression(
+ crate::Expression::Compose {
+ ty: vec_ty,
+ components: Vec::from(vec_components),
+ },
+ Default::default(),
+ )
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+
+ let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix {
+ columns,
+ rows,
+ scalar: consensus_scalar,
+ });
+ expr = crate::Expression::Compose { ty, components };
+ }
+
+ // Matrix constructor (by elements), type given
+ (
+ Components::Many { mut components, .. },
+ Constructor::Type((
+ _,
+ &crate::TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ },
+ )),
+ ) if components.len() == columns as usize * rows as usize => {
+ let element = Tr::Value(crate::TypeInner::Scalar(scalar));
+ ctx.try_automatic_conversions_slice(&mut components, &element, ty_span)?;
+ let vec_ty = ctx.ensure_type_exists(scalar.to_inner_vector(rows));
+
+ let components = components
+ .chunks(rows as usize)
+ .map(|vec_components| {
+ ctx.append_expression(
+ crate::Expression::Compose {
+ ty: vec_ty,
+ components: Vec::from(vec_components),
+ },
+ Default::default(),
+ )
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+
+ let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ });
+ expr = crate::Expression::Compose { ty, components };
+ }
+
+ // Matrix constructor (by columns), partial
+ (
+ Components::Many {
+ mut components,
+ spans,
+ },
+ Constructor::PartialMatrix { columns, rows },
+ ) => {
+ let consensus_scalar =
+ ctx.automatic_conversion_consensus(&components)
+ .map_err(|index| {
+ Error::InvalidConstructorComponentType(spans[index], index as i32)
+ })?;
+ ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?;
+ let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix {
+ columns,
+ rows,
+ scalar: consensus_scalar,
+ });
+ expr = crate::Expression::Compose { ty, components };
+ }
+
+ // Matrix constructor (by columns), type given
+ (
+ Components::Many { mut components, .. },
+ Constructor::Type((
+ ty,
+ &crate::TypeInner::Matrix {
+ columns: _,
+ rows,
+ scalar,
+ },
+ )),
+ ) => {
+ let component_ty = crate::TypeInner::Vector { size: rows, scalar };
+ ctx.try_automatic_conversions_slice(
+ &mut components,
+ &Tr::Value(component_ty),
+ ty_span,
+ )?;
+ expr = crate::Expression::Compose { ty, components };
+ }
+
+ // Array constructor - infer type
+ (components, Constructor::PartialArray) => {
+ let mut components = components.into_components_vec();
+ if let Ok(consensus_scalar) = ctx.automatic_conversion_consensus(&components) {
+ // Note that this will *not* necessarily convert all the
+ // components to the same type! The `automatic_conversion_consensus`
+ // method only considers the parameters' leaf scalar
+ // types; the parameters themselves could be any mix of
+ // vectors, matrices, and scalars.
+ //
+ // But *if* it is possible for this array construction
+ // expression to be well-typed at all, then all the
+ // parameters must have the same type constructors (vec,
+ // matrix, scalar) applied to their leaf scalars, so
+ // reconciling their scalars is always the right thing to
+ // do. And if this array construction is not well-typed,
+ // these conversions will not make it so, and we can let
+ // validation catch the error.
+ ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?;
+ } else {
+ // There's no consensus scalar. Emit the `Compose`
+ // expression anyway, and let validation catch the problem.
+ }
+
+ let base = ctx.register_type(components[0])?;
+
+ let inner = crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(
+ NonZeroU32::new(u32::try_from(components.len()).unwrap()).unwrap(),
+ ),
+ stride: {
+ self.layouter.update(ctx.module.to_ctx()).unwrap();
+ self.layouter[base].to_stride()
+ },
+ };
+ let ty = ctx.ensure_type_exists(inner);
+
+ expr = crate::Expression::Compose { ty, components };
+ }
+
+ // Array constructor, explicit type
+ (components, Constructor::Type((ty, &crate::TypeInner::Array { base, .. }))) => {
+ let mut components = components.into_components_vec();
+ ctx.try_automatic_conversions_slice(&mut components, &Tr::Handle(base), ty_span)?;
+ expr = crate::Expression::Compose { ty, components };
+ }
+
+ // Struct constructor
+ (
+ components,
+ Constructor::Type((ty, &crate::TypeInner::Struct { ref members, .. })),
+ ) => {
+ let mut components = components.into_components_vec();
+ let struct_ty_span = ctx.module.types.get_span(ty);
+
+ // Make a vector of the members' type handles in advance, to
+ // avoid borrowing `members` from `ctx` while we generate
+ // new code.
+ let members: Vec<Handle<crate::Type>> = members.iter().map(|m| m.ty).collect();
+
+ for (component, &ty) in components.iter_mut().zip(&members) {
+ *component =
+ ctx.try_automatic_conversions(*component, &Tr::Handle(ty), struct_ty_span)?;
+ }
+ expr = crate::Expression::Compose { ty, components };
+ }
+
+ // ERRORS
+
+ // Bad conversion (type cast)
+ (Components::One { span, ty_inner, .. }, constructor) => {
+ let from_type = ty_inner.to_wgsl(&ctx.module.to_ctx());
+ return Err(Error::BadTypeCast {
+ span,
+ from_type,
+ to_type: constructor.to_error_string(ctx),
+ });
+ }
+
+ // Too many parameters for scalar constructor
+ (
+ Components::Many { spans, .. },
+ Constructor::Type((_, &crate::TypeInner::Scalar { .. })),
+ ) => {
+ let span = spans[1].until(spans.last().unwrap());
+ return Err(Error::UnexpectedComponents(span));
+ }
+
+ // Other types can't be constructed
+ _ => return Err(Error::TypeNotConstructible(ty_span)),
+ }
+
+ let expr = ctx.append_expression(expr, span)?;
+ Ok(expr)
+ }
+
+ /// Build a [`Constructor`] for a WGSL construction expression.
+ ///
+ /// If `constructor` conveys enough information to determine which Naga [`Type`]
+ /// we're actually building (i.e., it's not a partial constructor), then
+ /// ensure the `Type` exists in [`ctx.module`], and return
+ /// [`Constructor::Type`].
+ ///
+ /// Otherwise, return the [`Constructor`] partial variant corresponding to
+ /// `constructor`.
+ ///
+ /// [`Type`]: crate::Type
+ /// [`ctx.module`]: ExpressionContext::module
+ fn constructor<'out>(
+ &mut self,
+ constructor: &ast::ConstructorType<'source>,
+ ctx: &mut ExpressionContext<'source, '_, 'out>,
+ ) -> Result<Constructor<Handle<crate::Type>>, Error<'source>> {
+ let handle = match *constructor {
+ ast::ConstructorType::Scalar(scalar) => {
+ let ty = ctx.ensure_type_exists(scalar.to_inner_scalar());
+ Constructor::Type(ty)
+ }
+ ast::ConstructorType::PartialVector { size } => Constructor::PartialVector { size },
+ ast::ConstructorType::Vector { size, scalar } => {
+ let ty = ctx.ensure_type_exists(scalar.to_inner_vector(size));
+ Constructor::Type(ty)
+ }
+ ast::ConstructorType::PartialMatrix { columns, rows } => {
+ Constructor::PartialMatrix { columns, rows }
+ }
+ ast::ConstructorType::Matrix {
+ rows,
+ columns,
+ width,
+ } => {
+ let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix {
+ columns,
+ rows,
+ scalar: crate::Scalar::float(width),
+ });
+ Constructor::Type(ty)
+ }
+ ast::ConstructorType::PartialArray => Constructor::PartialArray,
+ ast::ConstructorType::Array { base, size } => {
+ let base = self.resolve_ast_type(base, &mut ctx.as_global())?;
+ let size = self.array_size(size, &mut ctx.as_global())?;
+
+ self.layouter.update(ctx.module.to_ctx()).unwrap();
+ let stride = self.layouter[base].to_stride();
+
+ let ty = ctx.ensure_type_exists(crate::TypeInner::Array { base, size, stride });
+ Constructor::Type(ty)
+ }
+ ast::ConstructorType::Type(ty) => Constructor::Type(ty),
+ };
+
+ Ok(handle)
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/lower/conversion.rs b/third_party/rust/naga/src/front/wgsl/lower/conversion.rs
new file mode 100644
index 0000000000..2a2690f096
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/lower/conversion.rs
@@ -0,0 +1,503 @@
+//! WGSL's automatic conversions for abstract types.
+
+use crate::{Handle, Span};
+
+impl<'source, 'temp, 'out> super::ExpressionContext<'source, 'temp, 'out> {
+ /// Try to use WGSL's automatic conversions to convert `expr` to `goal_ty`.
+ ///
+ /// If no conversions are necessary, return `expr` unchanged.
+ ///
+ /// If automatic conversions cannot convert `expr` to `goal_ty`, return an
+ /// [`AutoConversion`] error.
+ ///
+ /// Although the Load Rule is one of the automatic conversions, this
+ /// function assumes it has already been applied if appropriate, as
+ /// indicated by the fact that the Rust type of `expr` is not `Typed<_>`.
+ ///
+ /// [`AutoConversion`]: super::Error::AutoConversion
+ pub fn try_automatic_conversions(
+ &mut self,
+ expr: Handle<crate::Expression>,
+ goal_ty: &crate::proc::TypeResolution,
+ goal_span: Span,
+ ) -> Result<Handle<crate::Expression>, super::Error<'source>> {
+ let expr_span = self.get_expression_span(expr);
+ // Keep the TypeResolution so we can get type names for
+ // structs in error messages.
+ let expr_resolution = super::resolve!(self, expr);
+ let types = &self.module.types;
+ let expr_inner = expr_resolution.inner_with(types);
+ let goal_inner = goal_ty.inner_with(types);
+
+ // If `expr` already has the requested type, we're done.
+ if expr_inner.equivalent(goal_inner, types) {
+ return Ok(expr);
+ }
+
+ let (_expr_scalar, goal_scalar) =
+ match expr_inner.automatically_converts_to(goal_inner, types) {
+ Some(scalars) => scalars,
+ None => {
+ let gctx = &self.module.to_ctx();
+ let source_type = expr_resolution.to_wgsl(gctx);
+ let dest_type = goal_ty.to_wgsl(gctx);
+
+ return Err(super::Error::AutoConversion {
+ dest_span: goal_span,
+ dest_type,
+ source_span: expr_span,
+ source_type,
+ });
+ }
+ };
+
+ self.convert_leaf_scalar(expr, expr_span, goal_scalar)
+ }
+
+ /// Try to convert `expr`'s leaf scalar to `goal` using automatic conversions.
+ ///
+ /// If no conversions are necessary, return `expr` unchanged.
+ ///
+ /// If automatic conversions cannot convert `expr` to `goal_scalar`, return
+ /// an [`AutoConversionLeafScalar`] error.
+ ///
+ /// Although the Load Rule is one of the automatic conversions, this
+ /// function assumes it has already been applied if appropriate, as
+ /// indicated by the fact that the Rust type of `expr` is not `Typed<_>`.
+ ///
+ /// [`AutoConversionLeafScalar`]: super::Error::AutoConversionLeafScalar
+ pub fn try_automatic_conversion_for_leaf_scalar(
+ &mut self,
+ expr: Handle<crate::Expression>,
+ goal_scalar: crate::Scalar,
+ goal_span: Span,
+ ) -> Result<Handle<crate::Expression>, super::Error<'source>> {
+ let expr_span = self.get_expression_span(expr);
+ let expr_resolution = super::resolve!(self, expr);
+ let types = &self.module.types;
+ let expr_inner = expr_resolution.inner_with(types);
+
+ let make_error = || {
+ let gctx = &self.module.to_ctx();
+ let source_type = expr_resolution.to_wgsl(gctx);
+ super::Error::AutoConversionLeafScalar {
+ dest_span: goal_span,
+ dest_scalar: goal_scalar.to_wgsl(),
+ source_span: expr_span,
+ source_type,
+ }
+ };
+
+ let expr_scalar = match expr_inner.scalar() {
+ Some(scalar) => scalar,
+ None => return Err(make_error()),
+ };
+
+ if expr_scalar == goal_scalar {
+ return Ok(expr);
+ }
+
+ if !expr_scalar.automatically_converts_to(goal_scalar) {
+ return Err(make_error());
+ }
+
+ assert!(expr_scalar.is_abstract());
+
+ self.convert_leaf_scalar(expr, expr_span, goal_scalar)
+ }
+
+ fn convert_leaf_scalar(
+ &mut self,
+ expr: Handle<crate::Expression>,
+ expr_span: Span,
+ goal_scalar: crate::Scalar,
+ ) -> Result<Handle<crate::Expression>, super::Error<'source>> {
+ let expr_inner = super::resolve_inner!(self, expr);
+ if let crate::TypeInner::Array { .. } = *expr_inner {
+ self.as_const_evaluator()
+ .cast_array(expr, goal_scalar, expr_span)
+ .map_err(|err| super::Error::ConstantEvaluatorError(err, expr_span))
+ } else {
+ let cast = crate::Expression::As {
+ expr,
+ kind: goal_scalar.kind,
+ convert: Some(goal_scalar.width),
+ };
+ self.append_expression(cast, expr_span)
+ }
+ }
+
+ /// Try to convert `exprs` to `goal_ty` using WGSL's automatic conversions.
+ pub fn try_automatic_conversions_slice(
+ &mut self,
+ exprs: &mut [Handle<crate::Expression>],
+ goal_ty: &crate::proc::TypeResolution,
+ goal_span: Span,
+ ) -> Result<(), super::Error<'source>> {
+ for expr in exprs.iter_mut() {
+ *expr = self.try_automatic_conversions(*expr, goal_ty, goal_span)?;
+ }
+
+ Ok(())
+ }
+
+ /// Apply WGSL's automatic conversions to a vector constructor's arguments.
+ ///
+ /// When calling a vector constructor like `vec3<f32>(...)`, the parameters
+ /// can be a mix of scalars and vectors, with the latter being spread out to
+ /// contribute each of their components as a component of the new value.
+ /// When the element type is explicit, as with `<f32>` in the example above,
+ /// WGSL's automatic conversions should convert abstract scalar and vector
+ /// parameters to the constructor's required scalar type.
+ pub fn try_automatic_conversions_for_vector(
+ &mut self,
+ exprs: &mut [Handle<crate::Expression>],
+ goal_scalar: crate::Scalar,
+ goal_span: Span,
+ ) -> Result<(), super::Error<'source>> {
+ use crate::proc::TypeResolution as Tr;
+ use crate::TypeInner as Ti;
+ let goal_scalar_res = Tr::Value(Ti::Scalar(goal_scalar));
+
+ for (i, expr) in exprs.iter_mut().enumerate() {
+ // Keep the TypeResolution so we can get full type names
+ // in error messages.
+ let expr_resolution = super::resolve!(self, *expr);
+ let types = &self.module.types;
+ let expr_inner = expr_resolution.inner_with(types);
+
+ match *expr_inner {
+ Ti::Scalar(_) => {
+ *expr = self.try_automatic_conversions(*expr, &goal_scalar_res, goal_span)?;
+ }
+ Ti::Vector { size, scalar: _ } => {
+ let goal_vector_res = Tr::Value(Ti::Vector {
+ size,
+ scalar: goal_scalar,
+ });
+ *expr = self.try_automatic_conversions(*expr, &goal_vector_res, goal_span)?;
+ }
+ _ => {
+ let span = self.get_expression_span(*expr);
+ return Err(super::Error::InvalidConstructorComponentType(
+ span, i as i32,
+ ));
+ }
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Convert `expr` to the leaf scalar type `scalar`.
+ pub fn convert_to_leaf_scalar(
+ &mut self,
+ expr: &mut Handle<crate::Expression>,
+ goal: crate::Scalar,
+ ) -> Result<(), super::Error<'source>> {
+ let inner = super::resolve_inner!(self, *expr);
+ // Do nothing if `inner` doesn't even have leaf scalars;
+ // it's a type error that validation will catch.
+ if inner.scalar() != Some(goal) {
+ let cast = crate::Expression::As {
+ expr: *expr,
+ kind: goal.kind,
+ convert: Some(goal.width),
+ };
+ let expr_span = self.get_expression_span(*expr);
+ *expr = self.append_expression(cast, expr_span)?;
+ }
+
+ Ok(())
+ }
+
+ /// Convert all expressions in `exprs` to a common scalar type.
+ ///
+ /// Note that the caller is responsible for making sure these
+ /// conversions are actually justified. This function simply
+ /// generates `As` expressions, regardless of whether they are
+ /// permitted WGSL automatic conversions. Callers intending to
+ /// implement automatic conversions need to determine for
+ /// themselves whether the casts we we generate are justified,
+ /// perhaps by calling `TypeInner::automatically_converts_to` or
+ /// `Scalar::automatic_conversion_combine`.
+ pub fn convert_slice_to_common_leaf_scalar(
+ &mut self,
+ exprs: &mut [Handle<crate::Expression>],
+ goal: crate::Scalar,
+ ) -> Result<(), super::Error<'source>> {
+ for expr in exprs.iter_mut() {
+ self.convert_to_leaf_scalar(expr, goal)?;
+ }
+
+ Ok(())
+ }
+
+ /// Return an expression for the concretized value of `expr`.
+ ///
+ /// If `expr` is already concrete, return it unchanged.
+ pub fn concretize(
+ &mut self,
+ mut expr: Handle<crate::Expression>,
+ ) -> Result<Handle<crate::Expression>, super::Error<'source>> {
+ let inner = super::resolve_inner!(self, expr);
+ if let Some(scalar) = inner.automatically_convertible_scalar(&self.module.types) {
+ let concretized = scalar.concretize();
+ if concretized != scalar {
+ assert!(scalar.is_abstract());
+ let expr_span = self.get_expression_span(expr);
+ expr = self
+ .as_const_evaluator()
+ .cast_array(expr, concretized, expr_span)
+ .map_err(|err| {
+ // A `TypeResolution` includes the type's full name, if
+ // it has one. Also, avoid holding the borrow of `inner`
+ // across the call to `cast_array`.
+ let expr_type = &self.typifier()[expr];
+ super::Error::ConcretizationFailed {
+ expr_span,
+ expr_type: expr_type.to_wgsl(&self.module.to_ctx()),
+ scalar: concretized.to_wgsl(),
+ inner: err,
+ }
+ })?;
+ }
+ }
+
+ Ok(expr)
+ }
+
+ /// Find the consensus scalar of `components` under WGSL's automatic
+ /// conversions.
+ ///
+ /// If `components` can all be converted to any common scalar via
+ /// WGSL's automatic conversions, return the best such scalar.
+ ///
+ /// The `components` slice must not be empty. All elements' types must
+ /// have been resolved.
+ ///
+ /// If `components` are definitely not acceptable as arguments to such
+ /// constructors, return `Err(i)`, where `i` is the index in
+ /// `components` of some problematic argument.
+ ///
+ /// This function doesn't fully type-check the arguments - it only
+ /// considers their leaf scalar types. This means it may return `Ok`
+ /// even when the Naga validator will reject the resulting
+ /// construction expression later.
+ pub fn automatic_conversion_consensus<'handle, I>(
+ &self,
+ components: I,
+ ) -> Result<crate::Scalar, usize>
+ where
+ I: IntoIterator<Item = &'handle Handle<crate::Expression>>,
+ I::IntoIter: Clone, // for debugging
+ {
+ let types = &self.module.types;
+ let mut inners = components
+ .into_iter()
+ .map(|&c| self.typifier()[c].inner_with(types));
+ log::debug!(
+ "wgsl automatic_conversion_consensus: {:?}",
+ inners
+ .clone()
+ .map(|inner| inner.to_wgsl(&self.module.to_ctx()))
+ .collect::<Vec<String>>()
+ );
+ let mut best = inners.next().unwrap().scalar().ok_or(0_usize)?;
+ for (inner, i) in inners.zip(1..) {
+ let scalar = inner.scalar().ok_or(i)?;
+ match best.automatic_conversion_combine(scalar) {
+ Some(new_best) => {
+ best = new_best;
+ }
+ None => return Err(i),
+ }
+ }
+
+ log::debug!(" consensus: {:?}", best.to_wgsl());
+ Ok(best)
+ }
+}
+
+impl crate::TypeInner {
+ /// Determine whether `self` automatically converts to `goal`.
+ ///
+ /// If WGSL's automatic conversions (excluding the Load Rule) will
+ /// convert `self` to `goal`, then return a pair `(from, to)`,
+ /// where `from` and `to` are the scalar types of the leaf values
+ /// of `self` and `goal`.
+ ///
+ /// This function assumes that `self` and `goal` are different
+ /// types. Callers should first check whether any conversion is
+ /// needed at all.
+ ///
+ /// If the automatic conversions cannot convert `self` to `goal`,
+ /// return `None`.
+ fn automatically_converts_to(
+ &self,
+ goal: &Self,
+ types: &crate::UniqueArena<crate::Type>,
+ ) -> Option<(crate::Scalar, crate::Scalar)> {
+ use crate::ScalarKind as Sk;
+ use crate::TypeInner as Ti;
+
+ // Automatic conversions only change the scalar type of a value's leaves
+ // (e.g., `vec4<AbstractFloat>` to `vec4<f32>`), never the type
+ // constructors applied to those scalar types (e.g., never scalar to
+ // `vec4`, or `vec2` to `vec3`). So first we check that the type
+ // constructors match, extracting the leaf scalar types in the process.
+ let expr_scalar;
+ let goal_scalar;
+ match (self, goal) {
+ (&Ti::Scalar(expr), &Ti::Scalar(goal)) => {
+ expr_scalar = expr;
+ goal_scalar = goal;
+ }
+ (
+ &Ti::Vector {
+ size: expr_size,
+ scalar: expr,
+ },
+ &Ti::Vector {
+ size: goal_size,
+ scalar: goal,
+ },
+ ) if expr_size == goal_size => {
+ expr_scalar = expr;
+ goal_scalar = goal;
+ }
+ (
+ &Ti::Matrix {
+ rows: expr_rows,
+ columns: expr_columns,
+ scalar: expr,
+ },
+ &Ti::Matrix {
+ rows: goal_rows,
+ columns: goal_columns,
+ scalar: goal,
+ },
+ ) if expr_rows == goal_rows && expr_columns == goal_columns => {
+ expr_scalar = expr;
+ goal_scalar = goal;
+ }
+ (
+ &Ti::Array {
+ base: expr_base,
+ size: expr_size,
+ stride: _,
+ },
+ &Ti::Array {
+ base: goal_base,
+ size: goal_size,
+ stride: _,
+ },
+ ) if expr_size == goal_size => {
+ return types[expr_base]
+ .inner
+ .automatically_converts_to(&types[goal_base].inner, types);
+ }
+ _ => return None,
+ }
+
+ match (expr_scalar.kind, goal_scalar.kind) {
+ (Sk::AbstractFloat, Sk::Float) => {}
+ (Sk::AbstractInt, Sk::Sint | Sk::Uint | Sk::AbstractFloat | Sk::Float) => {}
+ _ => return None,
+ }
+
+ log::trace!(" okay: expr {expr_scalar:?}, goal {goal_scalar:?}");
+ Some((expr_scalar, goal_scalar))
+ }
+
+ fn automatically_convertible_scalar(
+ &self,
+ types: &crate::UniqueArena<crate::Type>,
+ ) -> Option<crate::Scalar> {
+ use crate::TypeInner as Ti;
+ match *self {
+ Ti::Scalar(scalar) | Ti::Vector { scalar, .. } | Ti::Matrix { scalar, .. } => {
+ Some(scalar)
+ }
+ Ti::Array { base, .. } => types[base].inner.automatically_convertible_scalar(types),
+ Ti::Atomic(_)
+ | Ti::Pointer { .. }
+ | Ti::ValuePointer { .. }
+ | Ti::Struct { .. }
+ | Ti::Image { .. }
+ | Ti::Sampler { .. }
+ | Ti::AccelerationStructure
+ | Ti::RayQuery
+ | Ti::BindingArray { .. } => None,
+ }
+ }
+}
+
+impl crate::Scalar {
+ /// Find the common type of `self` and `other` under WGSL's
+ /// automatic conversions.
+ ///
+ /// If there are any scalars to which WGSL's automatic conversions
+ /// will convert both `self` and `other`, return the best such
+ /// scalar. Otherwise, return `None`.
+ pub const fn automatic_conversion_combine(self, other: Self) -> Option<crate::Scalar> {
+ use crate::ScalarKind as Sk;
+
+ match (self.kind, other.kind) {
+ // When the kinds match...
+ (Sk::AbstractFloat, Sk::AbstractFloat)
+ | (Sk::AbstractInt, Sk::AbstractInt)
+ | (Sk::Sint, Sk::Sint)
+ | (Sk::Uint, Sk::Uint)
+ | (Sk::Float, Sk::Float)
+ | (Sk::Bool, Sk::Bool) => {
+ if self.width == other.width {
+ // ... either no conversion is necessary ...
+ Some(self)
+ } else {
+ // ... or no conversion is possible.
+ // We never convert concrete to concrete, and
+ // abstract types should have only one size.
+ None
+ }
+ }
+
+ // AbstractInt converts to AbstractFloat.
+ (Sk::AbstractFloat, Sk::AbstractInt) => Some(self),
+ (Sk::AbstractInt, Sk::AbstractFloat) => Some(other),
+
+ // AbstractFloat converts to Float.
+ (Sk::AbstractFloat, Sk::Float) => Some(other),
+ (Sk::Float, Sk::AbstractFloat) => Some(self),
+
+ // AbstractInt converts to concrete integer or float.
+ (Sk::AbstractInt, Sk::Uint | Sk::Sint | Sk::Float) => Some(other),
+ (Sk::Uint | Sk::Sint | Sk::Float, Sk::AbstractInt) => Some(self),
+
+ // AbstractFloat can't be reconciled with concrete integer types.
+ (Sk::AbstractFloat, Sk::Uint | Sk::Sint) | (Sk::Uint | Sk::Sint, Sk::AbstractFloat) => {
+ None
+ }
+
+ // Nothing can be reconciled with `bool`.
+ (Sk::Bool, _) | (_, Sk::Bool) => None,
+
+ // Different concrete types cannot be reconciled.
+ (Sk::Sint | Sk::Uint | Sk::Float, Sk::Sint | Sk::Uint | Sk::Float) => None,
+ }
+ }
+
+ /// Return `true` if automatic conversions will covert `self` to `goal`.
+ pub fn automatically_converts_to(self, goal: Self) -> bool {
+ self.automatic_conversion_combine(goal) == Some(goal)
+ }
+
+ const fn concretize(self) -> Self {
+ use crate::ScalarKind as Sk;
+ match self.kind {
+ Sk::Sint | Sk::Uint | Sk::Float | Sk::Bool => self,
+ Sk::AbstractInt => Self::I32,
+ Sk::AbstractFloat => Self::F32,
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/lower/mod.rs b/third_party/rust/naga/src/front/wgsl/lower/mod.rs
new file mode 100644
index 0000000000..ba9b49e135
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/lower/mod.rs
@@ -0,0 +1,2760 @@
+use std::num::NonZeroU32;
+
+use crate::front::wgsl::error::{Error, ExpectedToken, InvalidAssignmentType};
+use crate::front::wgsl::index::Index;
+use crate::front::wgsl::parse::number::Number;
+use crate::front::wgsl::parse::{ast, conv};
+use crate::front::Typifier;
+use crate::proc::{
+ ensure_block_returns, Alignment, ConstantEvaluator, Emitter, Layouter, ResolveContext,
+};
+use crate::{Arena, FastHashMap, FastIndexMap, Handle, Span};
+
+mod construction;
+mod conversion;
+
+/// Resolves the inner type of a given expression.
+///
+/// Expects a &mut [`ExpressionContext`] and a [`Handle<Expression>`].
+///
+/// Returns a &[`crate::TypeInner`].
+///
+/// Ideally, we would simply have a function that takes a `&mut ExpressionContext`
+/// and returns a `&TypeResolution`. Unfortunately, this leads the borrow checker
+/// to conclude that the mutable borrow lasts for as long as we are using the
+/// `&TypeResolution`, so we can't use the `ExpressionContext` for anything else -
+/// like, say, resolving another operand's type. Using a macro that expands to
+/// two separate calls, only the first of which needs a `&mut`,
+/// lets the borrow checker see that the mutable borrow is over.
+macro_rules! resolve_inner {
+ ($ctx:ident, $expr:expr) => {{
+ $ctx.grow_types($expr)?;
+ $ctx.typifier()[$expr].inner_with(&$ctx.module.types)
+ }};
+}
+pub(super) use resolve_inner;
+
+/// Resolves the inner types of two given expressions.
+///
+/// Expects a &mut [`ExpressionContext`] and two [`Handle<Expression>`]s.
+///
+/// Returns a tuple containing two &[`crate::TypeInner`].
+///
+/// See the documentation of [`resolve_inner!`] for why this macro is necessary.
+macro_rules! resolve_inner_binary {
+ ($ctx:ident, $left:expr, $right:expr) => {{
+ $ctx.grow_types($left)?;
+ $ctx.grow_types($right)?;
+ (
+ $ctx.typifier()[$left].inner_with(&$ctx.module.types),
+ $ctx.typifier()[$right].inner_with(&$ctx.module.types),
+ )
+ }};
+}
+
+/// Resolves the type of a given expression.
+///
+/// Expects a &mut [`ExpressionContext`] and a [`Handle<Expression>`].
+///
+/// Returns a &[`TypeResolution`].
+///
+/// See the documentation of [`resolve_inner!`] for why this macro is necessary.
+///
+/// [`TypeResolution`]: crate::proc::TypeResolution
+macro_rules! resolve {
+ ($ctx:ident, $expr:expr) => {{
+ $ctx.grow_types($expr)?;
+ &$ctx.typifier()[$expr]
+ }};
+}
+pub(super) use resolve;
+
+/// State for constructing a `crate::Module`.
+pub struct GlobalContext<'source, 'temp, 'out> {
+ /// The `TranslationUnit`'s expressions arena.
+ ast_expressions: &'temp Arena<ast::Expression<'source>>,
+
+ /// The `TranslationUnit`'s types arena.
+ types: &'temp Arena<ast::Type<'source>>,
+
+ // Naga IR values.
+ /// The map from the names of module-scope declarations to the Naga IR
+ /// `Handle`s we have built for them, owned by `Lowerer::lower`.
+ globals: &'temp mut FastHashMap<&'source str, LoweredGlobalDecl>,
+
+ /// The module we're constructing.
+ module: &'out mut crate::Module,
+
+ const_typifier: &'temp mut Typifier,
+}
+
+impl<'source> GlobalContext<'source, '_, '_> {
+ fn as_const(&mut self) -> ExpressionContext<'source, '_, '_> {
+ ExpressionContext {
+ ast_expressions: self.ast_expressions,
+ globals: self.globals,
+ types: self.types,
+ module: self.module,
+ const_typifier: self.const_typifier,
+ expr_type: ExpressionContextType::Constant,
+ }
+ }
+
+ fn ensure_type_exists(
+ &mut self,
+ name: Option<String>,
+ inner: crate::TypeInner,
+ ) -> Handle<crate::Type> {
+ self.module
+ .types
+ .insert(crate::Type { inner, name }, Span::UNDEFINED)
+ }
+}
+
+/// State for lowering a statement within a function.
+pub struct StatementContext<'source, 'temp, 'out> {
+ // WGSL AST values.
+ /// A reference to [`TranslationUnit::expressions`] for the translation unit
+ /// we're lowering.
+ ///
+ /// [`TranslationUnit::expressions`]: ast::TranslationUnit::expressions
+ ast_expressions: &'temp Arena<ast::Expression<'source>>,
+
+ /// A reference to [`TranslationUnit::types`] for the translation unit
+ /// we're lowering.
+ ///
+ /// [`TranslationUnit::types`]: ast::TranslationUnit::types
+ types: &'temp Arena<ast::Type<'source>>,
+
+ // Naga IR values.
+ /// The map from the names of module-scope declarations to the Naga IR
+ /// `Handle`s we have built for them, owned by `Lowerer::lower`.
+ globals: &'temp mut FastHashMap<&'source str, LoweredGlobalDecl>,
+
+ /// A map from each `ast::Local` handle to the Naga expression
+ /// we've built for it:
+ ///
+ /// - WGSL function arguments become Naga [`FunctionArgument`] expressions.
+ ///
+ /// - WGSL `var` declarations become Naga [`LocalVariable`] expressions.
+ ///
+ /// - WGSL `let` declararations become arbitrary Naga expressions.
+ ///
+ /// This always borrows the `local_table` local variable in
+ /// [`Lowerer::function`].
+ ///
+ /// [`LocalVariable`]: crate::Expression::LocalVariable
+ /// [`FunctionArgument`]: crate::Expression::FunctionArgument
+ local_table: &'temp mut FastHashMap<Handle<ast::Local>, Typed<Handle<crate::Expression>>>,
+
+ const_typifier: &'temp mut Typifier,
+ typifier: &'temp mut Typifier,
+ function: &'out mut crate::Function,
+ /// Stores the names of expressions that are assigned in `let` statement
+ /// Also stores the spans of the names, for use in errors.
+ named_expressions: &'out mut FastIndexMap<Handle<crate::Expression>, (String, Span)>,
+ module: &'out mut crate::Module,
+
+ /// Which `Expression`s in `self.naga_expressions` are const expressions, in
+ /// the WGSL sense.
+ ///
+ /// According to the WGSL spec, a const expression must not refer to any
+ /// `let` declarations, even if those declarations' initializers are
+ /// themselves const expressions. So this tracker is not simply concerned
+ /// with the form of the expressions; it is also tracking whether WGSL says
+ /// we should consider them to be const. See the use of `force_non_const` in
+ /// the code for lowering `let` bindings.
+ expression_constness: &'temp mut crate::proc::ExpressionConstnessTracker,
+}
+
+impl<'a, 'temp> StatementContext<'a, 'temp, '_> {
+ fn as_expression<'t>(
+ &'t mut self,
+ block: &'t mut crate::Block,
+ emitter: &'t mut Emitter,
+ ) -> ExpressionContext<'a, 't, '_>
+ where
+ 'temp: 't,
+ {
+ ExpressionContext {
+ globals: self.globals,
+ types: self.types,
+ ast_expressions: self.ast_expressions,
+ const_typifier: self.const_typifier,
+ module: self.module,
+ expr_type: ExpressionContextType::Runtime(RuntimeExpressionContext {
+ local_table: self.local_table,
+ function: self.function,
+ block,
+ emitter,
+ typifier: self.typifier,
+ expression_constness: self.expression_constness,
+ }),
+ }
+ }
+
+ fn as_global(&mut self) -> GlobalContext<'a, '_, '_> {
+ GlobalContext {
+ ast_expressions: self.ast_expressions,
+ globals: self.globals,
+ types: self.types,
+ module: self.module,
+ const_typifier: self.const_typifier,
+ }
+ }
+
+ fn invalid_assignment_type(&self, expr: Handle<crate::Expression>) -> InvalidAssignmentType {
+ if let Some(&(_, span)) = self.named_expressions.get(&expr) {
+ InvalidAssignmentType::ImmutableBinding(span)
+ } else {
+ match self.function.expressions[expr] {
+ crate::Expression::Swizzle { .. } => InvalidAssignmentType::Swizzle,
+ crate::Expression::Access { base, .. } => self.invalid_assignment_type(base),
+ crate::Expression::AccessIndex { base, .. } => self.invalid_assignment_type(base),
+ _ => InvalidAssignmentType::Other,
+ }
+ }
+ }
+}
+
+pub struct RuntimeExpressionContext<'temp, 'out> {
+ /// A map from [`ast::Local`] handles to the Naga expressions we've built for them.
+ ///
+ /// This is always [`StatementContext::local_table`] for the
+ /// enclosing statement; see that documentation for details.
+ local_table: &'temp FastHashMap<Handle<ast::Local>, Typed<Handle<crate::Expression>>>,
+
+ function: &'out mut crate::Function,
+ block: &'temp mut crate::Block,
+ emitter: &'temp mut Emitter,
+ typifier: &'temp mut Typifier,
+
+ /// Which `Expression`s in `self.naga_expressions` are const expressions, in
+ /// the WGSL sense.
+ ///
+ /// See [`StatementContext::expression_constness`] for details.
+ expression_constness: &'temp mut crate::proc::ExpressionConstnessTracker,
+}
+
+/// The type of Naga IR expression we are lowering an [`ast::Expression`] to.
+pub enum ExpressionContextType<'temp, 'out> {
+ /// We are lowering to an arbitrary runtime expression, to be
+ /// included in a function's body.
+ ///
+ /// The given [`RuntimeExpressionContext`] holds information about local
+ /// variables, arguments, and other definitions available only to runtime
+ /// expressions, not constant or override expressions.
+ Runtime(RuntimeExpressionContext<'temp, 'out>),
+
+ /// We are lowering to a constant expression, to be included in the module's
+ /// constant expression arena.
+ ///
+ /// Everything constant expressions are allowed to refer to is
+ /// available in the [`ExpressionContext`], so this variant
+ /// carries no further information.
+ Constant,
+}
+
+/// State for lowering an [`ast::Expression`] to Naga IR.
+///
+/// [`ExpressionContext`]s come in two kinds, distinguished by
+/// the value of the [`expr_type`] field:
+///
+/// - A [`Runtime`] context contributes [`naga::Expression`]s to a [`naga::Function`]'s
+/// runtime expression arena.
+///
+/// - A [`Constant`] context contributes [`naga::Expression`]s to a [`naga::Module`]'s
+/// constant expression arena.
+///
+/// [`ExpressionContext`]s are constructed in restricted ways:
+///
+/// - To get a [`Runtime`] [`ExpressionContext`], call
+/// [`StatementContext::as_expression`].
+///
+/// - To get a [`Constant`] [`ExpressionContext`], call
+/// [`GlobalContext::as_const`].
+///
+/// - You can demote a [`Runtime`] context to a [`Constant`] context
+/// by calling [`as_const`], but there's no way to go in the other
+/// direction, producing a runtime context from a constant one. This
+/// is because runtime expressions can refer to constant
+/// expressions, via [`Expression::Constant`], but constant
+/// expressions can't refer to a function's expressions.
+///
+/// Not to be confused with `wgsl::parse::ExpressionContext`, which is
+/// for parsing the `ast::Expression` in the first place.
+///
+/// [`expr_type`]: ExpressionContext::expr_type
+/// [`Runtime`]: ExpressionContextType::Runtime
+/// [`naga::Expression`]: crate::Expression
+/// [`naga::Function`]: crate::Function
+/// [`Constant`]: ExpressionContextType::Constant
+/// [`naga::Module`]: crate::Module
+/// [`as_const`]: ExpressionContext::as_const
+/// [`Expression::Constant`]: crate::Expression::Constant
+pub struct ExpressionContext<'source, 'temp, 'out> {
+ // WGSL AST values.
+ ast_expressions: &'temp Arena<ast::Expression<'source>>,
+ types: &'temp Arena<ast::Type<'source>>,
+
+ // Naga IR values.
+ /// The map from the names of module-scope declarations to the Naga IR
+ /// `Handle`s we have built for them, owned by `Lowerer::lower`.
+ globals: &'temp mut FastHashMap<&'source str, LoweredGlobalDecl>,
+
+ /// The IR [`Module`] we're constructing.
+ ///
+ /// [`Module`]: crate::Module
+ module: &'out mut crate::Module,
+
+ /// Type judgments for [`module::const_expressions`].
+ ///
+ /// [`module::const_expressions`]: crate::Module::const_expressions
+ const_typifier: &'temp mut Typifier,
+
+ /// Whether we are lowering a constant expression or a general
+ /// runtime expression, and the data needed in each case.
+ expr_type: ExpressionContextType<'temp, 'out>,
+}
+
+impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
+ fn as_const(&mut self) -> ExpressionContext<'source, '_, '_> {
+ ExpressionContext {
+ globals: self.globals,
+ types: self.types,
+ ast_expressions: self.ast_expressions,
+ const_typifier: self.const_typifier,
+ module: self.module,
+ expr_type: ExpressionContextType::Constant,
+ }
+ }
+
+ fn as_global(&mut self) -> GlobalContext<'source, '_, '_> {
+ GlobalContext {
+ ast_expressions: self.ast_expressions,
+ globals: self.globals,
+ types: self.types,
+ module: self.module,
+ const_typifier: self.const_typifier,
+ }
+ }
+
+ fn as_const_evaluator(&mut self) -> ConstantEvaluator {
+ match self.expr_type {
+ ExpressionContextType::Runtime(ref mut rctx) => ConstantEvaluator::for_wgsl_function(
+ self.module,
+ &mut rctx.function.expressions,
+ rctx.expression_constness,
+ rctx.emitter,
+ rctx.block,
+ ),
+ ExpressionContextType::Constant => ConstantEvaluator::for_wgsl_module(self.module),
+ }
+ }
+
+ fn append_expression(
+ &mut self,
+ expr: crate::Expression,
+ span: Span,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ let mut eval = self.as_const_evaluator();
+ match eval.try_eval_and_append(&expr, span) {
+ Ok(expr) => Ok(expr),
+
+ // `expr` is not a constant expression. This is fine as
+ // long as we're not building `Module::const_expressions`.
+ Err(err) => match self.expr_type {
+ ExpressionContextType::Runtime(ref mut rctx) => {
+ Ok(rctx.function.expressions.append(expr, span))
+ }
+ ExpressionContextType::Constant => Err(Error::ConstantEvaluatorError(err, span)),
+ },
+ }
+ }
+
+ fn const_access(&self, handle: Handle<crate::Expression>) -> Option<u32> {
+ match self.expr_type {
+ ExpressionContextType::Runtime(ref ctx) => {
+ if !ctx.expression_constness.is_const(handle) {
+ return None;
+ }
+
+ self.module
+ .to_ctx()
+ .eval_expr_to_u32_from(handle, &ctx.function.expressions)
+ .ok()
+ }
+ ExpressionContextType::Constant => self.module.to_ctx().eval_expr_to_u32(handle).ok(),
+ }
+ }
+
+ fn get_expression_span(&self, handle: Handle<crate::Expression>) -> Span {
+ match self.expr_type {
+ ExpressionContextType::Runtime(ref ctx) => ctx.function.expressions.get_span(handle),
+ ExpressionContextType::Constant => self.module.const_expressions.get_span(handle),
+ }
+ }
+
+ fn typifier(&self) -> &Typifier {
+ match self.expr_type {
+ ExpressionContextType::Runtime(ref ctx) => ctx.typifier,
+ ExpressionContextType::Constant => self.const_typifier,
+ }
+ }
+
+ fn runtime_expression_ctx(
+ &mut self,
+ span: Span,
+ ) -> Result<&mut RuntimeExpressionContext<'temp, 'out>, Error<'source>> {
+ match self.expr_type {
+ ExpressionContextType::Runtime(ref mut ctx) => Ok(ctx),
+ ExpressionContextType::Constant => Err(Error::UnexpectedOperationInConstContext(span)),
+ }
+ }
+
+ fn gather_component(
+ &mut self,
+ expr: Handle<crate::Expression>,
+ component_span: Span,
+ gather_span: Span,
+ ) -> Result<crate::SwizzleComponent, Error<'source>> {
+ match self.expr_type {
+ ExpressionContextType::Runtime(ref rctx) => {
+ if !rctx.expression_constness.is_const(expr) {
+ return Err(Error::ExpectedConstExprConcreteIntegerScalar(
+ component_span,
+ ));
+ }
+
+ let index = self
+ .module
+ .to_ctx()
+ .eval_expr_to_u32_from(expr, &rctx.function.expressions)
+ .map_err(|err| match err {
+ crate::proc::U32EvalError::NonConst => {
+ Error::ExpectedConstExprConcreteIntegerScalar(component_span)
+ }
+ crate::proc::U32EvalError::Negative => {
+ Error::ExpectedNonNegative(component_span)
+ }
+ })?;
+ crate::SwizzleComponent::XYZW
+ .get(index as usize)
+ .copied()
+ .ok_or(Error::InvalidGatherComponent(component_span))
+ }
+ // This means a `gather` operation appeared in a constant expression.
+ // This error refers to the `gather` itself, not its "component" argument.
+ ExpressionContextType::Constant => {
+ Err(Error::UnexpectedOperationInConstContext(gather_span))
+ }
+ }
+ }
+
+ /// Determine the type of `handle`, and add it to the module's arena.
+ ///
+ /// If you just need a `TypeInner` for `handle`'s type, use the
+ /// [`resolve_inner!`] macro instead. This function
+ /// should only be used when the type of `handle` needs to appear
+ /// in the module's final `Arena<Type>`, for example, if you're
+ /// creating a [`LocalVariable`] whose type is inferred from its
+ /// initializer.
+ ///
+ /// [`LocalVariable`]: crate::LocalVariable
+ fn register_type(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ ) -> Result<Handle<crate::Type>, Error<'source>> {
+ self.grow_types(handle)?;
+ // This is equivalent to calling ExpressionContext::typifier(),
+ // except that this lets the borrow checker see that it's okay
+ // to also borrow self.module.types mutably below.
+ let typifier = match self.expr_type {
+ ExpressionContextType::Runtime(ref ctx) => ctx.typifier,
+ ExpressionContextType::Constant => &*self.const_typifier,
+ };
+ Ok(typifier.register_type(handle, &mut self.module.types))
+ }
+
+ /// Resolve the types of all expressions up through `handle`.
+ ///
+ /// Ensure that [`self.typifier`] has a [`TypeResolution`] for
+ /// every expression in [`self.function.expressions`].
+ ///
+ /// This does not add types to any arena. The [`Typifier`]
+ /// documentation explains the steps we take to avoid filling
+ /// arenas with intermediate types.
+ ///
+ /// This function takes `&mut self`, so it can't conveniently
+ /// return a shared reference to the resulting `TypeResolution`:
+ /// the shared reference would extend the mutable borrow, and you
+ /// wouldn't be able to use `self` for anything else. Instead, you
+ /// should use [`register_type`] or one of [`resolve!`],
+ /// [`resolve_inner!`] or [`resolve_inner_binary!`].
+ ///
+ /// [`self.typifier`]: ExpressionContext::typifier
+ /// [`TypeResolution`]: crate::proc::TypeResolution
+ /// [`register_type`]: Self::register_type
+ /// [`Typifier`]: Typifier
+ fn grow_types(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ ) -> Result<&mut Self, Error<'source>> {
+ let empty_arena = Arena::new();
+ let resolve_ctx;
+ let typifier;
+ let expressions;
+ match self.expr_type {
+ ExpressionContextType::Runtime(ref mut ctx) => {
+ resolve_ctx = ResolveContext::with_locals(
+ self.module,
+ &ctx.function.local_variables,
+ &ctx.function.arguments,
+ );
+ typifier = &mut *ctx.typifier;
+ expressions = &ctx.function.expressions;
+ }
+ ExpressionContextType::Constant => {
+ resolve_ctx = ResolveContext::with_locals(self.module, &empty_arena, &[]);
+ typifier = self.const_typifier;
+ expressions = &self.module.const_expressions;
+ }
+ };
+ typifier
+ .grow(handle, expressions, &resolve_ctx)
+ .map_err(Error::InvalidResolve)?;
+
+ Ok(self)
+ }
+
+ fn image_data(
+ &mut self,
+ image: Handle<crate::Expression>,
+ span: Span,
+ ) -> Result<(crate::ImageClass, bool), Error<'source>> {
+ match *resolve_inner!(self, image) {
+ crate::TypeInner::Image { class, arrayed, .. } => Ok((class, arrayed)),
+ _ => Err(Error::BadTexture(span)),
+ }
+ }
+
+ fn prepare_args<'b>(
+ &mut self,
+ args: &'b [Handle<ast::Expression<'source>>],
+ min_args: u32,
+ span: Span,
+ ) -> ArgumentContext<'b, 'source> {
+ ArgumentContext {
+ args: args.iter(),
+ min_args,
+ args_used: 0,
+ total_args: args.len() as u32,
+ span,
+ }
+ }
+
+ /// Insert splats, if needed by the non-'*' operations.
+ ///
+ /// See the "Binary arithmetic expressions with mixed scalar and vector operands"
+ /// table in the WebGPU Shading Language specification for relevant operators.
+ ///
+ /// Multiply is not handled here as backends are expected to handle vec*scalar
+ /// operations, so inserting splats into the IR increases size needlessly.
+ fn binary_op_splat(
+ &mut self,
+ op: crate::BinaryOperator,
+ left: &mut Handle<crate::Expression>,
+ right: &mut Handle<crate::Expression>,
+ ) -> Result<(), Error<'source>> {
+ if matches!(
+ op,
+ crate::BinaryOperator::Add
+ | crate::BinaryOperator::Subtract
+ | crate::BinaryOperator::Divide
+ | crate::BinaryOperator::Modulo
+ ) {
+ match resolve_inner_binary!(self, *left, *right) {
+ (&crate::TypeInner::Vector { size, .. }, &crate::TypeInner::Scalar { .. }) => {
+ *right = self.append_expression(
+ crate::Expression::Splat {
+ size,
+ value: *right,
+ },
+ self.get_expression_span(*right),
+ )?;
+ }
+ (&crate::TypeInner::Scalar { .. }, &crate::TypeInner::Vector { size, .. }) => {
+ *left = self.append_expression(
+ crate::Expression::Splat { size, value: *left },
+ self.get_expression_span(*left),
+ )?;
+ }
+ _ => {}
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Add a single expression to the expression table that is not covered by `self.emitter`.
+ ///
+ /// This is useful for `CallResult` and `AtomicResult` expressions, which should not be covered by
+ /// `Emit` statements.
+ fn interrupt_emitter(
+ &mut self,
+ expression: crate::Expression,
+ span: Span,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ match self.expr_type {
+ ExpressionContextType::Runtime(ref mut rctx) => {
+ rctx.block
+ .extend(rctx.emitter.finish(&rctx.function.expressions));
+ }
+ ExpressionContextType::Constant => {}
+ }
+ let result = self.append_expression(expression, span);
+ match self.expr_type {
+ ExpressionContextType::Runtime(ref mut rctx) => {
+ rctx.emitter.start(&rctx.function.expressions);
+ }
+ ExpressionContextType::Constant => {}
+ }
+ result
+ }
+
+ /// Apply the WGSL Load Rule to `expr`.
+ ///
+ /// If `expr` is has type `ref<SC, T, A>`, perform a load to produce a value of type
+ /// `T`. Otherwise, return `expr` unchanged.
+ fn apply_load_rule(
+ &mut self,
+ expr: Typed<Handle<crate::Expression>>,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ match expr {
+ Typed::Reference(pointer) => {
+ let load = crate::Expression::Load { pointer };
+ let span = self.get_expression_span(pointer);
+ self.append_expression(load, span)
+ }
+ Typed::Plain(handle) => Ok(handle),
+ }
+ }
+
+ fn ensure_type_exists(&mut self, inner: crate::TypeInner) -> Handle<crate::Type> {
+ self.as_global().ensure_type_exists(None, inner)
+ }
+}
+
+struct ArgumentContext<'ctx, 'source> {
+ args: std::slice::Iter<'ctx, Handle<ast::Expression<'source>>>,
+ min_args: u32,
+ args_used: u32,
+ total_args: u32,
+ span: Span,
+}
+
+impl<'source> ArgumentContext<'_, 'source> {
+ pub fn finish(self) -> Result<(), Error<'source>> {
+ if self.args.len() == 0 {
+ Ok(())
+ } else {
+ Err(Error::WrongArgumentCount {
+ found: self.total_args,
+ expected: self.min_args..self.args_used + 1,
+ span: self.span,
+ })
+ }
+ }
+
+ pub fn next(&mut self) -> Result<Handle<ast::Expression<'source>>, Error<'source>> {
+ match self.args.next().copied() {
+ Some(arg) => {
+ self.args_used += 1;
+ Ok(arg)
+ }
+ None => Err(Error::WrongArgumentCount {
+ found: self.total_args,
+ expected: self.min_args..self.args_used + 1,
+ span: self.span,
+ }),
+ }
+ }
+}
+
+/// WGSL type annotations on expressions, types, values, etc.
+///
+/// Naga and WGSL types are very close, but Naga lacks WGSL's `ref` types, which
+/// we need to know to apply the Load Rule. This enum carries some WGSL or Naga
+/// datum along with enough information to determine its corresponding WGSL
+/// type.
+///
+/// The `T` type parameter can be any expression-like thing:
+///
+/// - `Typed<Handle<crate::Type>>` can represent a full WGSL type. For example,
+/// given some Naga `Pointer` type `ptr`, a WGSL reference type is a
+/// `Typed::Reference(ptr)` whereas a WGSL pointer type is a
+/// `Typed::Plain(ptr)`.
+///
+/// - `Typed<crate::Expression>` or `Typed<Handle<crate::Expression>>` can
+/// represent references similarly.
+///
+/// Use the `map` and `try_map` methods to convert from one expression
+/// representation to another.
+///
+/// [`Expression`]: crate::Expression
+#[derive(Debug, Copy, Clone)]
+enum Typed<T> {
+ /// A WGSL reference.
+ Reference(T),
+
+ /// A WGSL plain type.
+ Plain(T),
+}
+
+impl<T> Typed<T> {
+ fn map<U>(self, mut f: impl FnMut(T) -> U) -> Typed<U> {
+ match self {
+ Self::Reference(v) => Typed::Reference(f(v)),
+ Self::Plain(v) => Typed::Plain(f(v)),
+ }
+ }
+
+ fn try_map<U, E>(self, mut f: impl FnMut(T) -> Result<U, E>) -> Result<Typed<U>, E> {
+ Ok(match self {
+ Self::Reference(expr) => Typed::Reference(f(expr)?),
+ Self::Plain(expr) => Typed::Plain(f(expr)?),
+ })
+ }
+}
+
+/// A single vector component or swizzle.
+///
+/// This represents the things that can appear after the `.` in a vector access
+/// expression: either a single component name, or a series of them,
+/// representing a swizzle.
+enum Components {
+ Single(u32),
+ Swizzle {
+ size: crate::VectorSize,
+ pattern: [crate::SwizzleComponent; 4],
+ },
+}
+
+impl Components {
+ const fn letter_component(letter: char) -> Option<crate::SwizzleComponent> {
+ use crate::SwizzleComponent as Sc;
+ match letter {
+ 'x' | 'r' => Some(Sc::X),
+ 'y' | 'g' => Some(Sc::Y),
+ 'z' | 'b' => Some(Sc::Z),
+ 'w' | 'a' => Some(Sc::W),
+ _ => None,
+ }
+ }
+
+ fn single_component(name: &str, name_span: Span) -> Result<u32, Error> {
+ let ch = name.chars().next().ok_or(Error::BadAccessor(name_span))?;
+ match Self::letter_component(ch) {
+ Some(sc) => Ok(sc as u32),
+ None => Err(Error::BadAccessor(name_span)),
+ }
+ }
+
+ /// Construct a `Components` value from a 'member' name, like `"wzy"` or `"x"`.
+ ///
+ /// Use `name_span` for reporting errors in parsing the component string.
+ fn new(name: &str, name_span: Span) -> Result<Self, Error> {
+ let size = match name.len() {
+ 1 => return Ok(Components::Single(Self::single_component(name, name_span)?)),
+ 2 => crate::VectorSize::Bi,
+ 3 => crate::VectorSize::Tri,
+ 4 => crate::VectorSize::Quad,
+ _ => return Err(Error::BadAccessor(name_span)),
+ };
+
+ let mut pattern = [crate::SwizzleComponent::X; 4];
+ for (comp, ch) in pattern.iter_mut().zip(name.chars()) {
+ *comp = Self::letter_component(ch).ok_or(Error::BadAccessor(name_span))?;
+ }
+
+ Ok(Components::Swizzle { size, pattern })
+ }
+}
+
+/// An `ast::GlobalDecl` for which we have built the Naga IR equivalent.
+enum LoweredGlobalDecl {
+ Function(Handle<crate::Function>),
+ Var(Handle<crate::GlobalVariable>),
+ Const(Handle<crate::Constant>),
+ Type(Handle<crate::Type>),
+ EntryPoint,
+}
+
+enum Texture {
+ Gather,
+ GatherCompare,
+
+ Sample,
+ SampleBias,
+ SampleCompare,
+ SampleCompareLevel,
+ SampleGrad,
+ SampleLevel,
+ // SampleBaseClampToEdge,
+}
+
+impl Texture {
+ pub fn map(word: &str) -> Option<Self> {
+ Some(match word {
+ "textureGather" => Self::Gather,
+ "textureGatherCompare" => Self::GatherCompare,
+
+ "textureSample" => Self::Sample,
+ "textureSampleBias" => Self::SampleBias,
+ "textureSampleCompare" => Self::SampleCompare,
+ "textureSampleCompareLevel" => Self::SampleCompareLevel,
+ "textureSampleGrad" => Self::SampleGrad,
+ "textureSampleLevel" => Self::SampleLevel,
+ // "textureSampleBaseClampToEdge" => Some(Self::SampleBaseClampToEdge),
+ _ => return None,
+ })
+ }
+
+ pub const fn min_argument_count(&self) -> u32 {
+ match *self {
+ Self::Gather => 3,
+ Self::GatherCompare => 4,
+
+ Self::Sample => 3,
+ Self::SampleBias => 5,
+ Self::SampleCompare => 5,
+ Self::SampleCompareLevel => 5,
+ Self::SampleGrad => 6,
+ Self::SampleLevel => 5,
+ // Self::SampleBaseClampToEdge => 3,
+ }
+ }
+}
+
+pub struct Lowerer<'source, 'temp> {
+ index: &'temp Index<'source>,
+ layouter: Layouter,
+}
+
+impl<'source, 'temp> Lowerer<'source, 'temp> {
+ pub fn new(index: &'temp Index<'source>) -> Self {
+ Self {
+ index,
+ layouter: Layouter::default(),
+ }
+ }
+
+ pub fn lower(
+ &mut self,
+ tu: &'temp ast::TranslationUnit<'source>,
+ ) -> Result<crate::Module, Error<'source>> {
+ let mut module = crate::Module::default();
+
+ let mut ctx = GlobalContext {
+ ast_expressions: &tu.expressions,
+ globals: &mut FastHashMap::default(),
+ types: &tu.types,
+ module: &mut module,
+ const_typifier: &mut Typifier::new(),
+ };
+
+ for decl_handle in self.index.visit_ordered() {
+ let span = tu.decls.get_span(decl_handle);
+ let decl = &tu.decls[decl_handle];
+
+ match decl.kind {
+ ast::GlobalDeclKind::Fn(ref f) => {
+ let lowered_decl = self.function(f, span, &mut ctx)?;
+ ctx.globals.insert(f.name.name, lowered_decl);
+ }
+ ast::GlobalDeclKind::Var(ref v) => {
+ let ty = self.resolve_ast_type(v.ty, &mut ctx)?;
+
+ let init;
+ if let Some(init_ast) = v.init {
+ let mut ectx = ctx.as_const();
+ let lowered = self.expression_for_abstract(init_ast, &mut ectx)?;
+ let ty_res = crate::proc::TypeResolution::Handle(ty);
+ let converted = ectx
+ .try_automatic_conversions(lowered, &ty_res, v.name.span)
+ .map_err(|error| match error {
+ Error::AutoConversion {
+ dest_span: _,
+ dest_type,
+ source_span: _,
+ source_type,
+ } => Error::InitializationTypeMismatch {
+ name: v.name.span,
+ expected: dest_type,
+ got: source_type,
+ },
+ other => other,
+ })?;
+ init = Some(converted);
+ } else {
+ init = None;
+ }
+
+ let binding = if let Some(ref binding) = v.binding {
+ Some(crate::ResourceBinding {
+ group: self.const_u32(binding.group, &mut ctx.as_const())?.0,
+ binding: self.const_u32(binding.binding, &mut ctx.as_const())?.0,
+ })
+ } else {
+ None
+ };
+
+ let handle = ctx.module.global_variables.append(
+ crate::GlobalVariable {
+ name: Some(v.name.name.to_string()),
+ space: v.space,
+ binding,
+ ty,
+ init,
+ },
+ span,
+ );
+
+ ctx.globals
+ .insert(v.name.name, LoweredGlobalDecl::Var(handle));
+ }
+ ast::GlobalDeclKind::Const(ref c) => {
+ let mut ectx = ctx.as_const();
+ let mut init = self.expression_for_abstract(c.init, &mut ectx)?;
+
+ let ty;
+ if let Some(explicit_ty) = c.ty {
+ let explicit_ty =
+ self.resolve_ast_type(explicit_ty, &mut ectx.as_global())?;
+ let explicit_ty_res = crate::proc::TypeResolution::Handle(explicit_ty);
+ init = ectx
+ .try_automatic_conversions(init, &explicit_ty_res, c.name.span)
+ .map_err(|error| match error {
+ Error::AutoConversion {
+ dest_span: _,
+ dest_type,
+ source_span: _,
+ source_type,
+ } => Error::InitializationTypeMismatch {
+ name: c.name.span,
+ expected: dest_type,
+ got: source_type,
+ },
+ other => other,
+ })?;
+ ty = explicit_ty;
+ } else {
+ init = ectx.concretize(init)?;
+ ty = ectx.register_type(init)?;
+ }
+
+ let handle = ctx.module.constants.append(
+ crate::Constant {
+ name: Some(c.name.name.to_string()),
+ r#override: crate::Override::None,
+ ty,
+ init,
+ },
+ span,
+ );
+
+ ctx.globals
+ .insert(c.name.name, LoweredGlobalDecl::Const(handle));
+ }
+ ast::GlobalDeclKind::Struct(ref s) => {
+ let handle = self.r#struct(s, span, &mut ctx)?;
+ ctx.globals
+ .insert(s.name.name, LoweredGlobalDecl::Type(handle));
+ }
+ ast::GlobalDeclKind::Type(ref alias) => {
+ let ty = self.resolve_named_ast_type(
+ alias.ty,
+ Some(alias.name.name.to_string()),
+ &mut ctx,
+ )?;
+ ctx.globals
+ .insert(alias.name.name, LoweredGlobalDecl::Type(ty));
+ }
+ }
+ }
+
+ // Constant evaluation may leave abstract-typed literals and
+ // compositions in expression arenas, so we need to compact the module
+ // to remove unused expressions and types.
+ crate::compact::compact(&mut module);
+
+ Ok(module)
+ }
+
+ fn function(
+ &mut self,
+ f: &ast::Function<'source>,
+ span: Span,
+ ctx: &mut GlobalContext<'source, '_, '_>,
+ ) -> Result<LoweredGlobalDecl, Error<'source>> {
+ let mut local_table = FastHashMap::default();
+ let mut expressions = Arena::new();
+ let mut named_expressions = FastIndexMap::default();
+
+ let arguments = f
+ .arguments
+ .iter()
+ .enumerate()
+ .map(|(i, arg)| {
+ let ty = self.resolve_ast_type(arg.ty, ctx)?;
+ let expr = expressions
+ .append(crate::Expression::FunctionArgument(i as u32), arg.name.span);
+ local_table.insert(arg.handle, Typed::Plain(expr));
+ named_expressions.insert(expr, (arg.name.name.to_string(), arg.name.span));
+
+ Ok(crate::FunctionArgument {
+ name: Some(arg.name.name.to_string()),
+ ty,
+ binding: self.binding(&arg.binding, ty, ctx)?,
+ })
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+
+ let result = f
+ .result
+ .as_ref()
+ .map(|res| {
+ let ty = self.resolve_ast_type(res.ty, ctx)?;
+ Ok(crate::FunctionResult {
+ ty,
+ binding: self.binding(&res.binding, ty, ctx)?,
+ })
+ })
+ .transpose()?;
+
+ let mut function = crate::Function {
+ name: Some(f.name.name.to_string()),
+ arguments,
+ result,
+ local_variables: Arena::new(),
+ expressions,
+ named_expressions: crate::NamedExpressions::default(),
+ body: crate::Block::default(),
+ };
+
+ let mut typifier = Typifier::default();
+ let mut stmt_ctx = StatementContext {
+ local_table: &mut local_table,
+ globals: ctx.globals,
+ ast_expressions: ctx.ast_expressions,
+ const_typifier: ctx.const_typifier,
+ typifier: &mut typifier,
+ function: &mut function,
+ named_expressions: &mut named_expressions,
+ types: ctx.types,
+ module: ctx.module,
+ expression_constness: &mut crate::proc::ExpressionConstnessTracker::new(),
+ };
+ let mut body = self.block(&f.body, false, &mut stmt_ctx)?;
+ ensure_block_returns(&mut body);
+
+ function.body = body;
+ function.named_expressions = named_expressions
+ .into_iter()
+ .map(|(key, (name, _))| (key, name))
+ .collect();
+
+ if let Some(ref entry) = f.entry_point {
+ let workgroup_size = if let Some(workgroup_size) = entry.workgroup_size {
+ // TODO: replace with try_map once stabilized
+ let mut workgroup_size_out = [1; 3];
+ for (i, size) in workgroup_size.into_iter().enumerate() {
+ if let Some(size_expr) = size {
+ workgroup_size_out[i] = self.const_u32(size_expr, &mut ctx.as_const())?.0;
+ }
+ }
+ workgroup_size_out
+ } else {
+ [0; 3]
+ };
+
+ ctx.module.entry_points.push(crate::EntryPoint {
+ name: f.name.name.to_string(),
+ stage: entry.stage,
+ early_depth_test: entry.early_depth_test,
+ workgroup_size,
+ function,
+ });
+ Ok(LoweredGlobalDecl::EntryPoint)
+ } else {
+ let handle = ctx.module.functions.append(function, span);
+ Ok(LoweredGlobalDecl::Function(handle))
+ }
+ }
+
+ fn block(
+ &mut self,
+ b: &ast::Block<'source>,
+ is_inside_loop: bool,
+ ctx: &mut StatementContext<'source, '_, '_>,
+ ) -> Result<crate::Block, Error<'source>> {
+ let mut block = crate::Block::default();
+
+ for stmt in b.stmts.iter() {
+ self.statement(stmt, &mut block, is_inside_loop, ctx)?;
+ }
+
+ Ok(block)
+ }
+
+ fn statement(
+ &mut self,
+ stmt: &ast::Statement<'source>,
+ block: &mut crate::Block,
+ is_inside_loop: bool,
+ ctx: &mut StatementContext<'source, '_, '_>,
+ ) -> Result<(), Error<'source>> {
+ let out = match stmt.kind {
+ ast::StatementKind::Block(ref block) => {
+ let block = self.block(block, is_inside_loop, ctx)?;
+ crate::Statement::Block(block)
+ }
+ ast::StatementKind::LocalDecl(ref decl) => match *decl {
+ ast::LocalDecl::Let(ref l) => {
+ let mut emitter = Emitter::default();
+ emitter.start(&ctx.function.expressions);
+
+ let value =
+ self.expression(l.init, &mut ctx.as_expression(block, &mut emitter))?;
+
+ // The WGSL spec says that any expression that refers to a
+ // `let`-bound variable is not a const expression. This
+ // affects when errors must be reported, so we can't even
+ // treat suitable `let` bindings as constant as an
+ // optimization.
+ ctx.expression_constness.force_non_const(value);
+
+ let explicit_ty =
+ l.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx.as_global()))
+ .transpose()?;
+
+ if let Some(ty) = explicit_ty {
+ let mut ctx = ctx.as_expression(block, &mut emitter);
+ let init_ty = ctx.register_type(value)?;
+ if !ctx.module.types[ty]
+ .inner
+ .equivalent(&ctx.module.types[init_ty].inner, &ctx.module.types)
+ {
+ let gctx = &ctx.module.to_ctx();
+ return Err(Error::InitializationTypeMismatch {
+ name: l.name.span,
+ expected: ty.to_wgsl(gctx),
+ got: init_ty.to_wgsl(gctx),
+ });
+ }
+ }
+
+ block.extend(emitter.finish(&ctx.function.expressions));
+ ctx.local_table.insert(l.handle, Typed::Plain(value));
+ ctx.named_expressions
+ .insert(value, (l.name.name.to_string(), l.name.span));
+
+ return Ok(());
+ }
+ ast::LocalDecl::Var(ref v) => {
+ let explicit_ty =
+ v.ty.map(|ast| self.resolve_ast_type(ast, &mut ctx.as_global()))
+ .transpose()?;
+
+ let mut emitter = Emitter::default();
+ emitter.start(&ctx.function.expressions);
+ let mut ectx = ctx.as_expression(block, &mut emitter);
+
+ let ty;
+ let initializer;
+ match (v.init, explicit_ty) {
+ (Some(init), Some(explicit_ty)) => {
+ let init = self.expression_for_abstract(init, &mut ectx)?;
+ let ty_res = crate::proc::TypeResolution::Handle(explicit_ty);
+ let init = ectx
+ .try_automatic_conversions(init, &ty_res, v.name.span)
+ .map_err(|error| match error {
+ Error::AutoConversion {
+ dest_span: _,
+ dest_type,
+ source_span: _,
+ source_type,
+ } => Error::InitializationTypeMismatch {
+ name: v.name.span,
+ expected: dest_type,
+ got: source_type,
+ },
+ other => other,
+ })?;
+ ty = explicit_ty;
+ initializer = Some(init);
+ }
+ (Some(init), None) => {
+ let concretized = self.expression(init, &mut ectx)?;
+ ty = ectx.register_type(concretized)?;
+ initializer = Some(concretized);
+ }
+ (None, Some(explicit_ty)) => {
+ ty = explicit_ty;
+ initializer = None;
+ }
+ (None, None) => return Err(Error::MissingType(v.name.span)),
+ }
+
+ let (const_initializer, initializer) = {
+ match initializer {
+ Some(init) => {
+ // It's not correct to hoist the initializer up
+ // to the top of the function if:
+ // - the initialization is inside a loop, and should
+ // take place on every iteration, or
+ // - the initialization is not a constant
+ // expression, so its value depends on the
+ // state at the point of initialization.
+ if is_inside_loop || !ctx.expression_constness.is_const(init) {
+ (None, Some(init))
+ } else {
+ (Some(init), None)
+ }
+ }
+ None => (None, None),
+ }
+ };
+
+ let var = ctx.function.local_variables.append(
+ crate::LocalVariable {
+ name: Some(v.name.name.to_string()),
+ ty,
+ init: const_initializer,
+ },
+ stmt.span,
+ );
+
+ let handle = ctx.as_expression(block, &mut emitter).interrupt_emitter(
+ crate::Expression::LocalVariable(var),
+ Span::UNDEFINED,
+ )?;
+ block.extend(emitter.finish(&ctx.function.expressions));
+ ctx.local_table.insert(v.handle, Typed::Reference(handle));
+
+ match initializer {
+ Some(initializer) => crate::Statement::Store {
+ pointer: handle,
+ value: initializer,
+ },
+ None => return Ok(()),
+ }
+ }
+ },
+ ast::StatementKind::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ let mut emitter = Emitter::default();
+ emitter.start(&ctx.function.expressions);
+
+ let condition =
+ self.expression(condition, &mut ctx.as_expression(block, &mut emitter))?;
+ block.extend(emitter.finish(&ctx.function.expressions));
+
+ let accept = self.block(accept, is_inside_loop, ctx)?;
+ let reject = self.block(reject, is_inside_loop, ctx)?;
+
+ crate::Statement::If {
+ condition,
+ accept,
+ reject,
+ }
+ }
+ ast::StatementKind::Switch {
+ selector,
+ ref cases,
+ } => {
+ let mut emitter = Emitter::default();
+ emitter.start(&ctx.function.expressions);
+
+ let mut ectx = ctx.as_expression(block, &mut emitter);
+ let selector = self.expression(selector, &mut ectx)?;
+
+ let uint =
+ resolve_inner!(ectx, selector).scalar_kind() == Some(crate::ScalarKind::Uint);
+ block.extend(emitter.finish(&ctx.function.expressions));
+
+ let cases = cases
+ .iter()
+ .map(|case| {
+ Ok(crate::SwitchCase {
+ value: match case.value {
+ ast::SwitchValue::Expr(expr) => {
+ let span = ctx.ast_expressions.get_span(expr);
+ let expr =
+ self.expression(expr, &mut ctx.as_global().as_const())?;
+ match ctx.module.to_ctx().eval_expr_to_literal(expr) {
+ Some(crate::Literal::I32(value)) if !uint => {
+ crate::SwitchValue::I32(value)
+ }
+ Some(crate::Literal::U32(value)) if uint => {
+ crate::SwitchValue::U32(value)
+ }
+ _ => {
+ return Err(Error::InvalidSwitchValue { uint, span });
+ }
+ }
+ }
+ ast::SwitchValue::Default => crate::SwitchValue::Default,
+ },
+ body: self.block(&case.body, is_inside_loop, ctx)?,
+ fall_through: case.fall_through,
+ })
+ })
+ .collect::<Result<_, _>>()?;
+
+ crate::Statement::Switch { selector, cases }
+ }
+ ast::StatementKind::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ let body = self.block(body, true, ctx)?;
+ let mut continuing = self.block(continuing, true, ctx)?;
+
+ let mut emitter = Emitter::default();
+ emitter.start(&ctx.function.expressions);
+ let break_if = break_if
+ .map(|expr| {
+ self.expression(expr, &mut ctx.as_expression(&mut continuing, &mut emitter))
+ })
+ .transpose()?;
+ continuing.extend(emitter.finish(&ctx.function.expressions));
+
+ crate::Statement::Loop {
+ body,
+ continuing,
+ break_if,
+ }
+ }
+ ast::StatementKind::Break => crate::Statement::Break,
+ ast::StatementKind::Continue => crate::Statement::Continue,
+ ast::StatementKind::Return { value } => {
+ let mut emitter = Emitter::default();
+ emitter.start(&ctx.function.expressions);
+
+ let value = value
+ .map(|expr| self.expression(expr, &mut ctx.as_expression(block, &mut emitter)))
+ .transpose()?;
+ block.extend(emitter.finish(&ctx.function.expressions));
+
+ crate::Statement::Return { value }
+ }
+ ast::StatementKind::Kill => crate::Statement::Kill,
+ ast::StatementKind::Call {
+ ref function,
+ ref arguments,
+ } => {
+ let mut emitter = Emitter::default();
+ emitter.start(&ctx.function.expressions);
+
+ let _ = self.call(
+ stmt.span,
+ function,
+ arguments,
+ &mut ctx.as_expression(block, &mut emitter),
+ )?;
+ block.extend(emitter.finish(&ctx.function.expressions));
+ return Ok(());
+ }
+ ast::StatementKind::Assign {
+ target: ast_target,
+ op,
+ value,
+ } => {
+ let mut emitter = Emitter::default();
+ emitter.start(&ctx.function.expressions);
+
+ let target = self.expression_for_reference(
+ ast_target,
+ &mut ctx.as_expression(block, &mut emitter),
+ )?;
+ let mut value =
+ self.expression(value, &mut ctx.as_expression(block, &mut emitter))?;
+
+ let target_handle = match target {
+ Typed::Reference(handle) => handle,
+ Typed::Plain(handle) => {
+ let ty = ctx.invalid_assignment_type(handle);
+ return Err(Error::InvalidAssignment {
+ span: ctx.ast_expressions.get_span(ast_target),
+ ty,
+ });
+ }
+ };
+
+ let value = match op {
+ Some(op) => {
+ let mut ctx = ctx.as_expression(block, &mut emitter);
+ let mut left = ctx.apply_load_rule(target)?;
+ ctx.binary_op_splat(op, &mut left, &mut value)?;
+ ctx.append_expression(
+ crate::Expression::Binary {
+ op,
+ left,
+ right: value,
+ },
+ stmt.span,
+ )?
+ }
+ None => value,
+ };
+ block.extend(emitter.finish(&ctx.function.expressions));
+
+ crate::Statement::Store {
+ pointer: target_handle,
+ value,
+ }
+ }
+ ast::StatementKind::Increment(value) | ast::StatementKind::Decrement(value) => {
+ let mut emitter = Emitter::default();
+ emitter.start(&ctx.function.expressions);
+
+ let op = match stmt.kind {
+ ast::StatementKind::Increment(_) => crate::BinaryOperator::Add,
+ ast::StatementKind::Decrement(_) => crate::BinaryOperator::Subtract,
+ _ => unreachable!(),
+ };
+
+ let value_span = ctx.ast_expressions.get_span(value);
+ let target = self
+ .expression_for_reference(value, &mut ctx.as_expression(block, &mut emitter))?;
+ let target_handle = match target {
+ Typed::Reference(handle) => handle,
+ Typed::Plain(_) => return Err(Error::BadIncrDecrReferenceType(value_span)),
+ };
+
+ let mut ectx = ctx.as_expression(block, &mut emitter);
+ let scalar = match *resolve_inner!(ectx, target_handle) {
+ crate::TypeInner::ValuePointer {
+ size: None, scalar, ..
+ } => scalar,
+ crate::TypeInner::Pointer { base, .. } => match ectx.module.types[base].inner {
+ crate::TypeInner::Scalar(scalar) => scalar,
+ _ => return Err(Error::BadIncrDecrReferenceType(value_span)),
+ },
+ _ => return Err(Error::BadIncrDecrReferenceType(value_span)),
+ };
+ let literal = match scalar.kind {
+ crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
+ crate::Literal::one(scalar)
+ .ok_or(Error::BadIncrDecrReferenceType(value_span))?
+ }
+ _ => return Err(Error::BadIncrDecrReferenceType(value_span)),
+ };
+
+ let right =
+ ectx.interrupt_emitter(crate::Expression::Literal(literal), Span::UNDEFINED)?;
+ let rctx = ectx.runtime_expression_ctx(stmt.span)?;
+ let left = rctx.function.expressions.append(
+ crate::Expression::Load {
+ pointer: target_handle,
+ },
+ value_span,
+ );
+ let value = rctx
+ .function
+ .expressions
+ .append(crate::Expression::Binary { op, left, right }, stmt.span);
+
+ block.extend(emitter.finish(&ctx.function.expressions));
+ crate::Statement::Store {
+ pointer: target_handle,
+ value,
+ }
+ }
+ ast::StatementKind::Ignore(expr) => {
+ let mut emitter = Emitter::default();
+ emitter.start(&ctx.function.expressions);
+
+ let _ = self.expression(expr, &mut ctx.as_expression(block, &mut emitter))?;
+ block.extend(emitter.finish(&ctx.function.expressions));
+ return Ok(());
+ }
+ };
+
+ block.push(out, stmt.span);
+
+ Ok(())
+ }
+
+ /// Lower `expr` and apply the Load Rule if possible.
+ ///
+ /// For the time being, this concretizes abstract values, to support
+ /// consumers that haven't been adapted to consume them yet. Consumers
+ /// prepared for abstract values can call [`expression_for_abstract`].
+ ///
+ /// [`expression_for_abstract`]: Lowerer::expression_for_abstract
+ fn expression(
+ &mut self,
+ expr: Handle<ast::Expression<'source>>,
+ ctx: &mut ExpressionContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ let expr = self.expression_for_abstract(expr, ctx)?;
+ ctx.concretize(expr)
+ }
+
+ fn expression_for_abstract(
+ &mut self,
+ expr: Handle<ast::Expression<'source>>,
+ ctx: &mut ExpressionContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ let expr = self.expression_for_reference(expr, ctx)?;
+ ctx.apply_load_rule(expr)
+ }
+
+ fn expression_for_reference(
+ &mut self,
+ expr: Handle<ast::Expression<'source>>,
+ ctx: &mut ExpressionContext<'source, '_, '_>,
+ ) -> Result<Typed<Handle<crate::Expression>>, Error<'source>> {
+ let span = ctx.ast_expressions.get_span(expr);
+ let expr = &ctx.ast_expressions[expr];
+
+ let expr: Typed<crate::Expression> = match *expr {
+ ast::Expression::Literal(literal) => {
+ let literal = match literal {
+ ast::Literal::Number(Number::F32(f)) => crate::Literal::F32(f),
+ ast::Literal::Number(Number::I32(i)) => crate::Literal::I32(i),
+ ast::Literal::Number(Number::U32(u)) => crate::Literal::U32(u),
+ ast::Literal::Number(Number::F64(f)) => crate::Literal::F64(f),
+ ast::Literal::Number(Number::AbstractInt(i)) => crate::Literal::AbstractInt(i),
+ ast::Literal::Number(Number::AbstractFloat(f)) => {
+ crate::Literal::AbstractFloat(f)
+ }
+ ast::Literal::Bool(b) => crate::Literal::Bool(b),
+ };
+ let handle = ctx.interrupt_emitter(crate::Expression::Literal(literal), span)?;
+ return Ok(Typed::Plain(handle));
+ }
+ ast::Expression::Ident(ast::IdentExpr::Local(local)) => {
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ return Ok(rctx.local_table[&local]);
+ }
+ ast::Expression::Ident(ast::IdentExpr::Unresolved(name)) => {
+ let global = ctx
+ .globals
+ .get(name)
+ .ok_or(Error::UnknownIdent(span, name))?;
+ let expr = match *global {
+ LoweredGlobalDecl::Var(handle) => {
+ let expr = crate::Expression::GlobalVariable(handle);
+ match ctx.module.global_variables[handle].space {
+ crate::AddressSpace::Handle => Typed::Plain(expr),
+ _ => Typed::Reference(expr),
+ }
+ }
+ LoweredGlobalDecl::Const(handle) => {
+ Typed::Plain(crate::Expression::Constant(handle))
+ }
+ _ => {
+ return Err(Error::Unexpected(span, ExpectedToken::Variable));
+ }
+ };
+
+ return expr.try_map(|handle| ctx.interrupt_emitter(handle, span));
+ }
+ ast::Expression::Construct {
+ ref ty,
+ ty_span,
+ ref components,
+ } => {
+ let handle = self.construct(span, ty, ty_span, components, ctx)?;
+ return Ok(Typed::Plain(handle));
+ }
+ ast::Expression::Unary { op, expr } => {
+ let expr = self.expression_for_abstract(expr, ctx)?;
+ Typed::Plain(crate::Expression::Unary { op, expr })
+ }
+ ast::Expression::AddrOf(expr) => {
+ // The `&` operator simply converts a reference to a pointer. And since a
+ // reference is required, the Load Rule is not applied.
+ match self.expression_for_reference(expr, ctx)? {
+ Typed::Reference(handle) => {
+ // No code is generated. We just declare the reference a pointer now.
+ return Ok(Typed::Plain(handle));
+ }
+ Typed::Plain(_) => {
+ return Err(Error::NotReference("the operand of the `&` operator", span));
+ }
+ }
+ }
+ ast::Expression::Deref(expr) => {
+ // The pointer we dereference must be loaded.
+ let pointer = self.expression(expr, ctx)?;
+
+ if resolve_inner!(ctx, pointer).pointer_space().is_none() {
+ return Err(Error::NotPointer(span));
+ }
+
+ // No code is generated. We just declare the pointer a reference now.
+ return Ok(Typed::Reference(pointer));
+ }
+ ast::Expression::Binary { op, left, right } => {
+ self.binary(op, left, right, span, ctx)?
+ }
+ ast::Expression::Call {
+ ref function,
+ ref arguments,
+ } => {
+ let handle = self
+ .call(span, function, arguments, ctx)?
+ .ok_or(Error::FunctionReturnsVoid(function.span))?;
+ return Ok(Typed::Plain(handle));
+ }
+ ast::Expression::Index { base, index } => {
+ let lowered_base = self.expression_for_reference(base, ctx)?;
+ let index = self.expression(index, ctx)?;
+
+ if let Typed::Plain(handle) = lowered_base {
+ if resolve_inner!(ctx, handle).pointer_space().is_some() {
+ return Err(Error::Pointer(
+ "the value indexed by a `[]` subscripting expression",
+ ctx.ast_expressions.get_span(base),
+ ));
+ }
+ }
+
+ lowered_base.map(|base| match ctx.const_access(index) {
+ Some(index) => crate::Expression::AccessIndex { base, index },
+ None => crate::Expression::Access { base, index },
+ })
+ }
+ ast::Expression::Member { base, ref field } => {
+ let lowered_base = self.expression_for_reference(base, ctx)?;
+
+ let temp_inner;
+ let composite_type: &crate::TypeInner = match lowered_base {
+ Typed::Reference(handle) => {
+ let inner = resolve_inner!(ctx, handle);
+ match *inner {
+ crate::TypeInner::Pointer { base, .. } => &ctx.module.types[base].inner,
+ crate::TypeInner::ValuePointer {
+ size: None, scalar, ..
+ } => {
+ temp_inner = crate::TypeInner::Scalar(scalar);
+ &temp_inner
+ }
+ crate::TypeInner::ValuePointer {
+ size: Some(size),
+ scalar,
+ ..
+ } => {
+ temp_inner = crate::TypeInner::Vector { size, scalar };
+ &temp_inner
+ }
+ _ => unreachable!(
+ "In Typed::Reference(handle), handle must be a Naga pointer"
+ ),
+ }
+ }
+
+ Typed::Plain(handle) => {
+ let inner = resolve_inner!(ctx, handle);
+ if let crate::TypeInner::Pointer { .. }
+ | crate::TypeInner::ValuePointer { .. } = *inner
+ {
+ return Err(Error::Pointer(
+ "the value accessed by a `.member` expression",
+ ctx.ast_expressions.get_span(base),
+ ));
+ }
+ inner
+ }
+ };
+
+ let access = match *composite_type {
+ crate::TypeInner::Struct { ref members, .. } => {
+ let index = members
+ .iter()
+ .position(|m| m.name.as_deref() == Some(field.name))
+ .ok_or(Error::BadAccessor(field.span))?
+ as u32;
+
+ lowered_base.map(|base| crate::Expression::AccessIndex { base, index })
+ }
+ crate::TypeInner::Vector { .. } | crate::TypeInner::Matrix { .. } => {
+ match Components::new(field.name, field.span)? {
+ Components::Swizzle { size, pattern } => {
+ // Swizzles aren't allowed on matrices, but
+ // validation will catch that.
+ Typed::Plain(crate::Expression::Swizzle {
+ size,
+ vector: ctx.apply_load_rule(lowered_base)?,
+ pattern,
+ })
+ }
+ Components::Single(index) => lowered_base
+ .map(|base| crate::Expression::AccessIndex { base, index }),
+ }
+ }
+ _ => return Err(Error::BadAccessor(field.span)),
+ };
+
+ access
+ }
+ ast::Expression::Bitcast { expr, to, ty_span } => {
+ let expr = self.expression(expr, ctx)?;
+ let to_resolved = self.resolve_ast_type(to, &mut ctx.as_global())?;
+
+ let element_scalar = match ctx.module.types[to_resolved].inner {
+ crate::TypeInner::Scalar(scalar) => scalar,
+ crate::TypeInner::Vector { scalar, .. } => scalar,
+ _ => {
+ let ty = resolve!(ctx, expr);
+ let gctx = &ctx.module.to_ctx();
+ return Err(Error::BadTypeCast {
+ from_type: ty.to_wgsl(gctx),
+ span: ty_span,
+ to_type: to_resolved.to_wgsl(gctx),
+ });
+ }
+ };
+
+ Typed::Plain(crate::Expression::As {
+ expr,
+ kind: element_scalar.kind,
+ convert: None,
+ })
+ }
+ };
+
+ expr.try_map(|handle| ctx.append_expression(handle, span))
+ }
+
+ fn binary(
+ &mut self,
+ op: crate::BinaryOperator,
+ left: Handle<ast::Expression<'source>>,
+ right: Handle<ast::Expression<'source>>,
+ span: Span,
+ ctx: &mut ExpressionContext<'source, '_, '_>,
+ ) -> Result<Typed<crate::Expression>, Error<'source>> {
+ // Load both operands.
+ let mut left = self.expression_for_abstract(left, ctx)?;
+ let mut right = self.expression_for_abstract(right, ctx)?;
+
+ // Convert `scalar op vector` to `vector op vector` by introducing
+ // `Splat` expressions.
+ ctx.binary_op_splat(op, &mut left, &mut right)?;
+
+ // Apply automatic conversions.
+ match op {
+ // Shift operators require the right operand to be `u32` or
+ // `vecN<u32>`. We can let the validator sort out vector length
+ // issues, but the right operand must be, or convert to, a u32 leaf
+ // scalar.
+ crate::BinaryOperator::ShiftLeft | crate::BinaryOperator::ShiftRight => {
+ right =
+ ctx.try_automatic_conversion_for_leaf_scalar(right, crate::Scalar::U32, span)?;
+ }
+
+ // All other operators follow the same pattern: reconcile the
+ // scalar leaf types. If there's no reconciliation possible,
+ // leave the expressions as they are: validation will report the
+ // problem.
+ _ => {
+ ctx.grow_types(left)?;
+ ctx.grow_types(right)?;
+ if let Ok(consensus_scalar) =
+ ctx.automatic_conversion_consensus([left, right].iter())
+ {
+ ctx.convert_to_leaf_scalar(&mut left, consensus_scalar)?;
+ ctx.convert_to_leaf_scalar(&mut right, consensus_scalar)?;
+ }
+ }
+ }
+
+ Ok(Typed::Plain(crate::Expression::Binary { op, left, right }))
+ }
+
+ /// Generate Naga IR for call expressions and statements, and type
+ /// constructor expressions.
+ ///
+ /// The "function" being called is simply an `Ident` that we know refers to
+ /// some module-scope definition.
+ ///
+ /// - If it is the name of a type, then the expression is a type constructor
+ /// expression: either constructing a value from components, a conversion
+ /// expression, or a zero value expression.
+ ///
+ /// - If it is the name of a function, then we're generating a [`Call`]
+ /// statement. We may be in the midst of generating code for an
+ /// expression, in which case we must generate an `Emit` statement to
+ /// force evaluation of the IR expressions we've generated so far, add the
+ /// `Call` statement to the current block, and then resume generating
+ /// expressions.
+ ///
+ /// [`Call`]: crate::Statement::Call
+ fn call(
+ &mut self,
+ span: Span,
+ function: &ast::Ident<'source>,
+ arguments: &[Handle<ast::Expression<'source>>],
+ ctx: &mut ExpressionContext<'source, '_, '_>,
+ ) -> Result<Option<Handle<crate::Expression>>, Error<'source>> {
+ match ctx.globals.get(function.name) {
+ Some(&LoweredGlobalDecl::Type(ty)) => {
+ let handle = self.construct(
+ span,
+ &ast::ConstructorType::Type(ty),
+ function.span,
+ arguments,
+ ctx,
+ )?;
+ Ok(Some(handle))
+ }
+ Some(&LoweredGlobalDecl::Const(_) | &LoweredGlobalDecl::Var(_)) => {
+ Err(Error::Unexpected(function.span, ExpectedToken::Function))
+ }
+ Some(&LoweredGlobalDecl::EntryPoint) => Err(Error::CalledEntryPoint(function.span)),
+ Some(&LoweredGlobalDecl::Function(function)) => {
+ let arguments = arguments
+ .iter()
+ .map(|&arg| self.expression(arg, ctx))
+ .collect::<Result<Vec<_>, _>>()?;
+
+ let has_result = ctx.module.functions[function].result.is_some();
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ // we need to always do this before a fn call since all arguments need to be emitted before the fn call
+ rctx.block
+ .extend(rctx.emitter.finish(&rctx.function.expressions));
+ let result = has_result.then(|| {
+ rctx.function
+ .expressions
+ .append(crate::Expression::CallResult(function), span)
+ });
+ rctx.emitter.start(&rctx.function.expressions);
+ rctx.block.push(
+ crate::Statement::Call {
+ function,
+ arguments,
+ result,
+ },
+ span,
+ );
+
+ Ok(result)
+ }
+ None => {
+ let span = function.span;
+ let expr = if let Some(fun) = conv::map_relational_fun(function.name) {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let argument = self.expression(args.next()?, ctx)?;
+ args.finish()?;
+
+ // Check for no-op all(bool) and any(bool):
+ let argument_unmodified = matches!(
+ fun,
+ crate::RelationalFunction::All | crate::RelationalFunction::Any
+ ) && {
+ matches!(
+ resolve_inner!(ctx, argument),
+ &crate::TypeInner::Scalar(crate::Scalar {
+ kind: crate::ScalarKind::Bool,
+ ..
+ })
+ )
+ };
+
+ if argument_unmodified {
+ return Ok(Some(argument));
+ } else {
+ crate::Expression::Relational { fun, argument }
+ }
+ } else if let Some((axis, ctrl)) = conv::map_derivative(function.name) {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let expr = self.expression(args.next()?, ctx)?;
+ args.finish()?;
+
+ crate::Expression::Derivative { axis, ctrl, expr }
+ } else if let Some(fun) = conv::map_standard_fun(function.name) {
+ let expected = fun.argument_count() as _;
+ let mut args = ctx.prepare_args(arguments, expected, span);
+
+ let arg = self.expression(args.next()?, ctx)?;
+ let arg1 = args
+ .next()
+ .map(|x| self.expression(x, ctx))
+ .ok()
+ .transpose()?;
+ let arg2 = args
+ .next()
+ .map(|x| self.expression(x, ctx))
+ .ok()
+ .transpose()?;
+ let arg3 = args
+ .next()
+ .map(|x| self.expression(x, ctx))
+ .ok()
+ .transpose()?;
+
+ args.finish()?;
+
+ if fun == crate::MathFunction::Modf || fun == crate::MathFunction::Frexp {
+ if let Some((size, width)) = match *resolve_inner!(ctx, arg) {
+ crate::TypeInner::Scalar(crate::Scalar { width, .. }) => {
+ Some((None, width))
+ }
+ crate::TypeInner::Vector {
+ size,
+ scalar: crate::Scalar { width, .. },
+ ..
+ } => Some((Some(size), width)),
+ _ => None,
+ } {
+ ctx.module.generate_predeclared_type(
+ if fun == crate::MathFunction::Modf {
+ crate::PredeclaredType::ModfResult { size, width }
+ } else {
+ crate::PredeclaredType::FrexpResult { size, width }
+ },
+ );
+ }
+ }
+
+ crate::Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ }
+ } else if let Some(fun) = Texture::map(function.name) {
+ self.texture_sample_helper(fun, arguments, span, ctx)?
+ } else {
+ match function.name {
+ "select" => {
+ let mut args = ctx.prepare_args(arguments, 3, span);
+
+ let reject = self.expression(args.next()?, ctx)?;
+ let accept = self.expression(args.next()?, ctx)?;
+ let condition = self.expression(args.next()?, ctx)?;
+
+ args.finish()?;
+
+ crate::Expression::Select {
+ reject,
+ accept,
+ condition,
+ }
+ }
+ "arrayLength" => {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let expr = self.expression(args.next()?, ctx)?;
+ args.finish()?;
+
+ crate::Expression::ArrayLength(expr)
+ }
+ "atomicLoad" => {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let pointer = self.atomic_pointer(args.next()?, ctx)?;
+ args.finish()?;
+
+ crate::Expression::Load { pointer }
+ }
+ "atomicStore" => {
+ let mut args = ctx.prepare_args(arguments, 2, span);
+ let pointer = self.atomic_pointer(args.next()?, ctx)?;
+ let value = self.expression(args.next()?, ctx)?;
+ args.finish()?;
+
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ rctx.block
+ .extend(rctx.emitter.finish(&rctx.function.expressions));
+ rctx.emitter.start(&rctx.function.expressions);
+ rctx.block
+ .push(crate::Statement::Store { pointer, value }, span);
+ return Ok(None);
+ }
+ "atomicAdd" => {
+ return Ok(Some(self.atomic_helper(
+ span,
+ crate::AtomicFunction::Add,
+ arguments,
+ ctx,
+ )?))
+ }
+ "atomicSub" => {
+ return Ok(Some(self.atomic_helper(
+ span,
+ crate::AtomicFunction::Subtract,
+ arguments,
+ ctx,
+ )?))
+ }
+ "atomicAnd" => {
+ return Ok(Some(self.atomic_helper(
+ span,
+ crate::AtomicFunction::And,
+ arguments,
+ ctx,
+ )?))
+ }
+ "atomicOr" => {
+ return Ok(Some(self.atomic_helper(
+ span,
+ crate::AtomicFunction::InclusiveOr,
+ arguments,
+ ctx,
+ )?))
+ }
+ "atomicXor" => {
+ return Ok(Some(self.atomic_helper(
+ span,
+ crate::AtomicFunction::ExclusiveOr,
+ arguments,
+ ctx,
+ )?))
+ }
+ "atomicMin" => {
+ return Ok(Some(self.atomic_helper(
+ span,
+ crate::AtomicFunction::Min,
+ arguments,
+ ctx,
+ )?))
+ }
+ "atomicMax" => {
+ return Ok(Some(self.atomic_helper(
+ span,
+ crate::AtomicFunction::Max,
+ arguments,
+ ctx,
+ )?))
+ }
+ "atomicExchange" => {
+ return Ok(Some(self.atomic_helper(
+ span,
+ crate::AtomicFunction::Exchange { compare: None },
+ arguments,
+ ctx,
+ )?))
+ }
+ "atomicCompareExchangeWeak" => {
+ let mut args = ctx.prepare_args(arguments, 3, span);
+
+ let pointer = self.atomic_pointer(args.next()?, ctx)?;
+
+ let compare = self.expression(args.next()?, ctx)?;
+
+ let value = args.next()?;
+ let value_span = ctx.ast_expressions.get_span(value);
+ let value = self.expression(value, ctx)?;
+
+ args.finish()?;
+
+ let expression = match *resolve_inner!(ctx, value) {
+ crate::TypeInner::Scalar(scalar) => {
+ crate::Expression::AtomicResult {
+ ty: ctx.module.generate_predeclared_type(
+ crate::PredeclaredType::AtomicCompareExchangeWeakResult(
+ scalar,
+ ),
+ ),
+ comparison: true,
+ }
+ }
+ _ => return Err(Error::InvalidAtomicOperandType(value_span)),
+ };
+
+ let result = ctx.interrupt_emitter(expression, span)?;
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ rctx.block.push(
+ crate::Statement::Atomic {
+ pointer,
+ fun: crate::AtomicFunction::Exchange {
+ compare: Some(compare),
+ },
+ value,
+ result,
+ },
+ span,
+ );
+ return Ok(Some(result));
+ }
+ "storageBarrier" => {
+ ctx.prepare_args(arguments, 0, span).finish()?;
+
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ rctx.block
+ .push(crate::Statement::Barrier(crate::Barrier::STORAGE), span);
+ return Ok(None);
+ }
+ "workgroupBarrier" => {
+ ctx.prepare_args(arguments, 0, span).finish()?;
+
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ rctx.block
+ .push(crate::Statement::Barrier(crate::Barrier::WORK_GROUP), span);
+ return Ok(None);
+ }
+ "workgroupUniformLoad" => {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let expr = args.next()?;
+ args.finish()?;
+
+ let pointer = self.expression(expr, ctx)?;
+ let result_ty = match *resolve_inner!(ctx, pointer) {
+ crate::TypeInner::Pointer {
+ base,
+ space: crate::AddressSpace::WorkGroup,
+ } => base,
+ ref other => {
+ log::error!("Type {other:?} passed to workgroupUniformLoad");
+ let span = ctx.ast_expressions.get_span(expr);
+ return Err(Error::InvalidWorkGroupUniformLoad(span));
+ }
+ };
+ let result = ctx.interrupt_emitter(
+ crate::Expression::WorkGroupUniformLoadResult { ty: result_ty },
+ span,
+ )?;
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ rctx.block.push(
+ crate::Statement::WorkGroupUniformLoad { pointer, result },
+ span,
+ );
+
+ return Ok(Some(result));
+ }
+ "textureStore" => {
+ let mut args = ctx.prepare_args(arguments, 3, span);
+
+ let image = args.next()?;
+ let image_span = ctx.ast_expressions.get_span(image);
+ let image = self.expression(image, ctx)?;
+
+ let coordinate = self.expression(args.next()?, ctx)?;
+
+ let (_, arrayed) = ctx.image_data(image, image_span)?;
+ let array_index = arrayed
+ .then(|| {
+ args.min_args += 1;
+ self.expression(args.next()?, ctx)
+ })
+ .transpose()?;
+
+ let value = self.expression(args.next()?, ctx)?;
+
+ args.finish()?;
+
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ rctx.block
+ .extend(rctx.emitter.finish(&rctx.function.expressions));
+ rctx.emitter.start(&rctx.function.expressions);
+ let stmt = crate::Statement::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ };
+ rctx.block.push(stmt, span);
+ return Ok(None);
+ }
+ "textureLoad" => {
+ let mut args = ctx.prepare_args(arguments, 2, span);
+
+ let image = args.next()?;
+ let image_span = ctx.ast_expressions.get_span(image);
+ let image = self.expression(image, ctx)?;
+
+ let coordinate = self.expression(args.next()?, ctx)?;
+
+ let (class, arrayed) = ctx.image_data(image, image_span)?;
+ let array_index = arrayed
+ .then(|| {
+ args.min_args += 1;
+ self.expression(args.next()?, ctx)
+ })
+ .transpose()?;
+
+ let level = class
+ .is_mipmapped()
+ .then(|| {
+ args.min_args += 1;
+ self.expression(args.next()?, ctx)
+ })
+ .transpose()?;
+
+ let sample = class
+ .is_multisampled()
+ .then(|| self.expression(args.next()?, ctx))
+ .transpose()?;
+
+ args.finish()?;
+
+ crate::Expression::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ level,
+ sample,
+ }
+ }
+ "textureDimensions" => {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let image = self.expression(args.next()?, ctx)?;
+ let level = args
+ .next()
+ .map(|arg| self.expression(arg, ctx))
+ .ok()
+ .transpose()?;
+ args.finish()?;
+
+ crate::Expression::ImageQuery {
+ image,
+ query: crate::ImageQuery::Size { level },
+ }
+ }
+ "textureNumLevels" => {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let image = self.expression(args.next()?, ctx)?;
+ args.finish()?;
+
+ crate::Expression::ImageQuery {
+ image,
+ query: crate::ImageQuery::NumLevels,
+ }
+ }
+ "textureNumLayers" => {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let image = self.expression(args.next()?, ctx)?;
+ args.finish()?;
+
+ crate::Expression::ImageQuery {
+ image,
+ query: crate::ImageQuery::NumLayers,
+ }
+ }
+ "textureNumSamples" => {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let image = self.expression(args.next()?, ctx)?;
+ args.finish()?;
+
+ crate::Expression::ImageQuery {
+ image,
+ query: crate::ImageQuery::NumSamples,
+ }
+ }
+ "rayQueryInitialize" => {
+ let mut args = ctx.prepare_args(arguments, 3, span);
+ let query = self.ray_query_pointer(args.next()?, ctx)?;
+ let acceleration_structure = self.expression(args.next()?, ctx)?;
+ let descriptor = self.expression(args.next()?, ctx)?;
+ args.finish()?;
+
+ let _ = ctx.module.generate_ray_desc_type();
+ let fun = crate::RayQueryFunction::Initialize {
+ acceleration_structure,
+ descriptor,
+ };
+
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ rctx.block
+ .extend(rctx.emitter.finish(&rctx.function.expressions));
+ rctx.emitter.start(&rctx.function.expressions);
+ rctx.block
+ .push(crate::Statement::RayQuery { query, fun }, span);
+ return Ok(None);
+ }
+ "rayQueryProceed" => {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let query = self.ray_query_pointer(args.next()?, ctx)?;
+ args.finish()?;
+
+ let result = ctx.interrupt_emitter(
+ crate::Expression::RayQueryProceedResult,
+ span,
+ )?;
+ let fun = crate::RayQueryFunction::Proceed { result };
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ rctx.block
+ .push(crate::Statement::RayQuery { query, fun }, span);
+ return Ok(Some(result));
+ }
+ "rayQueryGetCommittedIntersection" => {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let query = self.ray_query_pointer(args.next()?, ctx)?;
+ args.finish()?;
+
+ let _ = ctx.module.generate_ray_intersection_type();
+
+ crate::Expression::RayQueryGetIntersection {
+ query,
+ committed: true,
+ }
+ }
+ "RayDesc" => {
+ let ty = ctx.module.generate_ray_desc_type();
+ let handle = self.construct(
+ span,
+ &ast::ConstructorType::Type(ty),
+ function.span,
+ arguments,
+ ctx,
+ )?;
+ return Ok(Some(handle));
+ }
+ _ => return Err(Error::UnknownIdent(function.span, function.name)),
+ }
+ };
+
+ let expr = ctx.append_expression(expr, span)?;
+ Ok(Some(expr))
+ }
+ }
+ }
+
+ fn atomic_pointer(
+ &mut self,
+ expr: Handle<ast::Expression<'source>>,
+ ctx: &mut ExpressionContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ let span = ctx.ast_expressions.get_span(expr);
+ let pointer = self.expression(expr, ctx)?;
+
+ match *resolve_inner!(ctx, pointer) {
+ crate::TypeInner::Pointer { base, .. } => match ctx.module.types[base].inner {
+ crate::TypeInner::Atomic { .. } => Ok(pointer),
+ ref other => {
+ log::error!("Pointer type to {:?} passed to atomic op", other);
+ Err(Error::InvalidAtomicPointer(span))
+ }
+ },
+ ref other => {
+ log::error!("Type {:?} passed to atomic op", other);
+ Err(Error::InvalidAtomicPointer(span))
+ }
+ }
+ }
+
+ fn atomic_helper(
+ &mut self,
+ span: Span,
+ fun: crate::AtomicFunction,
+ args: &[Handle<ast::Expression<'source>>],
+ ctx: &mut ExpressionContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ let mut args = ctx.prepare_args(args, 2, span);
+
+ let pointer = self.atomic_pointer(args.next()?, ctx)?;
+
+ let value = args.next()?;
+ let value = self.expression(value, ctx)?;
+ let ty = ctx.register_type(value)?;
+
+ args.finish()?;
+
+ let result = ctx.interrupt_emitter(
+ crate::Expression::AtomicResult {
+ ty,
+ comparison: false,
+ },
+ span,
+ )?;
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ rctx.block.push(
+ crate::Statement::Atomic {
+ pointer,
+ fun,
+ value,
+ result,
+ },
+ span,
+ );
+ Ok(result)
+ }
+
+ fn texture_sample_helper(
+ &mut self,
+ fun: Texture,
+ args: &[Handle<ast::Expression<'source>>],
+ span: Span,
+ ctx: &mut ExpressionContext<'source, '_, '_>,
+ ) -> Result<crate::Expression, Error<'source>> {
+ let mut args = ctx.prepare_args(args, fun.min_argument_count(), span);
+
+ fn get_image_and_span<'source>(
+ lowerer: &mut Lowerer<'source, '_>,
+ args: &mut ArgumentContext<'_, 'source>,
+ ctx: &mut ExpressionContext<'source, '_, '_>,
+ ) -> Result<(Handle<crate::Expression>, Span), Error<'source>> {
+ let image = args.next()?;
+ let image_span = ctx.ast_expressions.get_span(image);
+ let image = lowerer.expression(image, ctx)?;
+ Ok((image, image_span))
+ }
+
+ let (image, image_span, gather) = match fun {
+ Texture::Gather => {
+ let image_or_component = args.next()?;
+ let image_or_component_span = ctx.ast_expressions.get_span(image_or_component);
+ // Gathers from depth textures don't take an initial `component` argument.
+ let lowered_image_or_component = self.expression(image_or_component, ctx)?;
+
+ match *resolve_inner!(ctx, lowered_image_or_component) {
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Depth { .. },
+ ..
+ } => (
+ lowered_image_or_component,
+ image_or_component_span,
+ Some(crate::SwizzleComponent::X),
+ ),
+ _ => {
+ let (image, image_span) = get_image_and_span(self, &mut args, ctx)?;
+ (
+ image,
+ image_span,
+ Some(ctx.gather_component(
+ lowered_image_or_component,
+ image_or_component_span,
+ span,
+ )?),
+ )
+ }
+ }
+ }
+ Texture::GatherCompare => {
+ let (image, image_span) = get_image_and_span(self, &mut args, ctx)?;
+ (image, image_span, Some(crate::SwizzleComponent::X))
+ }
+
+ _ => {
+ let (image, image_span) = get_image_and_span(self, &mut args, ctx)?;
+ (image, image_span, None)
+ }
+ };
+
+ let sampler = self.expression(args.next()?, ctx)?;
+
+ let coordinate = self.expression(args.next()?, ctx)?;
+
+ let (_, arrayed) = ctx.image_data(image, image_span)?;
+ let array_index = arrayed
+ .then(|| self.expression(args.next()?, ctx))
+ .transpose()?;
+
+ let (level, depth_ref) = match fun {
+ Texture::Gather => (crate::SampleLevel::Zero, None),
+ Texture::GatherCompare => {
+ let reference = self.expression(args.next()?, ctx)?;
+ (crate::SampleLevel::Zero, Some(reference))
+ }
+
+ Texture::Sample => (crate::SampleLevel::Auto, None),
+ Texture::SampleBias => {
+ let bias = self.expression(args.next()?, ctx)?;
+ (crate::SampleLevel::Bias(bias), None)
+ }
+ Texture::SampleCompare => {
+ let reference = self.expression(args.next()?, ctx)?;
+ (crate::SampleLevel::Auto, Some(reference))
+ }
+ Texture::SampleCompareLevel => {
+ let reference = self.expression(args.next()?, ctx)?;
+ (crate::SampleLevel::Zero, Some(reference))
+ }
+ Texture::SampleGrad => {
+ let x = self.expression(args.next()?, ctx)?;
+ let y = self.expression(args.next()?, ctx)?;
+ (crate::SampleLevel::Gradient { x, y }, None)
+ }
+ Texture::SampleLevel => {
+ let level = self.expression(args.next()?, ctx)?;
+ (crate::SampleLevel::Exact(level), None)
+ }
+ };
+
+ let offset = args
+ .next()
+ .map(|arg| self.expression(arg, &mut ctx.as_const()))
+ .ok()
+ .transpose()?;
+
+ args.finish()?;
+
+ Ok(crate::Expression::ImageSample {
+ image,
+ sampler,
+ gather,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ })
+ }
+
+ fn r#struct(
+ &mut self,
+ s: &ast::Struct<'source>,
+ span: Span,
+ ctx: &mut GlobalContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Type>, Error<'source>> {
+ let mut offset = 0;
+ let mut struct_alignment = Alignment::ONE;
+ let mut members = Vec::with_capacity(s.members.len());
+
+ for member in s.members.iter() {
+ let ty = self.resolve_ast_type(member.ty, ctx)?;
+
+ self.layouter.update(ctx.module.to_ctx()).unwrap();
+
+ let member_min_size = self.layouter[ty].size;
+ let member_min_alignment = self.layouter[ty].alignment;
+
+ let member_size = if let Some(size_expr) = member.size {
+ let (size, span) = self.const_u32(size_expr, &mut ctx.as_const())?;
+ if size < member_min_size {
+ return Err(Error::SizeAttributeTooLow(span, member_min_size));
+ } else {
+ size
+ }
+ } else {
+ member_min_size
+ };
+
+ let member_alignment = if let Some(align_expr) = member.align {
+ let (align, span) = self.const_u32(align_expr, &mut ctx.as_const())?;
+ if let Some(alignment) = Alignment::new(align) {
+ if alignment < member_min_alignment {
+ return Err(Error::AlignAttributeTooLow(span, member_min_alignment));
+ } else {
+ alignment
+ }
+ } else {
+ return Err(Error::NonPowerOfTwoAlignAttribute(span));
+ }
+ } else {
+ member_min_alignment
+ };
+
+ let binding = self.binding(&member.binding, ty, ctx)?;
+
+ offset = member_alignment.round_up(offset);
+ struct_alignment = struct_alignment.max(member_alignment);
+
+ members.push(crate::StructMember {
+ name: Some(member.name.name.to_owned()),
+ ty,
+ binding,
+ offset,
+ });
+
+ offset += member_size;
+ }
+
+ let size = struct_alignment.round_up(offset);
+ let inner = crate::TypeInner::Struct {
+ members,
+ span: size,
+ };
+
+ let handle = ctx.module.types.insert(
+ crate::Type {
+ name: Some(s.name.name.to_string()),
+ inner,
+ },
+ span,
+ );
+ Ok(handle)
+ }
+
+ fn const_u32(
+ &mut self,
+ expr: Handle<ast::Expression<'source>>,
+ ctx: &mut ExpressionContext<'source, '_, '_>,
+ ) -> Result<(u32, Span), Error<'source>> {
+ let span = ctx.ast_expressions.get_span(expr);
+ let expr = self.expression(expr, ctx)?;
+ let value = ctx
+ .module
+ .to_ctx()
+ .eval_expr_to_u32(expr)
+ .map_err(|err| match err {
+ crate::proc::U32EvalError::NonConst => {
+ Error::ExpectedConstExprConcreteIntegerScalar(span)
+ }
+ crate::proc::U32EvalError::Negative => Error::ExpectedNonNegative(span),
+ })?;
+ Ok((value, span))
+ }
+
+ fn array_size(
+ &mut self,
+ size: ast::ArraySize<'source>,
+ ctx: &mut GlobalContext<'source, '_, '_>,
+ ) -> Result<crate::ArraySize, Error<'source>> {
+ Ok(match size {
+ ast::ArraySize::Constant(expr) => {
+ let span = ctx.ast_expressions.get_span(expr);
+ let const_expr = self.expression(expr, &mut ctx.as_const())?;
+ let len =
+ ctx.module
+ .to_ctx()
+ .eval_expr_to_u32(const_expr)
+ .map_err(|err| match err {
+ crate::proc::U32EvalError::NonConst => {
+ Error::ExpectedConstExprConcreteIntegerScalar(span)
+ }
+ crate::proc::U32EvalError::Negative => {
+ Error::ExpectedPositiveArrayLength(span)
+ }
+ })?;
+ let size = NonZeroU32::new(len).ok_or(Error::ExpectedPositiveArrayLength(span))?;
+ crate::ArraySize::Constant(size)
+ }
+ ast::ArraySize::Dynamic => crate::ArraySize::Dynamic,
+ })
+ }
+
+ /// Build the Naga equivalent of a named AST type.
+ ///
+ /// Return a Naga `Handle<Type>` representing the front-end type
+ /// `handle`, which should be named `name`, if given.
+ ///
+ /// If `handle` refers to a type cached in [`SpecialTypes`],
+ /// `name` may be ignored.
+ ///
+ /// [`SpecialTypes`]: crate::SpecialTypes
+ fn resolve_named_ast_type(
+ &mut self,
+ handle: Handle<ast::Type<'source>>,
+ name: Option<String>,
+ ctx: &mut GlobalContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Type>, Error<'source>> {
+ let inner = match ctx.types[handle] {
+ ast::Type::Scalar(scalar) => scalar.to_inner_scalar(),
+ ast::Type::Vector { size, scalar } => scalar.to_inner_vector(size),
+ ast::Type::Matrix {
+ rows,
+ columns,
+ width,
+ } => crate::TypeInner::Matrix {
+ columns,
+ rows,
+ scalar: crate::Scalar::float(width),
+ },
+ ast::Type::Atomic(scalar) => scalar.to_inner_atomic(),
+ ast::Type::Pointer { base, space } => {
+ let base = self.resolve_ast_type(base, ctx)?;
+ crate::TypeInner::Pointer { base, space }
+ }
+ ast::Type::Array { base, size } => {
+ let base = self.resolve_ast_type(base, ctx)?;
+ let size = self.array_size(size, ctx)?;
+
+ self.layouter.update(ctx.module.to_ctx()).unwrap();
+ let stride = self.layouter[base].to_stride();
+
+ crate::TypeInner::Array { base, size, stride }
+ }
+ ast::Type::Image {
+ dim,
+ arrayed,
+ class,
+ } => crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ },
+ ast::Type::Sampler { comparison } => crate::TypeInner::Sampler { comparison },
+ ast::Type::AccelerationStructure => crate::TypeInner::AccelerationStructure,
+ ast::Type::RayQuery => crate::TypeInner::RayQuery,
+ ast::Type::BindingArray { base, size } => {
+ let base = self.resolve_ast_type(base, ctx)?;
+ let size = self.array_size(size, ctx)?;
+ crate::TypeInner::BindingArray { base, size }
+ }
+ ast::Type::RayDesc => {
+ return Ok(ctx.module.generate_ray_desc_type());
+ }
+ ast::Type::RayIntersection => {
+ return Ok(ctx.module.generate_ray_intersection_type());
+ }
+ ast::Type::User(ref ident) => {
+ return match ctx.globals.get(ident.name) {
+ Some(&LoweredGlobalDecl::Type(handle)) => Ok(handle),
+ Some(_) => Err(Error::Unexpected(ident.span, ExpectedToken::Type)),
+ None => Err(Error::UnknownType(ident.span)),
+ }
+ }
+ };
+
+ Ok(ctx.ensure_type_exists(name, inner))
+ }
+
+ /// Return a Naga `Handle<Type>` representing the front-end type `handle`.
+ fn resolve_ast_type(
+ &mut self,
+ handle: Handle<ast::Type<'source>>,
+ ctx: &mut GlobalContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Type>, Error<'source>> {
+ self.resolve_named_ast_type(handle, None, ctx)
+ }
+
+ fn binding(
+ &mut self,
+ binding: &Option<ast::Binding<'source>>,
+ ty: Handle<crate::Type>,
+ ctx: &mut GlobalContext<'source, '_, '_>,
+ ) -> Result<Option<crate::Binding>, Error<'source>> {
+ Ok(match *binding {
+ Some(ast::Binding::BuiltIn(b)) => Some(crate::Binding::BuiltIn(b)),
+ Some(ast::Binding::Location {
+ location,
+ second_blend_source,
+ interpolation,
+ sampling,
+ }) => {
+ let mut binding = crate::Binding::Location {
+ location: self.const_u32(location, &mut ctx.as_const())?.0,
+ second_blend_source,
+ interpolation,
+ sampling,
+ };
+ binding.apply_default_interpolation(&ctx.module.types[ty].inner);
+ Some(binding)
+ }
+ None => None,
+ })
+ }
+
+ fn ray_query_pointer(
+ &mut self,
+ expr: Handle<ast::Expression<'source>>,
+ ctx: &mut ExpressionContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ let span = ctx.ast_expressions.get_span(expr);
+ let pointer = self.expression(expr, ctx)?;
+
+ match *resolve_inner!(ctx, pointer) {
+ crate::TypeInner::Pointer { base, .. } => match ctx.module.types[base].inner {
+ crate::TypeInner::RayQuery => Ok(pointer),
+ ref other => {
+ log::error!("Pointer type to {:?} passed to ray query op", other);
+ Err(Error::InvalidRayQueryPointer(span))
+ }
+ },
+ ref other => {
+ log::error!("Type {:?} passed to ray query op", other);
+ Err(Error::InvalidRayQueryPointer(span))
+ }
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/mod.rs b/third_party/rust/naga/src/front/wgsl/mod.rs
new file mode 100644
index 0000000000..b6151fe1c0
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/mod.rs
@@ -0,0 +1,49 @@
+/*!
+Frontend for [WGSL][wgsl] (WebGPU Shading Language).
+
+[wgsl]: https://gpuweb.github.io/gpuweb/wgsl.html
+*/
+
+mod error;
+mod index;
+mod lower;
+mod parse;
+#[cfg(test)]
+mod tests;
+mod to_wgsl;
+
+use crate::front::wgsl::error::Error;
+use crate::front::wgsl::parse::Parser;
+use thiserror::Error;
+
+pub use crate::front::wgsl::error::ParseError;
+use crate::front::wgsl::lower::Lowerer;
+use crate::Scalar;
+
+pub struct Frontend {
+ parser: Parser,
+}
+
+impl Frontend {
+ pub const fn new() -> Self {
+ Self {
+ parser: Parser::new(),
+ }
+ }
+
+ pub fn parse(&mut self, source: &str) -> Result<crate::Module, ParseError> {
+ self.inner(source).map_err(|x| x.as_parse_error(source))
+ }
+
+ fn inner<'a>(&mut self, source: &'a str) -> Result<crate::Module, Error<'a>> {
+ let tu = self.parser.parse(source)?;
+ let index = index::Index::generate(&tu)?;
+ let module = Lowerer::new(&index).lower(&tu)?;
+
+ Ok(module)
+ }
+}
+
+pub fn parse_str(source: &str) -> Result<crate::Module, ParseError> {
+ Frontend::new().parse(source)
+}
diff --git a/third_party/rust/naga/src/front/wgsl/parse/ast.rs b/third_party/rust/naga/src/front/wgsl/parse/ast.rs
new file mode 100644
index 0000000000..dbaac523cb
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/parse/ast.rs
@@ -0,0 +1,491 @@
+use crate::front::wgsl::parse::number::Number;
+use crate::front::wgsl::Scalar;
+use crate::{Arena, FastIndexSet, Handle, Span};
+use std::hash::Hash;
+
+#[derive(Debug, Default)]
+pub struct TranslationUnit<'a> {
+ pub decls: Arena<GlobalDecl<'a>>,
+ /// The common expressions arena for the entire translation unit.
+ ///
+ /// All functions, global initializers, array lengths, etc. store their
+ /// expressions here. We apportion these out to individual Naga
+ /// [`Function`]s' expression arenas at lowering time. Keeping them all in a
+ /// single arena simplifies handling of things like array lengths (which are
+ /// effectively global and thus don't clearly belong to any function) and
+ /// initializers (which can appear in both function-local and module-scope
+ /// contexts).
+ ///
+ /// [`Function`]: crate::Function
+ pub expressions: Arena<Expression<'a>>,
+
+ /// Non-user-defined types, like `vec4<f32>` or `array<i32, 10>`.
+ ///
+ /// These are referred to by `Handle<ast::Type<'a>>` values.
+ /// User-defined types are referred to by name until lowering.
+ pub types: Arena<Type<'a>>,
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct Ident<'a> {
+ pub name: &'a str,
+ pub span: Span,
+}
+
+#[derive(Debug)]
+pub enum IdentExpr<'a> {
+ Unresolved(&'a str),
+ Local(Handle<Local>),
+}
+
+/// A reference to a module-scope definition or predeclared object.
+///
+/// Each [`GlobalDecl`] holds a set of these values, to be resolved to
+/// specific definitions later. To support de-duplication, `Eq` and
+/// `Hash` on a `Dependency` value consider only the name, not the
+/// source location at which the reference occurs.
+#[derive(Debug)]
+pub struct Dependency<'a> {
+ /// The name referred to.
+ pub ident: &'a str,
+
+ /// The location at which the reference to that name occurs.
+ pub usage: Span,
+}
+
+impl Hash for Dependency<'_> {
+ fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+ self.ident.hash(state);
+ }
+}
+
+impl PartialEq for Dependency<'_> {
+ fn eq(&self, other: &Self) -> bool {
+ self.ident == other.ident
+ }
+}
+
+impl Eq for Dependency<'_> {}
+
+/// A module-scope declaration.
+#[derive(Debug)]
+pub struct GlobalDecl<'a> {
+ pub kind: GlobalDeclKind<'a>,
+
+ /// Names of all module-scope or predeclared objects this
+ /// declaration uses.
+ pub dependencies: FastIndexSet<Dependency<'a>>,
+}
+
+#[derive(Debug)]
+pub enum GlobalDeclKind<'a> {
+ Fn(Function<'a>),
+ Var(GlobalVariable<'a>),
+ Const(Const<'a>),
+ Struct(Struct<'a>),
+ Type(TypeAlias<'a>),
+}
+
+#[derive(Debug)]
+pub struct FunctionArgument<'a> {
+ pub name: Ident<'a>,
+ pub ty: Handle<Type<'a>>,
+ pub binding: Option<Binding<'a>>,
+ pub handle: Handle<Local>,
+}
+
+#[derive(Debug)]
+pub struct FunctionResult<'a> {
+ pub ty: Handle<Type<'a>>,
+ pub binding: Option<Binding<'a>>,
+}
+
+#[derive(Debug)]
+pub struct EntryPoint<'a> {
+ pub stage: crate::ShaderStage,
+ pub early_depth_test: Option<crate::EarlyDepthTest>,
+ pub workgroup_size: Option<[Option<Handle<Expression<'a>>>; 3]>,
+}
+
+#[cfg(doc)]
+use crate::front::wgsl::lower::{RuntimeExpressionContext, StatementContext};
+
+#[derive(Debug)]
+pub struct Function<'a> {
+ pub entry_point: Option<EntryPoint<'a>>,
+ pub name: Ident<'a>,
+ pub arguments: Vec<FunctionArgument<'a>>,
+ pub result: Option<FunctionResult<'a>>,
+
+ /// Local variable and function argument arena.
+ ///
+ /// Note that the `Local` here is actually a zero-sized type. The AST keeps
+ /// all the detailed information about locals - names, types, etc. - in
+ /// [`LocalDecl`] statements. For arguments, that information is kept in
+ /// [`arguments`]. This `Arena`'s only role is to assign a unique `Handle`
+ /// to each of them, and track their definitions' spans for use in
+ /// diagnostics.
+ ///
+ /// In the AST, when an [`Ident`] expression refers to a local variable or
+ /// argument, its [`IdentExpr`] holds the referent's `Handle<Local>` in this
+ /// arena.
+ ///
+ /// During lowering, [`LocalDecl`] statements add entries to a per-function
+ /// table that maps `Handle<Local>` values to their Naga representations,
+ /// accessed via [`StatementContext::local_table`] and
+ /// [`RuntimeExpressionContext::local_table`]. This table is then consulted when
+ /// lowering subsequent [`Ident`] expressions.
+ ///
+ /// [`LocalDecl`]: StatementKind::LocalDecl
+ /// [`arguments`]: Function::arguments
+ /// [`Ident`]: Expression::Ident
+ /// [`StatementContext::local_table`]: StatementContext::local_table
+ /// [`RuntimeExpressionContext::local_table`]: RuntimeExpressionContext::local_table
+ pub locals: Arena<Local>,
+
+ pub body: Block<'a>,
+}
+
+#[derive(Debug)]
+pub enum Binding<'a> {
+ BuiltIn(crate::BuiltIn),
+ Location {
+ location: Handle<Expression<'a>>,
+ second_blend_source: bool,
+ interpolation: Option<crate::Interpolation>,
+ sampling: Option<crate::Sampling>,
+ },
+}
+
+#[derive(Debug)]
+pub struct ResourceBinding<'a> {
+ pub group: Handle<Expression<'a>>,
+ pub binding: Handle<Expression<'a>>,
+}
+
+#[derive(Debug)]
+pub struct GlobalVariable<'a> {
+ pub name: Ident<'a>,
+ pub space: crate::AddressSpace,
+ pub binding: Option<ResourceBinding<'a>>,
+ pub ty: Handle<Type<'a>>,
+ pub init: Option<Handle<Expression<'a>>>,
+}
+
+#[derive(Debug)]
+pub struct StructMember<'a> {
+ pub name: Ident<'a>,
+ pub ty: Handle<Type<'a>>,
+ pub binding: Option<Binding<'a>>,
+ pub align: Option<Handle<Expression<'a>>>,
+ pub size: Option<Handle<Expression<'a>>>,
+}
+
+#[derive(Debug)]
+pub struct Struct<'a> {
+ pub name: Ident<'a>,
+ pub members: Vec<StructMember<'a>>,
+}
+
+#[derive(Debug)]
+pub struct TypeAlias<'a> {
+ pub name: Ident<'a>,
+ pub ty: Handle<Type<'a>>,
+}
+
+#[derive(Debug)]
+pub struct Const<'a> {
+ pub name: Ident<'a>,
+ pub ty: Option<Handle<Type<'a>>>,
+ pub init: Handle<Expression<'a>>,
+}
+
+/// The size of an [`Array`] or [`BindingArray`].
+///
+/// [`Array`]: Type::Array
+/// [`BindingArray`]: Type::BindingArray
+#[derive(Debug, Copy, Clone)]
+pub enum ArraySize<'a> {
+ /// The length as a constant expression.
+ Constant(Handle<Expression<'a>>),
+ Dynamic,
+}
+
+#[derive(Debug)]
+pub enum Type<'a> {
+ Scalar(Scalar),
+ Vector {
+ size: crate::VectorSize,
+ scalar: Scalar,
+ },
+ Matrix {
+ columns: crate::VectorSize,
+ rows: crate::VectorSize,
+ width: crate::Bytes,
+ },
+ Atomic(Scalar),
+ Pointer {
+ base: Handle<Type<'a>>,
+ space: crate::AddressSpace,
+ },
+ Array {
+ base: Handle<Type<'a>>,
+ size: ArraySize<'a>,
+ },
+ Image {
+ dim: crate::ImageDimension,
+ arrayed: bool,
+ class: crate::ImageClass,
+ },
+ Sampler {
+ comparison: bool,
+ },
+ AccelerationStructure,
+ RayQuery,
+ RayDesc,
+ RayIntersection,
+ BindingArray {
+ base: Handle<Type<'a>>,
+ size: ArraySize<'a>,
+ },
+
+ /// A user-defined type, like a struct or a type alias.
+ User(Ident<'a>),
+}
+
+#[derive(Debug, Default)]
+pub struct Block<'a> {
+ pub stmts: Vec<Statement<'a>>,
+}
+
+#[derive(Debug)]
+pub struct Statement<'a> {
+ pub kind: StatementKind<'a>,
+ pub span: Span,
+}
+
+#[derive(Debug)]
+pub enum StatementKind<'a> {
+ LocalDecl(LocalDecl<'a>),
+ Block(Block<'a>),
+ If {
+ condition: Handle<Expression<'a>>,
+ accept: Block<'a>,
+ reject: Block<'a>,
+ },
+ Switch {
+ selector: Handle<Expression<'a>>,
+ cases: Vec<SwitchCase<'a>>,
+ },
+ Loop {
+ body: Block<'a>,
+ continuing: Block<'a>,
+ break_if: Option<Handle<Expression<'a>>>,
+ },
+ Break,
+ Continue,
+ Return {
+ value: Option<Handle<Expression<'a>>>,
+ },
+ Kill,
+ Call {
+ function: Ident<'a>,
+ arguments: Vec<Handle<Expression<'a>>>,
+ },
+ Assign {
+ target: Handle<Expression<'a>>,
+ op: Option<crate::BinaryOperator>,
+ value: Handle<Expression<'a>>,
+ },
+ Increment(Handle<Expression<'a>>),
+ Decrement(Handle<Expression<'a>>),
+ Ignore(Handle<Expression<'a>>),
+}
+
+#[derive(Debug)]
+pub enum SwitchValue<'a> {
+ Expr(Handle<Expression<'a>>),
+ Default,
+}
+
+#[derive(Debug)]
+pub struct SwitchCase<'a> {
+ pub value: SwitchValue<'a>,
+ pub body: Block<'a>,
+ pub fall_through: bool,
+}
+
+/// A type at the head of a [`Construct`] expression.
+///
+/// WGSL has two types of [`type constructor expressions`]:
+///
+/// - Those that fully specify the type being constructed, like
+/// `vec3<f32>(x,y,z)`, which obviously constructs a `vec3<f32>`.
+///
+/// - Those that leave the component type of the composite being constructed
+/// implicit, to be inferred from the argument types, like `vec3(x,y,z)`,
+/// which constructs a `vec3<T>` where `T` is the type of `x`, `y`, and `z`.
+///
+/// This enum represents the head type of both cases. The `PartialFoo` variants
+/// represent the second case, where the component type is implicit.
+///
+/// This does not cover structs or types referred to by type aliases. See the
+/// documentation for [`Construct`] and [`Call`] expressions for details.
+///
+/// [`Construct`]: Expression::Construct
+/// [`type constructor expressions`]: https://gpuweb.github.io/gpuweb/wgsl/#type-constructor-expr
+/// [`Call`]: Expression::Call
+#[derive(Debug)]
+pub enum ConstructorType<'a> {
+ /// A scalar type or conversion: `f32(1)`.
+ Scalar(Scalar),
+
+ /// A vector construction whose component type is inferred from the
+ /// argument: `vec3(1.0)`.
+ PartialVector { size: crate::VectorSize },
+
+ /// A vector construction whose component type is written out:
+ /// `vec3<f32>(1.0)`.
+ Vector {
+ size: crate::VectorSize,
+ scalar: Scalar,
+ },
+
+ /// A matrix construction whose component type is inferred from the
+ /// argument: `mat2x2(1,2,3,4)`.
+ PartialMatrix {
+ columns: crate::VectorSize,
+ rows: crate::VectorSize,
+ },
+
+ /// A matrix construction whose component type is written out:
+ /// `mat2x2<f32>(1,2,3,4)`.
+ Matrix {
+ columns: crate::VectorSize,
+ rows: crate::VectorSize,
+ width: crate::Bytes,
+ },
+
+ /// An array whose component type and size are inferred from the arguments:
+ /// `array(3,4,5)`.
+ PartialArray,
+
+ /// An array whose component type and size are written out:
+ /// `array<u32, 4>(3,4,5)`.
+ Array {
+ base: Handle<Type<'a>>,
+ size: ArraySize<'a>,
+ },
+
+ /// Constructing a value of a known Naga IR type.
+ ///
+ /// This variant is produced only during lowering, when we have Naga types
+ /// available, never during parsing.
+ Type(Handle<crate::Type>),
+}
+
+#[derive(Debug, Copy, Clone)]
+pub enum Literal {
+ Bool(bool),
+ Number(Number),
+}
+
+#[cfg(doc)]
+use crate::front::wgsl::lower::Lowerer;
+
+#[derive(Debug)]
+pub enum Expression<'a> {
+ Literal(Literal),
+ Ident(IdentExpr<'a>),
+
+ /// A type constructor expression.
+ ///
+ /// This is only used for expressions like `KEYWORD(EXPR...)` and
+ /// `KEYWORD<PARAM>(EXPR...)`, where `KEYWORD` is a [type-defining keyword] like
+ /// `vec3`. These keywords cannot be shadowed by user definitions, so we can
+ /// tell that such an expression is a construction immediately.
+ ///
+ /// For ordinary identifiers, we can't tell whether an expression like
+ /// `IDENTIFIER(EXPR, ...)` is a construction expression or a function call
+ /// until we know `IDENTIFIER`'s definition, so we represent those as
+ /// [`Call`] expressions.
+ ///
+ /// [type-defining keyword]: https://gpuweb.github.io/gpuweb/wgsl/#type-defining-keywords
+ /// [`Call`]: Expression::Call
+ Construct {
+ ty: ConstructorType<'a>,
+ ty_span: Span,
+ components: Vec<Handle<Expression<'a>>>,
+ },
+ Unary {
+ op: crate::UnaryOperator,
+ expr: Handle<Expression<'a>>,
+ },
+ AddrOf(Handle<Expression<'a>>),
+ Deref(Handle<Expression<'a>>),
+ Binary {
+ op: crate::BinaryOperator,
+ left: Handle<Expression<'a>>,
+ right: Handle<Expression<'a>>,
+ },
+
+ /// A function call or type constructor expression.
+ ///
+ /// We can't tell whether an expression like `IDENTIFIER(EXPR, ...)` is a
+ /// construction expression or a function call until we know `IDENTIFIER`'s
+ /// definition, so we represent everything of that form as one of these
+ /// expressions until lowering. At that point, [`Lowerer::call`] has
+ /// everything's definition in hand, and can decide whether to emit a Naga
+ /// [`Constant`], [`As`], [`Splat`], or [`Compose`] expression.
+ ///
+ /// [`Lowerer::call`]: Lowerer::call
+ /// [`Constant`]: crate::Expression::Constant
+ /// [`As`]: crate::Expression::As
+ /// [`Splat`]: crate::Expression::Splat
+ /// [`Compose`]: crate::Expression::Compose
+ Call {
+ function: Ident<'a>,
+ arguments: Vec<Handle<Expression<'a>>>,
+ },
+ Index {
+ base: Handle<Expression<'a>>,
+ index: Handle<Expression<'a>>,
+ },
+ Member {
+ base: Handle<Expression<'a>>,
+ field: Ident<'a>,
+ },
+ Bitcast {
+ expr: Handle<Expression<'a>>,
+ to: Handle<Type<'a>>,
+ ty_span: Span,
+ },
+}
+
+#[derive(Debug)]
+pub struct LocalVariable<'a> {
+ pub name: Ident<'a>,
+ pub ty: Option<Handle<Type<'a>>>,
+ pub init: Option<Handle<Expression<'a>>>,
+ pub handle: Handle<Local>,
+}
+
+#[derive(Debug)]
+pub struct Let<'a> {
+ pub name: Ident<'a>,
+ pub ty: Option<Handle<Type<'a>>>,
+ pub init: Handle<Expression<'a>>,
+ pub handle: Handle<Local>,
+}
+
+#[derive(Debug)]
+pub enum LocalDecl<'a> {
+ Var(LocalVariable<'a>),
+ Let(Let<'a>),
+}
+
+#[derive(Debug)]
+/// A placeholder for a local variable declaration.
+///
+/// See [`Function::locals`] for more information.
+pub struct Local;
diff --git a/third_party/rust/naga/src/front/wgsl/parse/conv.rs b/third_party/rust/naga/src/front/wgsl/parse/conv.rs
new file mode 100644
index 0000000000..08f1e39285
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/parse/conv.rs
@@ -0,0 +1,254 @@
+use super::Error;
+use crate::front::wgsl::Scalar;
+use crate::Span;
+
+pub fn map_address_space(word: &str, span: Span) -> Result<crate::AddressSpace, Error<'_>> {
+ match word {
+ "private" => Ok(crate::AddressSpace::Private),
+ "workgroup" => Ok(crate::AddressSpace::WorkGroup),
+ "uniform" => Ok(crate::AddressSpace::Uniform),
+ "storage" => Ok(crate::AddressSpace::Storage {
+ access: crate::StorageAccess::default(),
+ }),
+ "push_constant" => Ok(crate::AddressSpace::PushConstant),
+ "function" => Ok(crate::AddressSpace::Function),
+ _ => Err(Error::UnknownAddressSpace(span)),
+ }
+}
+
+pub fn map_built_in(word: &str, span: Span) -> Result<crate::BuiltIn, Error<'_>> {
+ Ok(match word {
+ "position" => crate::BuiltIn::Position { invariant: false },
+ // vertex
+ "vertex_index" => crate::BuiltIn::VertexIndex,
+ "instance_index" => crate::BuiltIn::InstanceIndex,
+ "view_index" => crate::BuiltIn::ViewIndex,
+ // fragment
+ "front_facing" => crate::BuiltIn::FrontFacing,
+ "frag_depth" => crate::BuiltIn::FragDepth,
+ "primitive_index" => crate::BuiltIn::PrimitiveIndex,
+ "sample_index" => crate::BuiltIn::SampleIndex,
+ "sample_mask" => crate::BuiltIn::SampleMask,
+ // compute
+ "global_invocation_id" => crate::BuiltIn::GlobalInvocationId,
+ "local_invocation_id" => crate::BuiltIn::LocalInvocationId,
+ "local_invocation_index" => crate::BuiltIn::LocalInvocationIndex,
+ "workgroup_id" => crate::BuiltIn::WorkGroupId,
+ "num_workgroups" => crate::BuiltIn::NumWorkGroups,
+ _ => return Err(Error::UnknownBuiltin(span)),
+ })
+}
+
+pub fn map_interpolation(word: &str, span: Span) -> Result<crate::Interpolation, Error<'_>> {
+ match word {
+ "linear" => Ok(crate::Interpolation::Linear),
+ "flat" => Ok(crate::Interpolation::Flat),
+ "perspective" => Ok(crate::Interpolation::Perspective),
+ _ => Err(Error::UnknownAttribute(span)),
+ }
+}
+
+pub fn map_sampling(word: &str, span: Span) -> Result<crate::Sampling, Error<'_>> {
+ match word {
+ "center" => Ok(crate::Sampling::Center),
+ "centroid" => Ok(crate::Sampling::Centroid),
+ "sample" => Ok(crate::Sampling::Sample),
+ _ => Err(Error::UnknownAttribute(span)),
+ }
+}
+
+pub fn map_storage_format(word: &str, span: Span) -> Result<crate::StorageFormat, Error<'_>> {
+ use crate::StorageFormat as Sf;
+ Ok(match word {
+ "r8unorm" => Sf::R8Unorm,
+ "r8snorm" => Sf::R8Snorm,
+ "r8uint" => Sf::R8Uint,
+ "r8sint" => Sf::R8Sint,
+ "r16unorm" => Sf::R16Unorm,
+ "r16snorm" => Sf::R16Snorm,
+ "r16uint" => Sf::R16Uint,
+ "r16sint" => Sf::R16Sint,
+ "r16float" => Sf::R16Float,
+ "rg8unorm" => Sf::Rg8Unorm,
+ "rg8snorm" => Sf::Rg8Snorm,
+ "rg8uint" => Sf::Rg8Uint,
+ "rg8sint" => Sf::Rg8Sint,
+ "r32uint" => Sf::R32Uint,
+ "r32sint" => Sf::R32Sint,
+ "r32float" => Sf::R32Float,
+ "rg16unorm" => Sf::Rg16Unorm,
+ "rg16snorm" => Sf::Rg16Snorm,
+ "rg16uint" => Sf::Rg16Uint,
+ "rg16sint" => Sf::Rg16Sint,
+ "rg16float" => Sf::Rg16Float,
+ "rgba8unorm" => Sf::Rgba8Unorm,
+ "rgba8snorm" => Sf::Rgba8Snorm,
+ "rgba8uint" => Sf::Rgba8Uint,
+ "rgba8sint" => Sf::Rgba8Sint,
+ "rgb10a2uint" => Sf::Rgb10a2Uint,
+ "rgb10a2unorm" => Sf::Rgb10a2Unorm,
+ "rg11b10float" => Sf::Rg11b10Float,
+ "rg32uint" => Sf::Rg32Uint,
+ "rg32sint" => Sf::Rg32Sint,
+ "rg32float" => Sf::Rg32Float,
+ "rgba16unorm" => Sf::Rgba16Unorm,
+ "rgba16snorm" => Sf::Rgba16Snorm,
+ "rgba16uint" => Sf::Rgba16Uint,
+ "rgba16sint" => Sf::Rgba16Sint,
+ "rgba16float" => Sf::Rgba16Float,
+ "rgba32uint" => Sf::Rgba32Uint,
+ "rgba32sint" => Sf::Rgba32Sint,
+ "rgba32float" => Sf::Rgba32Float,
+ "bgra8unorm" => Sf::Bgra8Unorm,
+ _ => return Err(Error::UnknownStorageFormat(span)),
+ })
+}
+
+pub fn get_scalar_type(word: &str) -> Option<Scalar> {
+ use crate::ScalarKind as Sk;
+ match word {
+ // "f16" => Some(Scalar { kind: Sk::Float, width: 2 }),
+ "f32" => Some(Scalar {
+ kind: Sk::Float,
+ width: 4,
+ }),
+ "f64" => Some(Scalar {
+ kind: Sk::Float,
+ width: 8,
+ }),
+ "i32" => Some(Scalar {
+ kind: Sk::Sint,
+ width: 4,
+ }),
+ "u32" => Some(Scalar {
+ kind: Sk::Uint,
+ width: 4,
+ }),
+ "bool" => Some(Scalar {
+ kind: Sk::Bool,
+ width: crate::BOOL_WIDTH,
+ }),
+ _ => None,
+ }
+}
+
+pub fn map_derivative(word: &str) -> Option<(crate::DerivativeAxis, crate::DerivativeControl)> {
+ use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
+ match word {
+ "dpdxCoarse" => Some((Axis::X, Ctrl::Coarse)),
+ "dpdyCoarse" => Some((Axis::Y, Ctrl::Coarse)),
+ "fwidthCoarse" => Some((Axis::Width, Ctrl::Coarse)),
+ "dpdxFine" => Some((Axis::X, Ctrl::Fine)),
+ "dpdyFine" => Some((Axis::Y, Ctrl::Fine)),
+ "fwidthFine" => Some((Axis::Width, Ctrl::Fine)),
+ "dpdx" => Some((Axis::X, Ctrl::None)),
+ "dpdy" => Some((Axis::Y, Ctrl::None)),
+ "fwidth" => Some((Axis::Width, Ctrl::None)),
+ _ => None,
+ }
+}
+
+pub fn map_relational_fun(word: &str) -> Option<crate::RelationalFunction> {
+ match word {
+ "any" => Some(crate::RelationalFunction::Any),
+ "all" => Some(crate::RelationalFunction::All),
+ _ => None,
+ }
+}
+
+pub fn map_standard_fun(word: &str) -> Option<crate::MathFunction> {
+ use crate::MathFunction as Mf;
+ Some(match word {
+ // comparison
+ "abs" => Mf::Abs,
+ "min" => Mf::Min,
+ "max" => Mf::Max,
+ "clamp" => Mf::Clamp,
+ "saturate" => Mf::Saturate,
+ // trigonometry
+ "cos" => Mf::Cos,
+ "cosh" => Mf::Cosh,
+ "sin" => Mf::Sin,
+ "sinh" => Mf::Sinh,
+ "tan" => Mf::Tan,
+ "tanh" => Mf::Tanh,
+ "acos" => Mf::Acos,
+ "acosh" => Mf::Acosh,
+ "asin" => Mf::Asin,
+ "asinh" => Mf::Asinh,
+ "atan" => Mf::Atan,
+ "atanh" => Mf::Atanh,
+ "atan2" => Mf::Atan2,
+ "radians" => Mf::Radians,
+ "degrees" => Mf::Degrees,
+ // decomposition
+ "ceil" => Mf::Ceil,
+ "floor" => Mf::Floor,
+ "round" => Mf::Round,
+ "fract" => Mf::Fract,
+ "trunc" => Mf::Trunc,
+ "modf" => Mf::Modf,
+ "frexp" => Mf::Frexp,
+ "ldexp" => Mf::Ldexp,
+ // exponent
+ "exp" => Mf::Exp,
+ "exp2" => Mf::Exp2,
+ "log" => Mf::Log,
+ "log2" => Mf::Log2,
+ "pow" => Mf::Pow,
+ // geometry
+ "dot" => Mf::Dot,
+ "cross" => Mf::Cross,
+ "distance" => Mf::Distance,
+ "length" => Mf::Length,
+ "normalize" => Mf::Normalize,
+ "faceForward" => Mf::FaceForward,
+ "reflect" => Mf::Reflect,
+ "refract" => Mf::Refract,
+ // computational
+ "sign" => Mf::Sign,
+ "fma" => Mf::Fma,
+ "mix" => Mf::Mix,
+ "step" => Mf::Step,
+ "smoothstep" => Mf::SmoothStep,
+ "sqrt" => Mf::Sqrt,
+ "inverseSqrt" => Mf::InverseSqrt,
+ "transpose" => Mf::Transpose,
+ "determinant" => Mf::Determinant,
+ // bits
+ "countTrailingZeros" => Mf::CountTrailingZeros,
+ "countLeadingZeros" => Mf::CountLeadingZeros,
+ "countOneBits" => Mf::CountOneBits,
+ "reverseBits" => Mf::ReverseBits,
+ "extractBits" => Mf::ExtractBits,
+ "insertBits" => Mf::InsertBits,
+ "firstTrailingBit" => Mf::FindLsb,
+ "firstLeadingBit" => Mf::FindMsb,
+ // data packing
+ "pack4x8snorm" => Mf::Pack4x8snorm,
+ "pack4x8unorm" => Mf::Pack4x8unorm,
+ "pack2x16snorm" => Mf::Pack2x16snorm,
+ "pack2x16unorm" => Mf::Pack2x16unorm,
+ "pack2x16float" => Mf::Pack2x16float,
+ // data unpacking
+ "unpack4x8snorm" => Mf::Unpack4x8snorm,
+ "unpack4x8unorm" => Mf::Unpack4x8unorm,
+ "unpack2x16snorm" => Mf::Unpack2x16snorm,
+ "unpack2x16unorm" => Mf::Unpack2x16unorm,
+ "unpack2x16float" => Mf::Unpack2x16float,
+ _ => return None,
+ })
+}
+
+pub fn map_conservative_depth(
+ word: &str,
+ span: Span,
+) -> Result<crate::ConservativeDepth, Error<'_>> {
+ use crate::ConservativeDepth as Cd;
+ match word {
+ "greater_equal" => Ok(Cd::GreaterEqual),
+ "less_equal" => Ok(Cd::LessEqual),
+ "unchanged" => Ok(Cd::Unchanged),
+ _ => Err(Error::UnknownConservativeDepth(span)),
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/parse/lexer.rs b/third_party/rust/naga/src/front/wgsl/parse/lexer.rs
new file mode 100644
index 0000000000..d03a448561
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/parse/lexer.rs
@@ -0,0 +1,739 @@
+use super::{number::consume_number, Error, ExpectedToken};
+use crate::front::wgsl::error::NumberError;
+use crate::front::wgsl::parse::{conv, Number};
+use crate::front::wgsl::Scalar;
+use crate::Span;
+
+type TokenSpan<'a> = (Token<'a>, Span);
+
+#[derive(Copy, Clone, Debug, PartialEq)]
+pub enum Token<'a> {
+ Separator(char),
+ Paren(char),
+ Attribute,
+ Number(Result<Number, NumberError>),
+ Word(&'a str),
+ Operation(char),
+ LogicalOperation(char),
+ ShiftOperation(char),
+ AssignmentOperation(char),
+ IncrementOperation,
+ DecrementOperation,
+ Arrow,
+ Unknown(char),
+ Trivia,
+ End,
+}
+
+fn consume_any(input: &str, what: impl Fn(char) -> bool) -> (&str, &str) {
+ let pos = input.find(|c| !what(c)).unwrap_or(input.len());
+ input.split_at(pos)
+}
+
+/// Return the token at the start of `input`.
+///
+/// If `generic` is `false`, then the bit shift operators `>>` or `<<`
+/// are valid lookahead tokens for the current parser state (see [§3.1
+/// Parsing] in the WGSL specification). In other words:
+///
+/// - If `generic` is `true`, then we are expecting an angle bracket
+/// around a generic type parameter, like the `<` and `>` in
+/// `vec3<f32>`, so interpret `<` and `>` as `Token::Paren` tokens,
+/// even if they're part of `<<` or `>>` sequences.
+///
+/// - Otherwise, interpret `<<` and `>>` as shift operators:
+/// `Token::LogicalOperation` tokens.
+///
+/// [§3.1 Parsing]: https://gpuweb.github.io/gpuweb/wgsl/#parsing
+fn consume_token(input: &str, generic: bool) -> (Token<'_>, &str) {
+ let mut chars = input.chars();
+ let cur = match chars.next() {
+ Some(c) => c,
+ None => return (Token::End, ""),
+ };
+ match cur {
+ ':' | ';' | ',' => (Token::Separator(cur), chars.as_str()),
+ '.' => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some('0'..='9') => consume_number(input),
+ _ => (Token::Separator(cur), og_chars),
+ }
+ }
+ '@' => (Token::Attribute, chars.as_str()),
+ '(' | ')' | '{' | '}' | '[' | ']' => (Token::Paren(cur), chars.as_str()),
+ '<' | '>' => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some('=') if !generic => (Token::LogicalOperation(cur), chars.as_str()),
+ Some(c) if c == cur && !generic => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some('=') => (Token::AssignmentOperation(cur), chars.as_str()),
+ _ => (Token::ShiftOperation(cur), og_chars),
+ }
+ }
+ _ => (Token::Paren(cur), og_chars),
+ }
+ }
+ '0'..='9' => consume_number(input),
+ '/' => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some('/') => {
+ let _ = chars.position(is_comment_end);
+ (Token::Trivia, chars.as_str())
+ }
+ Some('*') => {
+ let mut depth = 1;
+ let mut prev = None;
+
+ for c in &mut chars {
+ match (prev, c) {
+ (Some('*'), '/') => {
+ prev = None;
+ depth -= 1;
+ if depth == 0 {
+ return (Token::Trivia, chars.as_str());
+ }
+ }
+ (Some('/'), '*') => {
+ prev = None;
+ depth += 1;
+ }
+ _ => {
+ prev = Some(c);
+ }
+ }
+ }
+
+ (Token::End, "")
+ }
+ Some('=') => (Token::AssignmentOperation(cur), chars.as_str()),
+ _ => (Token::Operation(cur), og_chars),
+ }
+ }
+ '-' => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some('>') => (Token::Arrow, chars.as_str()),
+ Some('-') => (Token::DecrementOperation, chars.as_str()),
+ Some('=') => (Token::AssignmentOperation(cur), chars.as_str()),
+ _ => (Token::Operation(cur), og_chars),
+ }
+ }
+ '+' => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some('+') => (Token::IncrementOperation, chars.as_str()),
+ Some('=') => (Token::AssignmentOperation(cur), chars.as_str()),
+ _ => (Token::Operation(cur), og_chars),
+ }
+ }
+ '*' | '%' | '^' => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some('=') => (Token::AssignmentOperation(cur), chars.as_str()),
+ _ => (Token::Operation(cur), og_chars),
+ }
+ }
+ '~' => (Token::Operation(cur), chars.as_str()),
+ '=' | '!' => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some('=') => (Token::LogicalOperation(cur), chars.as_str()),
+ _ => (Token::Operation(cur), og_chars),
+ }
+ }
+ '&' | '|' => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some(c) if c == cur => (Token::LogicalOperation(cur), chars.as_str()),
+ Some('=') => (Token::AssignmentOperation(cur), chars.as_str()),
+ _ => (Token::Operation(cur), og_chars),
+ }
+ }
+ _ if is_blankspace(cur) => {
+ let (_, rest) = consume_any(input, is_blankspace);
+ (Token::Trivia, rest)
+ }
+ _ if is_word_start(cur) => {
+ let (word, rest) = consume_any(input, is_word_part);
+ (Token::Word(word), rest)
+ }
+ _ => (Token::Unknown(cur), chars.as_str()),
+ }
+}
+
+/// Returns whether or not a char is a comment end
+/// (Unicode Pattern_White_Space excluding U+0020, U+0009, U+200E and U+200F)
+const fn is_comment_end(c: char) -> bool {
+ match c {
+ '\u{000a}'..='\u{000d}' | '\u{0085}' | '\u{2028}' | '\u{2029}' => true,
+ _ => false,
+ }
+}
+
+/// Returns whether or not a char is a blankspace (Unicode Pattern_White_Space)
+const fn is_blankspace(c: char) -> bool {
+ match c {
+ '\u{0020}'
+ | '\u{0009}'..='\u{000d}'
+ | '\u{0085}'
+ | '\u{200e}'
+ | '\u{200f}'
+ | '\u{2028}'
+ | '\u{2029}' => true,
+ _ => false,
+ }
+}
+
+/// Returns whether or not a char is a word start (Unicode XID_Start + '_')
+fn is_word_start(c: char) -> bool {
+ c == '_' || unicode_xid::UnicodeXID::is_xid_start(c)
+}
+
+/// Returns whether or not a char is a word part (Unicode XID_Continue)
+fn is_word_part(c: char) -> bool {
+ unicode_xid::UnicodeXID::is_xid_continue(c)
+}
+
+#[derive(Clone)]
+pub(in crate::front::wgsl) struct Lexer<'a> {
+ input: &'a str,
+ pub(in crate::front::wgsl) source: &'a str,
+ // The byte offset of the end of the last non-trivia token.
+ last_end_offset: usize,
+}
+
+impl<'a> Lexer<'a> {
+ pub(in crate::front::wgsl) const fn new(input: &'a str) -> Self {
+ Lexer {
+ input,
+ source: input,
+ last_end_offset: 0,
+ }
+ }
+
+ /// Calls the function with a lexer and returns the result of the function as well as the span for everything the function parsed
+ ///
+ /// # Examples
+ /// ```ignore
+ /// let lexer = Lexer::new("5");
+ /// let (value, span) = lexer.capture_span(Lexer::next_uint_literal);
+ /// assert_eq!(value, 5);
+ /// ```
+ #[inline]
+ pub fn capture_span<T, E>(
+ &mut self,
+ inner: impl FnOnce(&mut Self) -> Result<T, E>,
+ ) -> Result<(T, Span), E> {
+ let start = self.current_byte_offset();
+ let res = inner(self)?;
+ let end = self.current_byte_offset();
+ Ok((res, Span::from(start..end)))
+ }
+
+ pub(in crate::front::wgsl) fn start_byte_offset(&mut self) -> usize {
+ loop {
+ // Eat all trivia because `next` doesn't eat trailing trivia.
+ let (token, rest) = consume_token(self.input, false);
+ if let Token::Trivia = token {
+ self.input = rest;
+ } else {
+ return self.current_byte_offset();
+ }
+ }
+ }
+
+ fn peek_token_and_rest(&mut self) -> (TokenSpan<'a>, &'a str) {
+ let mut cloned = self.clone();
+ let token = cloned.next();
+ let rest = cloned.input;
+ (token, rest)
+ }
+
+ const fn current_byte_offset(&self) -> usize {
+ self.source.len() - self.input.len()
+ }
+
+ pub(in crate::front::wgsl) fn span_from(&self, offset: usize) -> Span {
+ Span::from(offset..self.last_end_offset)
+ }
+
+ /// Return the next non-whitespace token from `self`.
+ ///
+ /// Assume we are a parse state where bit shift operators may
+ /// occur, but not angle brackets.
+ #[must_use]
+ pub(in crate::front::wgsl) fn next(&mut self) -> TokenSpan<'a> {
+ self.next_impl(false)
+ }
+
+ /// Return the next non-whitespace token from `self`.
+ ///
+ /// Assume we are in a parse state where angle brackets may occur,
+ /// but not bit shift operators.
+ #[must_use]
+ pub(in crate::front::wgsl) fn next_generic(&mut self) -> TokenSpan<'a> {
+ self.next_impl(true)
+ }
+
+ /// Return the next non-whitespace token from `self`, with a span.
+ ///
+ /// See [`consume_token`] for the meaning of `generic`.
+ fn next_impl(&mut self, generic: bool) -> TokenSpan<'a> {
+ let mut start_byte_offset = self.current_byte_offset();
+ loop {
+ let (token, rest) = consume_token(self.input, generic);
+ self.input = rest;
+ match token {
+ Token::Trivia => start_byte_offset = self.current_byte_offset(),
+ _ => {
+ self.last_end_offset = self.current_byte_offset();
+ return (token, self.span_from(start_byte_offset));
+ }
+ }
+ }
+ }
+
+ #[must_use]
+ pub(in crate::front::wgsl) fn peek(&mut self) -> TokenSpan<'a> {
+ let (token, _) = self.peek_token_and_rest();
+ token
+ }
+
+ pub(in crate::front::wgsl) fn expect_span(
+ &mut self,
+ expected: Token<'a>,
+ ) -> Result<Span, Error<'a>> {
+ let next = self.next();
+ if next.0 == expected {
+ Ok(next.1)
+ } else {
+ Err(Error::Unexpected(next.1, ExpectedToken::Token(expected)))
+ }
+ }
+
+ pub(in crate::front::wgsl) fn expect(&mut self, expected: Token<'a>) -> Result<(), Error<'a>> {
+ self.expect_span(expected)?;
+ Ok(())
+ }
+
+ pub(in crate::front::wgsl) fn expect_generic_paren(
+ &mut self,
+ expected: char,
+ ) -> Result<(), Error<'a>> {
+ let next = self.next_generic();
+ if next.0 == Token::Paren(expected) {
+ Ok(())
+ } else {
+ Err(Error::Unexpected(
+ next.1,
+ ExpectedToken::Token(Token::Paren(expected)),
+ ))
+ }
+ }
+
+ /// If the next token matches it is skipped and true is returned
+ pub(in crate::front::wgsl) fn skip(&mut self, what: Token<'_>) -> bool {
+ let (peeked_token, rest) = self.peek_token_and_rest();
+ if peeked_token.0 == what {
+ self.input = rest;
+ true
+ } else {
+ false
+ }
+ }
+
+ pub(in crate::front::wgsl) fn next_ident_with_span(
+ &mut self,
+ ) -> Result<(&'a str, Span), Error<'a>> {
+ match self.next() {
+ (Token::Word("_"), span) => Err(Error::InvalidIdentifierUnderscore(span)),
+ (Token::Word(word), span) if word.starts_with("__") => {
+ Err(Error::ReservedIdentifierPrefix(span))
+ }
+ (Token::Word(word), span) => Ok((word, span)),
+ other => Err(Error::Unexpected(other.1, ExpectedToken::Identifier)),
+ }
+ }
+
+ pub(in crate::front::wgsl) fn next_ident(
+ &mut self,
+ ) -> Result<super::ast::Ident<'a>, Error<'a>> {
+ let ident = self
+ .next_ident_with_span()
+ .map(|(name, span)| super::ast::Ident { name, span })?;
+
+ if crate::keywords::wgsl::RESERVED.contains(&ident.name) {
+ return Err(Error::ReservedKeyword(ident.span));
+ }
+
+ Ok(ident)
+ }
+
+ /// Parses a generic scalar type, for example `<f32>`.
+ pub(in crate::front::wgsl) fn next_scalar_generic(&mut self) -> Result<Scalar, Error<'a>> {
+ self.expect_generic_paren('<')?;
+ let pair = match self.next() {
+ (Token::Word(word), span) => {
+ conv::get_scalar_type(word).ok_or(Error::UnknownScalarType(span))
+ }
+ (_, span) => Err(Error::UnknownScalarType(span)),
+ }?;
+ self.expect_generic_paren('>')?;
+ Ok(pair)
+ }
+
+ /// Parses a generic scalar type, for example `<f32>`.
+ ///
+ /// Returns the span covering the inner type, excluding the brackets.
+ pub(in crate::front::wgsl) fn next_scalar_generic_with_span(
+ &mut self,
+ ) -> Result<(Scalar, Span), Error<'a>> {
+ self.expect_generic_paren('<')?;
+ let pair = match self.next() {
+ (Token::Word(word), span) => conv::get_scalar_type(word)
+ .map(|scalar| (scalar, span))
+ .ok_or(Error::UnknownScalarType(span)),
+ (_, span) => Err(Error::UnknownScalarType(span)),
+ }?;
+ self.expect_generic_paren('>')?;
+ Ok(pair)
+ }
+
+ pub(in crate::front::wgsl) fn next_storage_access(
+ &mut self,
+ ) -> Result<crate::StorageAccess, Error<'a>> {
+ let (ident, span) = self.next_ident_with_span()?;
+ match ident {
+ "read" => Ok(crate::StorageAccess::LOAD),
+ "write" => Ok(crate::StorageAccess::STORE),
+ "read_write" => Ok(crate::StorageAccess::LOAD | crate::StorageAccess::STORE),
+ _ => Err(Error::UnknownAccess(span)),
+ }
+ }
+
+ pub(in crate::front::wgsl) fn next_format_generic(
+ &mut self,
+ ) -> Result<(crate::StorageFormat, crate::StorageAccess), Error<'a>> {
+ self.expect(Token::Paren('<'))?;
+ let (ident, ident_span) = self.next_ident_with_span()?;
+ let format = conv::map_storage_format(ident, ident_span)?;
+ self.expect(Token::Separator(','))?;
+ let access = self.next_storage_access()?;
+ self.expect(Token::Paren('>'))?;
+ Ok((format, access))
+ }
+
+ pub(in crate::front::wgsl) fn open_arguments(&mut self) -> Result<(), Error<'a>> {
+ self.expect(Token::Paren('('))
+ }
+
+ pub(in crate::front::wgsl) fn close_arguments(&mut self) -> Result<(), Error<'a>> {
+ let _ = self.skip(Token::Separator(','));
+ self.expect(Token::Paren(')'))
+ }
+
+ pub(in crate::front::wgsl) fn next_argument(&mut self) -> Result<bool, Error<'a>> {
+ let paren = Token::Paren(')');
+ if self.skip(Token::Separator(',')) {
+ Ok(!self.skip(paren))
+ } else {
+ self.expect(paren).map(|()| false)
+ }
+ }
+}
+
+#[cfg(test)]
+#[track_caller]
+fn sub_test(source: &str, expected_tokens: &[Token]) {
+ let mut lex = Lexer::new(source);
+ for &token in expected_tokens {
+ assert_eq!(lex.next().0, token);
+ }
+ assert_eq!(lex.next().0, Token::End);
+}
+
+#[test]
+fn test_numbers() {
+ // WGSL spec examples //
+
+ // decimal integer
+ sub_test(
+ "0x123 0X123u 1u 123 0 0i 0x3f",
+ &[
+ Token::Number(Ok(Number::AbstractInt(291))),
+ Token::Number(Ok(Number::U32(291))),
+ Token::Number(Ok(Number::U32(1))),
+ Token::Number(Ok(Number::AbstractInt(123))),
+ Token::Number(Ok(Number::AbstractInt(0))),
+ Token::Number(Ok(Number::I32(0))),
+ Token::Number(Ok(Number::AbstractInt(63))),
+ ],
+ );
+ // decimal floating point
+ sub_test(
+ "0.e+4f 01. .01 12.34 .0f 0h 1e-3 0xa.fp+2 0x1P+4f 0X.3 0x3p+2h 0X1.fp-4 0x3.2p+2h",
+ &[
+ Token::Number(Ok(Number::F32(0.))),
+ Token::Number(Ok(Number::AbstractFloat(1.))),
+ Token::Number(Ok(Number::AbstractFloat(0.01))),
+ Token::Number(Ok(Number::AbstractFloat(12.34))),
+ Token::Number(Ok(Number::F32(0.))),
+ Token::Number(Err(NumberError::UnimplementedF16)),
+ Token::Number(Ok(Number::AbstractFloat(0.001))),
+ Token::Number(Ok(Number::AbstractFloat(43.75))),
+ Token::Number(Ok(Number::F32(16.))),
+ Token::Number(Ok(Number::AbstractFloat(0.1875))),
+ Token::Number(Err(NumberError::UnimplementedF16)),
+ Token::Number(Ok(Number::AbstractFloat(0.12109375))),
+ Token::Number(Err(NumberError::UnimplementedF16)),
+ ],
+ );
+
+ // MIN / MAX //
+
+ // min / max decimal integer
+ sub_test(
+ "0i 2147483647i 2147483648i",
+ &[
+ Token::Number(Ok(Number::I32(0))),
+ Token::Number(Ok(Number::I32(i32::MAX))),
+ Token::Number(Err(NumberError::NotRepresentable)),
+ ],
+ );
+ // min / max decimal unsigned integer
+ sub_test(
+ "0u 4294967295u 4294967296u",
+ &[
+ Token::Number(Ok(Number::U32(u32::MIN))),
+ Token::Number(Ok(Number::U32(u32::MAX))),
+ Token::Number(Err(NumberError::NotRepresentable)),
+ ],
+ );
+
+ // min / max hexadecimal signed integer
+ sub_test(
+ "0x0i 0x7FFFFFFFi 0x80000000i",
+ &[
+ Token::Number(Ok(Number::I32(0))),
+ Token::Number(Ok(Number::I32(i32::MAX))),
+ Token::Number(Err(NumberError::NotRepresentable)),
+ ],
+ );
+ // min / max hexadecimal unsigned integer
+ sub_test(
+ "0x0u 0xFFFFFFFFu 0x100000000u",
+ &[
+ Token::Number(Ok(Number::U32(u32::MIN))),
+ Token::Number(Ok(Number::U32(u32::MAX))),
+ Token::Number(Err(NumberError::NotRepresentable)),
+ ],
+ );
+
+ // min/max decimal abstract int
+ sub_test(
+ "0 9223372036854775807 9223372036854775808",
+ &[
+ Token::Number(Ok(Number::AbstractInt(0))),
+ Token::Number(Ok(Number::AbstractInt(i64::MAX))),
+ Token::Number(Err(NumberError::NotRepresentable)),
+ ],
+ );
+
+ // min/max hexadecimal abstract int
+ sub_test(
+ "0 0x7fffffffffffffff 0x8000000000000000",
+ &[
+ Token::Number(Ok(Number::AbstractInt(0))),
+ Token::Number(Ok(Number::AbstractInt(i64::MAX))),
+ Token::Number(Err(NumberError::NotRepresentable)),
+ ],
+ );
+
+ /// ≈ 2^-126 * 2^−23 (= 2^−149)
+ const SMALLEST_POSITIVE_SUBNORMAL_F32: f32 = 1e-45;
+ /// ≈ 2^-126 * (1 − 2^−23)
+ const LARGEST_SUBNORMAL_F32: f32 = 1.1754942e-38;
+ /// ≈ 2^-126
+ const SMALLEST_POSITIVE_NORMAL_F32: f32 = f32::MIN_POSITIVE;
+ /// ≈ 1 − 2^−24
+ const LARGEST_F32_LESS_THAN_ONE: f32 = 0.99999994;
+ /// ≈ 1 + 2^−23
+ const SMALLEST_F32_LARGER_THAN_ONE: f32 = 1.0000001;
+ /// ≈ 2^127 * (2 − 2^−23)
+ const LARGEST_NORMAL_F32: f32 = f32::MAX;
+
+ // decimal floating point
+ sub_test(
+ "1e-45f 1.1754942e-38f 1.17549435e-38f 0.99999994f 1.0000001f 3.40282347e+38f",
+ &[
+ Token::Number(Ok(Number::F32(SMALLEST_POSITIVE_SUBNORMAL_F32))),
+ Token::Number(Ok(Number::F32(LARGEST_SUBNORMAL_F32))),
+ Token::Number(Ok(Number::F32(SMALLEST_POSITIVE_NORMAL_F32))),
+ Token::Number(Ok(Number::F32(LARGEST_F32_LESS_THAN_ONE))),
+ Token::Number(Ok(Number::F32(SMALLEST_F32_LARGER_THAN_ONE))),
+ Token::Number(Ok(Number::F32(LARGEST_NORMAL_F32))),
+ ],
+ );
+ sub_test(
+ "3.40282367e+38f",
+ &[
+ Token::Number(Err(NumberError::NotRepresentable)), // ≈ 2^128
+ ],
+ );
+
+ // hexadecimal floating point
+ sub_test(
+ "0x1p-149f 0x7FFFFFp-149f 0x1p-126f 0xFFFFFFp-24f 0x800001p-23f 0xFFFFFFp+104f",
+ &[
+ Token::Number(Ok(Number::F32(SMALLEST_POSITIVE_SUBNORMAL_F32))),
+ Token::Number(Ok(Number::F32(LARGEST_SUBNORMAL_F32))),
+ Token::Number(Ok(Number::F32(SMALLEST_POSITIVE_NORMAL_F32))),
+ Token::Number(Ok(Number::F32(LARGEST_F32_LESS_THAN_ONE))),
+ Token::Number(Ok(Number::F32(SMALLEST_F32_LARGER_THAN_ONE))),
+ Token::Number(Ok(Number::F32(LARGEST_NORMAL_F32))),
+ ],
+ );
+ sub_test(
+ "0x1p128f 0x1.000001p0f",
+ &[
+ Token::Number(Err(NumberError::NotRepresentable)), // = 2^128
+ Token::Number(Err(NumberError::NotRepresentable)),
+ ],
+ );
+}
+
+#[test]
+fn double_floats() {
+ sub_test(
+ "0x1.2p4lf 0x1p8lf 0.0625lf 625e-4lf 10lf 10l",
+ &[
+ Token::Number(Ok(Number::F64(18.0))),
+ Token::Number(Ok(Number::F64(256.0))),
+ Token::Number(Ok(Number::F64(0.0625))),
+ Token::Number(Ok(Number::F64(0.0625))),
+ Token::Number(Ok(Number::F64(10.0))),
+ Token::Number(Ok(Number::AbstractInt(10))),
+ Token::Word("l"),
+ ],
+ )
+}
+
+#[test]
+fn test_tokens() {
+ sub_test("id123_OK", &[Token::Word("id123_OK")]);
+ sub_test(
+ "92No",
+ &[
+ Token::Number(Ok(Number::AbstractInt(92))),
+ Token::Word("No"),
+ ],
+ );
+ sub_test(
+ "2u3o",
+ &[
+ Token::Number(Ok(Number::U32(2))),
+ Token::Number(Ok(Number::AbstractInt(3))),
+ Token::Word("o"),
+ ],
+ );
+ sub_test(
+ "2.4f44po",
+ &[
+ Token::Number(Ok(Number::F32(2.4))),
+ Token::Number(Ok(Number::AbstractInt(44))),
+ Token::Word("po"),
+ ],
+ );
+ sub_test(
+ "Δέλτα réflexion Кызыл 𐰓𐰏𐰇 朝焼け سلام 검정 שָׁלוֹם गुलाबी փիրուզ",
+ &[
+ Token::Word("Δέλτα"),
+ Token::Word("réflexion"),
+ Token::Word("Кызыл"),
+ Token::Word("𐰓𐰏𐰇"),
+ Token::Word("朝焼け"),
+ Token::Word("سلام"),
+ Token::Word("검정"),
+ Token::Word("שָׁלוֹם"),
+ Token::Word("गुलाबी"),
+ Token::Word("փիրուզ"),
+ ],
+ );
+ sub_test("æNoø", &[Token::Word("æNoø")]);
+ sub_test("No¾", &[Token::Word("No"), Token::Unknown('¾')]);
+ sub_test("No好", &[Token::Word("No好")]);
+ sub_test("_No", &[Token::Word("_No")]);
+ sub_test(
+ "*/*/***/*//=/*****//",
+ &[
+ Token::Operation('*'),
+ Token::AssignmentOperation('/'),
+ Token::Operation('/'),
+ ],
+ );
+
+ // Type suffixes are only allowed on hex float literals
+ // if you provided an exponent.
+ sub_test(
+ "0x1.2f 0x1.2f 0x1.2h 0x1.2H 0x1.2lf",
+ &[
+ // The 'f' suffixes are taken as a hex digit:
+ // the fractional part is 0x2f / 256.
+ Token::Number(Ok(Number::AbstractFloat(1.0 + 0x2f as f64 / 256.0))),
+ Token::Number(Ok(Number::AbstractFloat(1.0 + 0x2f as f64 / 256.0))),
+ Token::Number(Ok(Number::AbstractFloat(1.125))),
+ Token::Word("h"),
+ Token::Number(Ok(Number::AbstractFloat(1.125))),
+ Token::Word("H"),
+ Token::Number(Ok(Number::AbstractFloat(1.125))),
+ Token::Word("lf"),
+ ],
+ )
+}
+
+#[test]
+fn test_variable_decl() {
+ sub_test(
+ "@group(0 ) var< uniform> texture: texture_multisampled_2d <f32 >;",
+ &[
+ Token::Attribute,
+ Token::Word("group"),
+ Token::Paren('('),
+ Token::Number(Ok(Number::AbstractInt(0))),
+ Token::Paren(')'),
+ Token::Word("var"),
+ Token::Paren('<'),
+ Token::Word("uniform"),
+ Token::Paren('>'),
+ Token::Word("texture"),
+ Token::Separator(':'),
+ Token::Word("texture_multisampled_2d"),
+ Token::Paren('<'),
+ Token::Word("f32"),
+ Token::Paren('>'),
+ Token::Separator(';'),
+ ],
+ );
+ sub_test(
+ "var<storage,read_write> buffer: array<u32>;",
+ &[
+ Token::Word("var"),
+ Token::Paren('<'),
+ Token::Word("storage"),
+ Token::Separator(','),
+ Token::Word("read_write"),
+ Token::Paren('>'),
+ Token::Word("buffer"),
+ Token::Separator(':'),
+ Token::Word("array"),
+ Token::Paren('<'),
+ Token::Word("u32"),
+ Token::Paren('>'),
+ Token::Separator(';'),
+ ],
+ );
+}
diff --git a/third_party/rust/naga/src/front/wgsl/parse/mod.rs b/third_party/rust/naga/src/front/wgsl/parse/mod.rs
new file mode 100644
index 0000000000..51fc2f013b
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/parse/mod.rs
@@ -0,0 +1,2350 @@
+use crate::front::wgsl::error::{Error, ExpectedToken};
+use crate::front::wgsl::parse::lexer::{Lexer, Token};
+use crate::front::wgsl::parse::number::Number;
+use crate::front::wgsl::Scalar;
+use crate::front::SymbolTable;
+use crate::{Arena, FastIndexSet, Handle, ShaderStage, Span};
+
+pub mod ast;
+pub mod conv;
+pub mod lexer;
+pub mod number;
+
+/// State for constructing an AST expression.
+///
+/// Not to be confused with [`lower::ExpressionContext`], which is for producing
+/// Naga IR from the AST we produce here.
+///
+/// [`lower::ExpressionContext`]: super::lower::ExpressionContext
+struct ExpressionContext<'input, 'temp, 'out> {
+ /// The [`TranslationUnit::expressions`] arena to which we should contribute
+ /// expressions.
+ ///
+ /// [`TranslationUnit::expressions`]: ast::TranslationUnit::expressions
+ expressions: &'out mut Arena<ast::Expression<'input>>,
+
+ /// The [`TranslationUnit::types`] arena to which we should contribute new
+ /// types.
+ ///
+ /// [`TranslationUnit::types`]: ast::TranslationUnit::types
+ types: &'out mut Arena<ast::Type<'input>>,
+
+ /// A map from identifiers in scope to the locals/arguments they represent.
+ ///
+ /// The handles refer to the [`Function::locals`] area; see that field's
+ /// documentation for details.
+ ///
+ /// [`Function::locals`]: ast::Function::locals
+ local_table: &'temp mut SymbolTable<&'input str, Handle<ast::Local>>,
+
+ /// The [`Function::locals`] arena for the function we're building.
+ ///
+ /// [`Function::locals`]: ast::Function::locals
+ locals: &'out mut Arena<ast::Local>,
+
+ /// Identifiers used by the current global declaration that have no local definition.
+ ///
+ /// This becomes the [`GlobalDecl`]'s [`dependencies`] set.
+ ///
+ /// Note that we don't know at parse time what kind of [`GlobalDecl`] the
+ /// name refers to. We can't look up names until we've seen the entire
+ /// translation unit.
+ ///
+ /// [`GlobalDecl`]: ast::GlobalDecl
+ /// [`dependencies`]: ast::GlobalDecl::dependencies
+ unresolved: &'out mut FastIndexSet<ast::Dependency<'input>>,
+}
+
+impl<'a> ExpressionContext<'a, '_, '_> {
+ fn parse_binary_op(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ classifier: impl Fn(Token<'a>) -> Option<crate::BinaryOperator>,
+ mut parser: impl FnMut(
+ &mut Lexer<'a>,
+ &mut Self,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>>,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> {
+ let start = lexer.start_byte_offset();
+ let mut accumulator = parser(lexer, self)?;
+ while let Some(op) = classifier(lexer.peek().0) {
+ let _ = lexer.next();
+ let left = accumulator;
+ let right = parser(lexer, self)?;
+ accumulator = self.expressions.append(
+ ast::Expression::Binary { op, left, right },
+ lexer.span_from(start),
+ );
+ }
+ Ok(accumulator)
+ }
+
+ fn declare_local(&mut self, name: ast::Ident<'a>) -> Result<Handle<ast::Local>, Error<'a>> {
+ let handle = self.locals.append(ast::Local, name.span);
+ if let Some(old) = self.local_table.add(name.name, handle) {
+ Err(Error::Redefinition {
+ previous: self.locals.get_span(old),
+ current: name.span,
+ })
+ } else {
+ Ok(handle)
+ }
+ }
+}
+
+/// Which grammar rule we are in the midst of parsing.
+///
+/// This is used for error checking. `Parser` maintains a stack of
+/// these and (occasionally) checks that it is being pushed and popped
+/// as expected.
+#[derive(Clone, Debug, PartialEq)]
+enum Rule {
+ Attribute,
+ VariableDecl,
+ TypeDecl,
+ FunctionDecl,
+ Block,
+ Statement,
+ PrimaryExpr,
+ SingularExpr,
+ UnaryExpr,
+ GeneralExpr,
+}
+
+struct ParsedAttribute<T> {
+ value: Option<T>,
+}
+
+impl<T> Default for ParsedAttribute<T> {
+ fn default() -> Self {
+ Self { value: None }
+ }
+}
+
+impl<T> ParsedAttribute<T> {
+ fn set(&mut self, value: T, name_span: Span) -> Result<(), Error<'static>> {
+ if self.value.is_some() {
+ return Err(Error::RepeatedAttribute(name_span));
+ }
+ self.value = Some(value);
+ Ok(())
+ }
+}
+
+#[derive(Default)]
+struct BindingParser<'a> {
+ location: ParsedAttribute<Handle<ast::Expression<'a>>>,
+ second_blend_source: ParsedAttribute<bool>,
+ built_in: ParsedAttribute<crate::BuiltIn>,
+ interpolation: ParsedAttribute<crate::Interpolation>,
+ sampling: ParsedAttribute<crate::Sampling>,
+ invariant: ParsedAttribute<bool>,
+}
+
+impl<'a> BindingParser<'a> {
+ fn parse(
+ &mut self,
+ parser: &mut Parser,
+ lexer: &mut Lexer<'a>,
+ name: &'a str,
+ name_span: Span,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<(), Error<'a>> {
+ match name {
+ "location" => {
+ lexer.expect(Token::Paren('('))?;
+ self.location
+ .set(parser.general_expression(lexer, ctx)?, name_span)?;
+ lexer.expect(Token::Paren(')'))?;
+ }
+ "builtin" => {
+ lexer.expect(Token::Paren('('))?;
+ let (raw, span) = lexer.next_ident_with_span()?;
+ self.built_in
+ .set(conv::map_built_in(raw, span)?, name_span)?;
+ lexer.expect(Token::Paren(')'))?;
+ }
+ "interpolate" => {
+ lexer.expect(Token::Paren('('))?;
+ let (raw, span) = lexer.next_ident_with_span()?;
+ self.interpolation
+ .set(conv::map_interpolation(raw, span)?, name_span)?;
+ if lexer.skip(Token::Separator(',')) {
+ let (raw, span) = lexer.next_ident_with_span()?;
+ self.sampling
+ .set(conv::map_sampling(raw, span)?, name_span)?;
+ }
+ lexer.expect(Token::Paren(')'))?;
+ }
+ "second_blend_source" => {
+ self.second_blend_source.set(true, name_span)?;
+ }
+ "invariant" => {
+ self.invariant.set(true, name_span)?;
+ }
+ _ => return Err(Error::UnknownAttribute(name_span)),
+ }
+ Ok(())
+ }
+
+ fn finish(self, span: Span) -> Result<Option<ast::Binding<'a>>, Error<'a>> {
+ match (
+ self.location.value,
+ self.built_in.value,
+ self.interpolation.value,
+ self.sampling.value,
+ self.invariant.value.unwrap_or_default(),
+ ) {
+ (None, None, None, None, false) => Ok(None),
+ (Some(location), None, interpolation, sampling, false) => {
+ // Before handing over the completed `Module`, we call
+ // `apply_default_interpolation` to ensure that the interpolation and
+ // sampling have been explicitly specified on all vertex shader output and fragment
+ // shader input user bindings, so leaving them potentially `None` here is fine.
+ Ok(Some(ast::Binding::Location {
+ location,
+ interpolation,
+ sampling,
+ second_blend_source: self.second_blend_source.value.unwrap_or(false),
+ }))
+ }
+ (None, Some(crate::BuiltIn::Position { .. }), None, None, invariant) => {
+ Ok(Some(ast::Binding::BuiltIn(crate::BuiltIn::Position {
+ invariant,
+ })))
+ }
+ (None, Some(built_in), None, None, false) => Ok(Some(ast::Binding::BuiltIn(built_in))),
+ (_, _, _, _, _) => Err(Error::InconsistentBinding(span)),
+ }
+ }
+}
+
+pub struct Parser {
+ rules: Vec<(Rule, usize)>,
+}
+
+impl Parser {
+ pub const fn new() -> Self {
+ Parser { rules: Vec::new() }
+ }
+
+ fn reset(&mut self) {
+ self.rules.clear();
+ }
+
+ fn push_rule_span(&mut self, rule: Rule, lexer: &mut Lexer<'_>) {
+ self.rules.push((rule, lexer.start_byte_offset()));
+ }
+
+ fn pop_rule_span(&mut self, lexer: &Lexer<'_>) -> Span {
+ let (_, initial) = self.rules.pop().unwrap();
+ lexer.span_from(initial)
+ }
+
+ fn peek_rule_span(&mut self, lexer: &Lexer<'_>) -> Span {
+ let &(_, initial) = self.rules.last().unwrap();
+ lexer.span_from(initial)
+ }
+
+ fn switch_value<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<ast::SwitchValue<'a>, Error<'a>> {
+ if let Token::Word("default") = lexer.peek().0 {
+ let _ = lexer.next();
+ return Ok(ast::SwitchValue::Default);
+ }
+
+ let expr = self.general_expression(lexer, ctx)?;
+ Ok(ast::SwitchValue::Expr(expr))
+ }
+
+ /// Decide if we're looking at a construction expression, and return its
+ /// type if so.
+ ///
+ /// If the identifier `word` is a [type-defining keyword], then return a
+ /// [`ConstructorType`] value describing the type to build. Return an error
+ /// if the type is not constructible (like `sampler`).
+ ///
+ /// If `word` isn't a type name, then return `None`.
+ ///
+ /// [type-defining keyword]: https://gpuweb.github.io/gpuweb/wgsl/#type-defining-keywords
+ /// [`ConstructorType`]: ast::ConstructorType
+ fn constructor_type<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ word: &'a str,
+ span: Span,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<Option<ast::ConstructorType<'a>>, Error<'a>> {
+ if let Some(scalar) = conv::get_scalar_type(word) {
+ return Ok(Some(ast::ConstructorType::Scalar(scalar)));
+ }
+
+ let partial = match word {
+ "vec2" => ast::ConstructorType::PartialVector {
+ size: crate::VectorSize::Bi,
+ },
+ "vec2i" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Bi,
+ scalar: Scalar {
+ kind: crate::ScalarKind::Sint,
+ width: 4,
+ },
+ }))
+ }
+ "vec2u" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Bi,
+ scalar: Scalar {
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ },
+ }))
+ }
+ "vec2f" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Bi,
+ scalar: Scalar::F32,
+ }))
+ }
+ "vec3" => ast::ConstructorType::PartialVector {
+ size: crate::VectorSize::Tri,
+ },
+ "vec3i" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Tri,
+ scalar: Scalar::I32,
+ }))
+ }
+ "vec3u" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Tri,
+ scalar: Scalar::U32,
+ }))
+ }
+ "vec3f" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Tri,
+ scalar: Scalar::F32,
+ }))
+ }
+ "vec4" => ast::ConstructorType::PartialVector {
+ size: crate::VectorSize::Quad,
+ },
+ "vec4i" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Quad,
+ scalar: Scalar::I32,
+ }))
+ }
+ "vec4u" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Quad,
+ scalar: Scalar::U32,
+ }))
+ }
+ "vec4f" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Quad,
+ scalar: Scalar::F32,
+ }))
+ }
+ "mat2x2" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Bi,
+ },
+ "mat2x2f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }))
+ }
+ "mat2x3" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Tri,
+ },
+ "mat2x3f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Tri,
+ width: 4,
+ }))
+ }
+ "mat2x4" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Quad,
+ },
+ "mat2x4f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Quad,
+ width: 4,
+ }))
+ }
+ "mat3x2" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Bi,
+ },
+ "mat3x2f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }))
+ }
+ "mat3x3" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Tri,
+ },
+ "mat3x3f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Tri,
+ width: 4,
+ }))
+ }
+ "mat3x4" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Quad,
+ },
+ "mat3x4f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Quad,
+ width: 4,
+ }))
+ }
+ "mat4x2" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Bi,
+ },
+ "mat4x2f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }))
+ }
+ "mat4x3" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Tri,
+ },
+ "mat4x3f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Tri,
+ width: 4,
+ }))
+ }
+ "mat4x4" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Quad,
+ },
+ "mat4x4f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Quad,
+ width: 4,
+ }))
+ }
+ "array" => ast::ConstructorType::PartialArray,
+ "atomic"
+ | "binding_array"
+ | "sampler"
+ | "sampler_comparison"
+ | "texture_1d"
+ | "texture_1d_array"
+ | "texture_2d"
+ | "texture_2d_array"
+ | "texture_3d"
+ | "texture_cube"
+ | "texture_cube_array"
+ | "texture_multisampled_2d"
+ | "texture_multisampled_2d_array"
+ | "texture_depth_2d"
+ | "texture_depth_2d_array"
+ | "texture_depth_cube"
+ | "texture_depth_cube_array"
+ | "texture_depth_multisampled_2d"
+ | "texture_storage_1d"
+ | "texture_storage_1d_array"
+ | "texture_storage_2d"
+ | "texture_storage_2d_array"
+ | "texture_storage_3d" => return Err(Error::TypeNotConstructible(span)),
+ _ => return Ok(None),
+ };
+
+ // parse component type if present
+ match (lexer.peek().0, partial) {
+ (Token::Paren('<'), ast::ConstructorType::PartialVector { size }) => {
+ let scalar = lexer.next_scalar_generic()?;
+ Ok(Some(ast::ConstructorType::Vector { size, scalar }))
+ }
+ (Token::Paren('<'), ast::ConstructorType::PartialMatrix { columns, rows }) => {
+ let (scalar, span) = lexer.next_scalar_generic_with_span()?;
+ match scalar.kind {
+ crate::ScalarKind::Float => Ok(Some(ast::ConstructorType::Matrix {
+ columns,
+ rows,
+ width: scalar.width,
+ })),
+ _ => Err(Error::BadMatrixScalarKind(span, scalar)),
+ }
+ }
+ (Token::Paren('<'), ast::ConstructorType::PartialArray) => {
+ lexer.expect_generic_paren('<')?;
+ let base = self.type_decl(lexer, ctx)?;
+ let size = if lexer.skip(Token::Separator(',')) {
+ let expr = self.unary_expression(lexer, ctx)?;
+ ast::ArraySize::Constant(expr)
+ } else {
+ ast::ArraySize::Dynamic
+ };
+ lexer.expect_generic_paren('>')?;
+
+ Ok(Some(ast::ConstructorType::Array { base, size }))
+ }
+ (_, partial) => Ok(Some(partial)),
+ }
+ }
+
+ /// Expects `name` to be consumed (not in lexer).
+ fn arguments<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<Vec<Handle<ast::Expression<'a>>>, Error<'a>> {
+ lexer.open_arguments()?;
+ let mut arguments = Vec::new();
+ loop {
+ if !arguments.is_empty() {
+ if !lexer.next_argument()? {
+ break;
+ }
+ } else if lexer.skip(Token::Paren(')')) {
+ break;
+ }
+ let arg = self.general_expression(lexer, ctx)?;
+ arguments.push(arg);
+ }
+
+ Ok(arguments)
+ }
+
+ /// Expects [`Rule::PrimaryExpr`] or [`Rule::SingularExpr`] on top; does not pop it.
+ /// Expects `name` to be consumed (not in lexer).
+ fn function_call<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ name: &'a str,
+ name_span: Span,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> {
+ assert!(self.rules.last().is_some());
+
+ let expr = match name {
+ // bitcast looks like a function call, but it's an operator and must be handled differently.
+ "bitcast" => {
+ lexer.expect_generic_paren('<')?;
+ let start = lexer.start_byte_offset();
+ let to = self.type_decl(lexer, ctx)?;
+ let span = lexer.span_from(start);
+ lexer.expect_generic_paren('>')?;
+
+ lexer.open_arguments()?;
+ let expr = self.general_expression(lexer, ctx)?;
+ lexer.close_arguments()?;
+
+ ast::Expression::Bitcast {
+ expr,
+ to,
+ ty_span: span,
+ }
+ }
+ // everything else must be handled later, since they can be hidden by user-defined functions.
+ _ => {
+ let arguments = self.arguments(lexer, ctx)?;
+ ctx.unresolved.insert(ast::Dependency {
+ ident: name,
+ usage: name_span,
+ });
+ ast::Expression::Call {
+ function: ast::Ident {
+ name,
+ span: name_span,
+ },
+ arguments,
+ }
+ }
+ };
+
+ let span = self.peek_rule_span(lexer);
+ let expr = ctx.expressions.append(expr, span);
+ Ok(expr)
+ }
+
+ fn ident_expr<'a>(
+ &mut self,
+ name: &'a str,
+ name_span: Span,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ ) -> ast::IdentExpr<'a> {
+ match ctx.local_table.lookup(name) {
+ Some(&local) => ast::IdentExpr::Local(local),
+ None => {
+ ctx.unresolved.insert(ast::Dependency {
+ ident: name,
+ usage: name_span,
+ });
+ ast::IdentExpr::Unresolved(name)
+ }
+ }
+ }
+
+ fn primary_expression<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> {
+ self.push_rule_span(Rule::PrimaryExpr, lexer);
+
+ let expr = match lexer.peek() {
+ (Token::Paren('('), _) => {
+ let _ = lexer.next();
+ let expr = self.general_expression(lexer, ctx)?;
+ lexer.expect(Token::Paren(')'))?;
+ self.pop_rule_span(lexer);
+ return Ok(expr);
+ }
+ (Token::Word("true"), _) => {
+ let _ = lexer.next();
+ ast::Expression::Literal(ast::Literal::Bool(true))
+ }
+ (Token::Word("false"), _) => {
+ let _ = lexer.next();
+ ast::Expression::Literal(ast::Literal::Bool(false))
+ }
+ (Token::Number(res), span) => {
+ let _ = lexer.next();
+ let num = res.map_err(|err| Error::BadNumber(span, err))?;
+ ast::Expression::Literal(ast::Literal::Number(num))
+ }
+ (Token::Word("RAY_FLAG_NONE"), _) => {
+ let _ = lexer.next();
+ ast::Expression::Literal(ast::Literal::Number(Number::U32(0)))
+ }
+ (Token::Word("RAY_FLAG_TERMINATE_ON_FIRST_HIT"), _) => {
+ let _ = lexer.next();
+ ast::Expression::Literal(ast::Literal::Number(Number::U32(4)))
+ }
+ (Token::Word("RAY_QUERY_INTERSECTION_NONE"), _) => {
+ let _ = lexer.next();
+ ast::Expression::Literal(ast::Literal::Number(Number::U32(0)))
+ }
+ (Token::Word(word), span) => {
+ let start = lexer.start_byte_offset();
+ let _ = lexer.next();
+
+ if let Some(ty) = self.constructor_type(lexer, word, span, ctx)? {
+ let ty_span = lexer.span_from(start);
+ let components = self.arguments(lexer, ctx)?;
+ ast::Expression::Construct {
+ ty,
+ ty_span,
+ components,
+ }
+ } else if let Token::Paren('(') = lexer.peek().0 {
+ self.pop_rule_span(lexer);
+ return self.function_call(lexer, word, span, ctx);
+ } else if word == "bitcast" {
+ self.pop_rule_span(lexer);
+ return self.function_call(lexer, word, span, ctx);
+ } else {
+ let ident = self.ident_expr(word, span, ctx);
+ ast::Expression::Ident(ident)
+ }
+ }
+ other => return Err(Error::Unexpected(other.1, ExpectedToken::PrimaryExpression)),
+ };
+
+ let span = self.pop_rule_span(lexer);
+ let expr = ctx.expressions.append(expr, span);
+ Ok(expr)
+ }
+
+ fn postfix<'a>(
+ &mut self,
+ span_start: usize,
+ lexer: &mut Lexer<'a>,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ expr: Handle<ast::Expression<'a>>,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> {
+ let mut expr = expr;
+
+ loop {
+ let expression = match lexer.peek().0 {
+ Token::Separator('.') => {
+ let _ = lexer.next();
+ let field = lexer.next_ident()?;
+
+ ast::Expression::Member { base: expr, field }
+ }
+ Token::Paren('[') => {
+ let _ = lexer.next();
+ let index = self.general_expression(lexer, ctx)?;
+ lexer.expect(Token::Paren(']'))?;
+
+ ast::Expression::Index { base: expr, index }
+ }
+ _ => break,
+ };
+
+ let span = lexer.span_from(span_start);
+ expr = ctx.expressions.append(expression, span);
+ }
+
+ Ok(expr)
+ }
+
+ /// Parse a `unary_expression`.
+ fn unary_expression<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> {
+ self.push_rule_span(Rule::UnaryExpr, lexer);
+ //TODO: refactor this to avoid backing up
+ let expr = match lexer.peek().0 {
+ Token::Operation('-') => {
+ let _ = lexer.next();
+ let expr = self.unary_expression(lexer, ctx)?;
+ let expr = ast::Expression::Unary {
+ op: crate::UnaryOperator::Negate,
+ expr,
+ };
+ let span = self.peek_rule_span(lexer);
+ ctx.expressions.append(expr, span)
+ }
+ Token::Operation('!') => {
+ let _ = lexer.next();
+ let expr = self.unary_expression(lexer, ctx)?;
+ let expr = ast::Expression::Unary {
+ op: crate::UnaryOperator::LogicalNot,
+ expr,
+ };
+ let span = self.peek_rule_span(lexer);
+ ctx.expressions.append(expr, span)
+ }
+ Token::Operation('~') => {
+ let _ = lexer.next();
+ let expr = self.unary_expression(lexer, ctx)?;
+ let expr = ast::Expression::Unary {
+ op: crate::UnaryOperator::BitwiseNot,
+ expr,
+ };
+ let span = self.peek_rule_span(lexer);
+ ctx.expressions.append(expr, span)
+ }
+ Token::Operation('*') => {
+ let _ = lexer.next();
+ let expr = self.unary_expression(lexer, ctx)?;
+ let expr = ast::Expression::Deref(expr);
+ let span = self.peek_rule_span(lexer);
+ ctx.expressions.append(expr, span)
+ }
+ Token::Operation('&') => {
+ let _ = lexer.next();
+ let expr = self.unary_expression(lexer, ctx)?;
+ let expr = ast::Expression::AddrOf(expr);
+ let span = self.peek_rule_span(lexer);
+ ctx.expressions.append(expr, span)
+ }
+ _ => self.singular_expression(lexer, ctx)?,
+ };
+
+ self.pop_rule_span(lexer);
+ Ok(expr)
+ }
+
+ /// Parse a `singular_expression`.
+ fn singular_expression<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> {
+ let start = lexer.start_byte_offset();
+ self.push_rule_span(Rule::SingularExpr, lexer);
+ let primary_expr = self.primary_expression(lexer, ctx)?;
+ let singular_expr = self.postfix(start, lexer, ctx, primary_expr)?;
+ self.pop_rule_span(lexer);
+
+ Ok(singular_expr)
+ }
+
+ fn equality_expression<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ context: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> {
+ // equality_expression
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::LogicalOperation('=') => Some(crate::BinaryOperator::Equal),
+ Token::LogicalOperation('!') => Some(crate::BinaryOperator::NotEqual),
+ _ => None,
+ },
+ // relational_expression
+ |lexer, context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::Paren('<') => Some(crate::BinaryOperator::Less),
+ Token::Paren('>') => Some(crate::BinaryOperator::Greater),
+ Token::LogicalOperation('<') => Some(crate::BinaryOperator::LessEqual),
+ Token::LogicalOperation('>') => Some(crate::BinaryOperator::GreaterEqual),
+ _ => None,
+ },
+ // shift_expression
+ |lexer, context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::ShiftOperation('<') => {
+ Some(crate::BinaryOperator::ShiftLeft)
+ }
+ Token::ShiftOperation('>') => {
+ Some(crate::BinaryOperator::ShiftRight)
+ }
+ _ => None,
+ },
+ // additive_expression
+ |lexer, context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::Operation('+') => Some(crate::BinaryOperator::Add),
+ Token::Operation('-') => {
+ Some(crate::BinaryOperator::Subtract)
+ }
+ _ => None,
+ },
+ // multiplicative_expression
+ |lexer, context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::Operation('*') => {
+ Some(crate::BinaryOperator::Multiply)
+ }
+ Token::Operation('/') => {
+ Some(crate::BinaryOperator::Divide)
+ }
+ Token::Operation('%') => {
+ Some(crate::BinaryOperator::Modulo)
+ }
+ _ => None,
+ },
+ |lexer, context| self.unary_expression(lexer, context),
+ )
+ },
+ )
+ },
+ )
+ },
+ )
+ },
+ )
+ }
+
+ fn general_expression<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> {
+ self.general_expression_with_span(lexer, ctx)
+ .map(|(expr, _)| expr)
+ }
+
+ fn general_expression_with_span<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ context: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<(Handle<ast::Expression<'a>>, Span), Error<'a>> {
+ self.push_rule_span(Rule::GeneralExpr, lexer);
+ // logical_or_expression
+ let handle = context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::LogicalOperation('|') => Some(crate::BinaryOperator::LogicalOr),
+ _ => None,
+ },
+ // logical_and_expression
+ |lexer, context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::LogicalOperation('&') => Some(crate::BinaryOperator::LogicalAnd),
+ _ => None,
+ },
+ // inclusive_or_expression
+ |lexer, context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::Operation('|') => Some(crate::BinaryOperator::InclusiveOr),
+ _ => None,
+ },
+ // exclusive_or_expression
+ |lexer, context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::Operation('^') => {
+ Some(crate::BinaryOperator::ExclusiveOr)
+ }
+ _ => None,
+ },
+ // and_expression
+ |lexer, context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::Operation('&') => {
+ Some(crate::BinaryOperator::And)
+ }
+ _ => None,
+ },
+ |lexer, context| {
+ self.equality_expression(lexer, context)
+ },
+ )
+ },
+ )
+ },
+ )
+ },
+ )
+ },
+ )?;
+ Ok((handle, self.pop_rule_span(lexer)))
+ }
+
+ fn variable_decl<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<ast::GlobalVariable<'a>, Error<'a>> {
+ self.push_rule_span(Rule::VariableDecl, lexer);
+ let mut space = crate::AddressSpace::Handle;
+
+ if lexer.skip(Token::Paren('<')) {
+ let (class_str, span) = lexer.next_ident_with_span()?;
+ space = match class_str {
+ "storage" => {
+ let access = if lexer.skip(Token::Separator(',')) {
+ lexer.next_storage_access()?
+ } else {
+ // defaulting to `read`
+ crate::StorageAccess::LOAD
+ };
+ crate::AddressSpace::Storage { access }
+ }
+ _ => conv::map_address_space(class_str, span)?,
+ };
+ lexer.expect(Token::Paren('>'))?;
+ }
+ let name = lexer.next_ident()?;
+ lexer.expect(Token::Separator(':'))?;
+ let ty = self.type_decl(lexer, ctx)?;
+
+ let init = if lexer.skip(Token::Operation('=')) {
+ let handle = self.general_expression(lexer, ctx)?;
+ Some(handle)
+ } else {
+ None
+ };
+ lexer.expect(Token::Separator(';'))?;
+ self.pop_rule_span(lexer);
+
+ Ok(ast::GlobalVariable {
+ name,
+ space,
+ binding: None,
+ ty,
+ init,
+ })
+ }
+
+ fn struct_body<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<Vec<ast::StructMember<'a>>, Error<'a>> {
+ let mut members = Vec::new();
+
+ lexer.expect(Token::Paren('{'))?;
+ let mut ready = true;
+ while !lexer.skip(Token::Paren('}')) {
+ if !ready {
+ return Err(Error::Unexpected(
+ lexer.next().1,
+ ExpectedToken::Token(Token::Separator(',')),
+ ));
+ }
+ let (mut size, mut align) = (ParsedAttribute::default(), ParsedAttribute::default());
+ self.push_rule_span(Rule::Attribute, lexer);
+ let mut bind_parser = BindingParser::default();
+ while lexer.skip(Token::Attribute) {
+ match lexer.next_ident_with_span()? {
+ ("size", name_span) => {
+ lexer.expect(Token::Paren('('))?;
+ let expr = self.general_expression(lexer, ctx)?;
+ lexer.expect(Token::Paren(')'))?;
+ size.set(expr, name_span)?;
+ }
+ ("align", name_span) => {
+ lexer.expect(Token::Paren('('))?;
+ let expr = self.general_expression(lexer, ctx)?;
+ lexer.expect(Token::Paren(')'))?;
+ align.set(expr, name_span)?;
+ }
+ (word, word_span) => bind_parser.parse(self, lexer, word, word_span, ctx)?,
+ }
+ }
+
+ let bind_span = self.pop_rule_span(lexer);
+ let binding = bind_parser.finish(bind_span)?;
+
+ let name = lexer.next_ident()?;
+ lexer.expect(Token::Separator(':'))?;
+ let ty = self.type_decl(lexer, ctx)?;
+ ready = lexer.skip(Token::Separator(','));
+
+ members.push(ast::StructMember {
+ name,
+ ty,
+ binding,
+ size: size.value,
+ align: align.value,
+ });
+ }
+
+ Ok(members)
+ }
+
+ fn matrix_scalar_type<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ columns: crate::VectorSize,
+ rows: crate::VectorSize,
+ ) -> Result<ast::Type<'a>, Error<'a>> {
+ let (scalar, span) = lexer.next_scalar_generic_with_span()?;
+ match scalar.kind {
+ crate::ScalarKind::Float => Ok(ast::Type::Matrix {
+ columns,
+ rows,
+ width: scalar.width,
+ }),
+ _ => Err(Error::BadMatrixScalarKind(span, scalar)),
+ }
+ }
+
+ fn type_decl_impl<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ word: &'a str,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<Option<ast::Type<'a>>, Error<'a>> {
+ if let Some(scalar) = conv::get_scalar_type(word) {
+ return Ok(Some(ast::Type::Scalar(scalar)));
+ }
+
+ Ok(Some(match word {
+ "vec2" => {
+ let scalar = lexer.next_scalar_generic()?;
+ ast::Type::Vector {
+ size: crate::VectorSize::Bi,
+ scalar,
+ }
+ }
+ "vec2i" => ast::Type::Vector {
+ size: crate::VectorSize::Bi,
+ scalar: Scalar {
+ kind: crate::ScalarKind::Sint,
+ width: 4,
+ },
+ },
+ "vec2u" => ast::Type::Vector {
+ size: crate::VectorSize::Bi,
+ scalar: Scalar {
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ },
+ },
+ "vec2f" => ast::Type::Vector {
+ size: crate::VectorSize::Bi,
+ scalar: Scalar::F32,
+ },
+ "vec3" => {
+ let scalar = lexer.next_scalar_generic()?;
+ ast::Type::Vector {
+ size: crate::VectorSize::Tri,
+ scalar,
+ }
+ }
+ "vec3i" => ast::Type::Vector {
+ size: crate::VectorSize::Tri,
+ scalar: Scalar {
+ kind: crate::ScalarKind::Sint,
+ width: 4,
+ },
+ },
+ "vec3u" => ast::Type::Vector {
+ size: crate::VectorSize::Tri,
+ scalar: Scalar {
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ },
+ },
+ "vec3f" => ast::Type::Vector {
+ size: crate::VectorSize::Tri,
+ scalar: Scalar::F32,
+ },
+ "vec4" => {
+ let scalar = lexer.next_scalar_generic()?;
+ ast::Type::Vector {
+ size: crate::VectorSize::Quad,
+ scalar,
+ }
+ }
+ "vec4i" => ast::Type::Vector {
+ size: crate::VectorSize::Quad,
+ scalar: Scalar {
+ kind: crate::ScalarKind::Sint,
+ width: 4,
+ },
+ },
+ "vec4u" => ast::Type::Vector {
+ size: crate::VectorSize::Quad,
+ scalar: Scalar {
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ },
+ },
+ "vec4f" => ast::Type::Vector {
+ size: crate::VectorSize::Quad,
+ scalar: Scalar::F32,
+ },
+ "mat2x2" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Bi, crate::VectorSize::Bi)?
+ }
+ "mat2x2f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ },
+ "mat2x3" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Bi, crate::VectorSize::Tri)?
+ }
+ "mat2x3f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Tri,
+ width: 4,
+ },
+ "mat2x4" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Bi, crate::VectorSize::Quad)?
+ }
+ "mat2x4f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Quad,
+ width: 4,
+ },
+ "mat3x2" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Tri, crate::VectorSize::Bi)?
+ }
+ "mat3x2f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ },
+ "mat3x3" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Tri, crate::VectorSize::Tri)?
+ }
+ "mat3x3f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Tri,
+ width: 4,
+ },
+ "mat3x4" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Tri, crate::VectorSize::Quad)?
+ }
+ "mat3x4f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Quad,
+ width: 4,
+ },
+ "mat4x2" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Quad, crate::VectorSize::Bi)?
+ }
+ "mat4x2f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ },
+ "mat4x3" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Quad, crate::VectorSize::Tri)?
+ }
+ "mat4x3f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Tri,
+ width: 4,
+ },
+ "mat4x4" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Quad, crate::VectorSize::Quad)?
+ }
+ "mat4x4f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Quad,
+ width: 4,
+ },
+ "atomic" => {
+ let scalar = lexer.next_scalar_generic()?;
+ ast::Type::Atomic(scalar)
+ }
+ "ptr" => {
+ lexer.expect_generic_paren('<')?;
+ let (ident, span) = lexer.next_ident_with_span()?;
+ let mut space = conv::map_address_space(ident, span)?;
+ lexer.expect(Token::Separator(','))?;
+ let base = self.type_decl(lexer, ctx)?;
+ if let crate::AddressSpace::Storage { ref mut access } = space {
+ *access = if lexer.skip(Token::Separator(',')) {
+ lexer.next_storage_access()?
+ } else {
+ crate::StorageAccess::LOAD
+ };
+ }
+ lexer.expect_generic_paren('>')?;
+ ast::Type::Pointer { base, space }
+ }
+ "array" => {
+ lexer.expect_generic_paren('<')?;
+ let base = self.type_decl(lexer, ctx)?;
+ let size = if lexer.skip(Token::Separator(',')) {
+ let size = self.unary_expression(lexer, ctx)?;
+ ast::ArraySize::Constant(size)
+ } else {
+ ast::ArraySize::Dynamic
+ };
+ lexer.expect_generic_paren('>')?;
+
+ ast::Type::Array { base, size }
+ }
+ "binding_array" => {
+ lexer.expect_generic_paren('<')?;
+ let base = self.type_decl(lexer, ctx)?;
+ let size = if lexer.skip(Token::Separator(',')) {
+ let size = self.unary_expression(lexer, ctx)?;
+ ast::ArraySize::Constant(size)
+ } else {
+ ast::ArraySize::Dynamic
+ };
+ lexer.expect_generic_paren('>')?;
+
+ ast::Type::BindingArray { base, size }
+ }
+ "sampler" => ast::Type::Sampler { comparison: false },
+ "sampler_comparison" => ast::Type::Sampler { comparison: true },
+ "texture_1d" => {
+ let (scalar, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(scalar, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D1,
+ arrayed: false,
+ class: crate::ImageClass::Sampled {
+ kind: scalar.kind,
+ multi: false,
+ },
+ }
+ }
+ "texture_1d_array" => {
+ let (scalar, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(scalar, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D1,
+ arrayed: true,
+ class: crate::ImageClass::Sampled {
+ kind: scalar.kind,
+ multi: false,
+ },
+ }
+ }
+ "texture_2d" => {
+ let (scalar, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(scalar, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: false,
+ class: crate::ImageClass::Sampled {
+ kind: scalar.kind,
+ multi: false,
+ },
+ }
+ }
+ "texture_2d_array" => {
+ let (scalar, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(scalar, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: true,
+ class: crate::ImageClass::Sampled {
+ kind: scalar.kind,
+ multi: false,
+ },
+ }
+ }
+ "texture_3d" => {
+ let (scalar, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(scalar, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D3,
+ arrayed: false,
+ class: crate::ImageClass::Sampled {
+ kind: scalar.kind,
+ multi: false,
+ },
+ }
+ }
+ "texture_cube" => {
+ let (scalar, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(scalar, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::Cube,
+ arrayed: false,
+ class: crate::ImageClass::Sampled {
+ kind: scalar.kind,
+ multi: false,
+ },
+ }
+ }
+ "texture_cube_array" => {
+ let (scalar, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(scalar, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::Cube,
+ arrayed: true,
+ class: crate::ImageClass::Sampled {
+ kind: scalar.kind,
+ multi: false,
+ },
+ }
+ }
+ "texture_multisampled_2d" => {
+ let (scalar, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(scalar, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: false,
+ class: crate::ImageClass::Sampled {
+ kind: scalar.kind,
+ multi: true,
+ },
+ }
+ }
+ "texture_multisampled_2d_array" => {
+ let (scalar, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(scalar, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: true,
+ class: crate::ImageClass::Sampled {
+ kind: scalar.kind,
+ multi: true,
+ },
+ }
+ }
+ "texture_depth_2d" => ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: false,
+ class: crate::ImageClass::Depth { multi: false },
+ },
+ "texture_depth_2d_array" => ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: true,
+ class: crate::ImageClass::Depth { multi: false },
+ },
+ "texture_depth_cube" => ast::Type::Image {
+ dim: crate::ImageDimension::Cube,
+ arrayed: false,
+ class: crate::ImageClass::Depth { multi: false },
+ },
+ "texture_depth_cube_array" => ast::Type::Image {
+ dim: crate::ImageDimension::Cube,
+ arrayed: true,
+ class: crate::ImageClass::Depth { multi: false },
+ },
+ "texture_depth_multisampled_2d" => ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: false,
+ class: crate::ImageClass::Depth { multi: true },
+ },
+ "texture_storage_1d" => {
+ let (format, access) = lexer.next_format_generic()?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D1,
+ arrayed: false,
+ class: crate::ImageClass::Storage { format, access },
+ }
+ }
+ "texture_storage_1d_array" => {
+ let (format, access) = lexer.next_format_generic()?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D1,
+ arrayed: true,
+ class: crate::ImageClass::Storage { format, access },
+ }
+ }
+ "texture_storage_2d" => {
+ let (format, access) = lexer.next_format_generic()?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: false,
+ class: crate::ImageClass::Storage { format, access },
+ }
+ }
+ "texture_storage_2d_array" => {
+ let (format, access) = lexer.next_format_generic()?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: true,
+ class: crate::ImageClass::Storage { format, access },
+ }
+ }
+ "texture_storage_3d" => {
+ let (format, access) = lexer.next_format_generic()?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D3,
+ arrayed: false,
+ class: crate::ImageClass::Storage { format, access },
+ }
+ }
+ "acceleration_structure" => ast::Type::AccelerationStructure,
+ "ray_query" => ast::Type::RayQuery,
+ "RayDesc" => ast::Type::RayDesc,
+ "RayIntersection" => ast::Type::RayIntersection,
+ _ => return Ok(None),
+ }))
+ }
+
+ const fn check_texture_sample_type(scalar: Scalar, span: Span) -> Result<(), Error<'static>> {
+ use crate::ScalarKind::*;
+ // Validate according to https://gpuweb.github.io/gpuweb/wgsl/#sampled-texture-type
+ match scalar {
+ Scalar {
+ kind: Float | Sint | Uint,
+ width: 4,
+ } => Ok(()),
+ _ => Err(Error::BadTextureSampleType { span, scalar }),
+ }
+ }
+
+ /// Parse type declaration of a given name.
+ fn type_decl<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<ast::Type<'a>>, Error<'a>> {
+ self.push_rule_span(Rule::TypeDecl, lexer);
+
+ let (name, span) = lexer.next_ident_with_span()?;
+
+ let ty = match self.type_decl_impl(lexer, name, ctx)? {
+ Some(ty) => ty,
+ None => {
+ ctx.unresolved.insert(ast::Dependency {
+ ident: name,
+ usage: span,
+ });
+ ast::Type::User(ast::Ident { name, span })
+ }
+ };
+
+ self.pop_rule_span(lexer);
+
+ let handle = ctx.types.append(ty, Span::UNDEFINED);
+ Ok(handle)
+ }
+
+ fn assignment_op_and_rhs<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ block: &mut ast::Block<'a>,
+ target: Handle<ast::Expression<'a>>,
+ span_start: usize,
+ ) -> Result<(), Error<'a>> {
+ use crate::BinaryOperator as Bo;
+
+ let op = lexer.next();
+ let (op, value) = match op {
+ (Token::Operation('='), _) => {
+ let value = self.general_expression(lexer, ctx)?;
+ (None, value)
+ }
+ (Token::AssignmentOperation(c), _) => {
+ let op = match c {
+ '<' => Bo::ShiftLeft,
+ '>' => Bo::ShiftRight,
+ '+' => Bo::Add,
+ '-' => Bo::Subtract,
+ '*' => Bo::Multiply,
+ '/' => Bo::Divide,
+ '%' => Bo::Modulo,
+ '&' => Bo::And,
+ '|' => Bo::InclusiveOr,
+ '^' => Bo::ExclusiveOr,
+ // Note: `consume_token` shouldn't produce any other assignment ops
+ _ => unreachable!(),
+ };
+
+ let value = self.general_expression(lexer, ctx)?;
+ (Some(op), value)
+ }
+ token @ (Token::IncrementOperation | Token::DecrementOperation, _) => {
+ let op = match token.0 {
+ Token::IncrementOperation => ast::StatementKind::Increment,
+ Token::DecrementOperation => ast::StatementKind::Decrement,
+ _ => unreachable!(),
+ };
+
+ let span = lexer.span_from(span_start);
+ block.stmts.push(ast::Statement {
+ kind: op(target),
+ span,
+ });
+ return Ok(());
+ }
+ _ => return Err(Error::Unexpected(op.1, ExpectedToken::Assignment)),
+ };
+
+ let span = lexer.span_from(span_start);
+ block.stmts.push(ast::Statement {
+ kind: ast::StatementKind::Assign { target, op, value },
+ span,
+ });
+ Ok(())
+ }
+
+ /// Parse an assignment statement (will also parse increment and decrement statements)
+ fn assignment_statement<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ block: &mut ast::Block<'a>,
+ ) -> Result<(), Error<'a>> {
+ let span_start = lexer.start_byte_offset();
+ let target = self.general_expression(lexer, ctx)?;
+ self.assignment_op_and_rhs(lexer, ctx, block, target, span_start)
+ }
+
+ /// Parse a function call statement.
+ /// Expects `ident` to be consumed (not in the lexer).
+ fn function_statement<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ident: &'a str,
+ ident_span: Span,
+ span_start: usize,
+ context: &mut ExpressionContext<'a, '_, '_>,
+ block: &mut ast::Block<'a>,
+ ) -> Result<(), Error<'a>> {
+ self.push_rule_span(Rule::SingularExpr, lexer);
+
+ context.unresolved.insert(ast::Dependency {
+ ident,
+ usage: ident_span,
+ });
+ let arguments = self.arguments(lexer, context)?;
+ let span = lexer.span_from(span_start);
+
+ block.stmts.push(ast::Statement {
+ kind: ast::StatementKind::Call {
+ function: ast::Ident {
+ name: ident,
+ span: ident_span,
+ },
+ arguments,
+ },
+ span,
+ });
+
+ self.pop_rule_span(lexer);
+
+ Ok(())
+ }
+
+ fn function_call_or_assignment_statement<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ context: &mut ExpressionContext<'a, '_, '_>,
+ block: &mut ast::Block<'a>,
+ ) -> Result<(), Error<'a>> {
+ let span_start = lexer.start_byte_offset();
+ match lexer.peek() {
+ (Token::Word(name), span) => {
+ // A little hack for 2 token lookahead.
+ let cloned = lexer.clone();
+ let _ = lexer.next();
+ match lexer.peek() {
+ (Token::Paren('('), _) => {
+ self.function_statement(lexer, name, span, span_start, context, block)
+ }
+ _ => {
+ *lexer = cloned;
+ self.assignment_statement(lexer, context, block)
+ }
+ }
+ }
+ _ => self.assignment_statement(lexer, context, block),
+ }
+ }
+
+ fn statement<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ block: &mut ast::Block<'a>,
+ ) -> Result<(), Error<'a>> {
+ self.push_rule_span(Rule::Statement, lexer);
+ match lexer.peek() {
+ (Token::Separator(';'), _) => {
+ let _ = lexer.next();
+ self.pop_rule_span(lexer);
+ return Ok(());
+ }
+ (Token::Paren('{'), _) => {
+ let (inner, span) = self.block(lexer, ctx)?;
+ block.stmts.push(ast::Statement {
+ kind: ast::StatementKind::Block(inner),
+ span,
+ });
+ self.pop_rule_span(lexer);
+ return Ok(());
+ }
+ (Token::Word(word), _) => {
+ let kind = match word {
+ "_" => {
+ let _ = lexer.next();
+ lexer.expect(Token::Operation('='))?;
+ let expr = self.general_expression(lexer, ctx)?;
+ lexer.expect(Token::Separator(';'))?;
+
+ ast::StatementKind::Ignore(expr)
+ }
+ "let" => {
+ let _ = lexer.next();
+ let name = lexer.next_ident()?;
+
+ let given_ty = if lexer.skip(Token::Separator(':')) {
+ let ty = self.type_decl(lexer, ctx)?;
+ Some(ty)
+ } else {
+ None
+ };
+ lexer.expect(Token::Operation('='))?;
+ let expr_id = self.general_expression(lexer, ctx)?;
+ lexer.expect(Token::Separator(';'))?;
+
+ let handle = ctx.declare_local(name)?;
+ ast::StatementKind::LocalDecl(ast::LocalDecl::Let(ast::Let {
+ name,
+ ty: given_ty,
+ init: expr_id,
+ handle,
+ }))
+ }
+ "var" => {
+ let _ = lexer.next();
+
+ let name = lexer.next_ident()?;
+ let ty = if lexer.skip(Token::Separator(':')) {
+ let ty = self.type_decl(lexer, ctx)?;
+ Some(ty)
+ } else {
+ None
+ };
+
+ let init = if lexer.skip(Token::Operation('=')) {
+ let init = self.general_expression(lexer, ctx)?;
+ Some(init)
+ } else {
+ None
+ };
+
+ lexer.expect(Token::Separator(';'))?;
+
+ let handle = ctx.declare_local(name)?;
+ ast::StatementKind::LocalDecl(ast::LocalDecl::Var(ast::LocalVariable {
+ name,
+ ty,
+ init,
+ handle,
+ }))
+ }
+ "return" => {
+ let _ = lexer.next();
+ let value = if lexer.peek().0 != Token::Separator(';') {
+ let handle = self.general_expression(lexer, ctx)?;
+ Some(handle)
+ } else {
+ None
+ };
+ lexer.expect(Token::Separator(';'))?;
+ ast::StatementKind::Return { value }
+ }
+ "if" => {
+ let _ = lexer.next();
+ let condition = self.general_expression(lexer, ctx)?;
+
+ let accept = self.block(lexer, ctx)?.0;
+
+ let mut elsif_stack = Vec::new();
+ let mut elseif_span_start = lexer.start_byte_offset();
+ let mut reject = loop {
+ if !lexer.skip(Token::Word("else")) {
+ break ast::Block::default();
+ }
+
+ if !lexer.skip(Token::Word("if")) {
+ // ... else { ... }
+ break self.block(lexer, ctx)?.0;
+ }
+
+ // ... else if (...) { ... }
+ let other_condition = self.general_expression(lexer, ctx)?;
+ let other_block = self.block(lexer, ctx)?;
+ elsif_stack.push((elseif_span_start, other_condition, other_block));
+ elseif_span_start = lexer.start_byte_offset();
+ };
+
+ // reverse-fold the else-if blocks
+ //Note: we may consider uplifting this to the IR
+ for (other_span_start, other_cond, other_block) in
+ elsif_stack.into_iter().rev()
+ {
+ let sub_stmt = ast::StatementKind::If {
+ condition: other_cond,
+ accept: other_block.0,
+ reject,
+ };
+ reject = ast::Block::default();
+ let span = lexer.span_from(other_span_start);
+ reject.stmts.push(ast::Statement {
+ kind: sub_stmt,
+ span,
+ })
+ }
+
+ ast::StatementKind::If {
+ condition,
+ accept,
+ reject,
+ }
+ }
+ "switch" => {
+ let _ = lexer.next();
+ let selector = self.general_expression(lexer, ctx)?;
+ lexer.expect(Token::Paren('{'))?;
+ let mut cases = Vec::new();
+
+ loop {
+ // cases + default
+ match lexer.next() {
+ (Token::Word("case"), _) => {
+ // parse a list of values
+ let value = loop {
+ let value = self.switch_value(lexer, ctx)?;
+ if lexer.skip(Token::Separator(',')) {
+ if lexer.skip(Token::Separator(':')) {
+ break value;
+ }
+ } else {
+ lexer.skip(Token::Separator(':'));
+ break value;
+ }
+ cases.push(ast::SwitchCase {
+ value,
+ body: ast::Block::default(),
+ fall_through: true,
+ });
+ };
+
+ let body = self.block(lexer, ctx)?.0;
+
+ cases.push(ast::SwitchCase {
+ value,
+ body,
+ fall_through: false,
+ });
+ }
+ (Token::Word("default"), _) => {
+ lexer.skip(Token::Separator(':'));
+ let body = self.block(lexer, ctx)?.0;
+ cases.push(ast::SwitchCase {
+ value: ast::SwitchValue::Default,
+ body,
+ fall_through: false,
+ });
+ }
+ (Token::Paren('}'), _) => break,
+ (_, span) => {
+ return Err(Error::Unexpected(span, ExpectedToken::SwitchItem))
+ }
+ }
+ }
+
+ ast::StatementKind::Switch { selector, cases }
+ }
+ "loop" => self.r#loop(lexer, ctx)?,
+ "while" => {
+ let _ = lexer.next();
+ let mut body = ast::Block::default();
+
+ let (condition, span) = lexer.capture_span(|lexer| {
+ let condition = self.general_expression(lexer, ctx)?;
+ Ok(condition)
+ })?;
+ let mut reject = ast::Block::default();
+ reject.stmts.push(ast::Statement {
+ kind: ast::StatementKind::Break,
+ span,
+ });
+
+ body.stmts.push(ast::Statement {
+ kind: ast::StatementKind::If {
+ condition,
+ accept: ast::Block::default(),
+ reject,
+ },
+ span,
+ });
+
+ let (block, span) = self.block(lexer, ctx)?;
+ body.stmts.push(ast::Statement {
+ kind: ast::StatementKind::Block(block),
+ span,
+ });
+
+ ast::StatementKind::Loop {
+ body,
+ continuing: ast::Block::default(),
+ break_if: None,
+ }
+ }
+ "for" => {
+ let _ = lexer.next();
+ lexer.expect(Token::Paren('('))?;
+
+ ctx.local_table.push_scope();
+
+ if !lexer.skip(Token::Separator(';')) {
+ let num_statements = block.stmts.len();
+ let (_, span) = {
+ let ctx = &mut *ctx;
+ let block = &mut *block;
+ lexer.capture_span(|lexer| self.statement(lexer, ctx, block))?
+ };
+
+ if block.stmts.len() != num_statements {
+ match block.stmts.last().unwrap().kind {
+ ast::StatementKind::Call { .. }
+ | ast::StatementKind::Assign { .. }
+ | ast::StatementKind::LocalDecl(_) => {}
+ _ => return Err(Error::InvalidForInitializer(span)),
+ }
+ }
+ };
+
+ let mut body = ast::Block::default();
+ if !lexer.skip(Token::Separator(';')) {
+ let (condition, span) = lexer.capture_span(|lexer| {
+ let condition = self.general_expression(lexer, ctx)?;
+ lexer.expect(Token::Separator(';'))?;
+ Ok(condition)
+ })?;
+ let mut reject = ast::Block::default();
+ reject.stmts.push(ast::Statement {
+ kind: ast::StatementKind::Break,
+ span,
+ });
+ body.stmts.push(ast::Statement {
+ kind: ast::StatementKind::If {
+ condition,
+ accept: ast::Block::default(),
+ reject,
+ },
+ span,
+ });
+ };
+
+ let mut continuing = ast::Block::default();
+ if !lexer.skip(Token::Paren(')')) {
+ self.function_call_or_assignment_statement(
+ lexer,
+ ctx,
+ &mut continuing,
+ )?;
+ lexer.expect(Token::Paren(')'))?;
+ }
+
+ let (block, span) = self.block(lexer, ctx)?;
+ body.stmts.push(ast::Statement {
+ kind: ast::StatementKind::Block(block),
+ span,
+ });
+
+ ctx.local_table.pop_scope();
+
+ ast::StatementKind::Loop {
+ body,
+ continuing,
+ break_if: None,
+ }
+ }
+ "break" => {
+ let (_, span) = lexer.next();
+ // Check if the next token is an `if`, this indicates
+ // that the user tried to type out a `break if` which
+ // is illegal in this position.
+ let (peeked_token, peeked_span) = lexer.peek();
+ if let Token::Word("if") = peeked_token {
+ let span = span.until(&peeked_span);
+ return Err(Error::InvalidBreakIf(span));
+ }
+ lexer.expect(Token::Separator(';'))?;
+ ast::StatementKind::Break
+ }
+ "continue" => {
+ let _ = lexer.next();
+ lexer.expect(Token::Separator(';'))?;
+ ast::StatementKind::Continue
+ }
+ "discard" => {
+ let _ = lexer.next();
+ lexer.expect(Token::Separator(';'))?;
+ ast::StatementKind::Kill
+ }
+ // assignment or a function call
+ _ => {
+ self.function_call_or_assignment_statement(lexer, ctx, block)?;
+ lexer.expect(Token::Separator(';'))?;
+ self.pop_rule_span(lexer);
+ return Ok(());
+ }
+ };
+
+ let span = self.pop_rule_span(lexer);
+ block.stmts.push(ast::Statement { kind, span });
+ }
+ _ => {
+ self.assignment_statement(lexer, ctx, block)?;
+ lexer.expect(Token::Separator(';'))?;
+ self.pop_rule_span(lexer);
+ }
+ }
+ Ok(())
+ }
+
+ fn r#loop<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<ast::StatementKind<'a>, Error<'a>> {
+ let _ = lexer.next();
+ let mut body = ast::Block::default();
+ let mut continuing = ast::Block::default();
+ let mut break_if = None;
+
+ lexer.expect(Token::Paren('{'))?;
+
+ ctx.local_table.push_scope();
+
+ loop {
+ if lexer.skip(Token::Word("continuing")) {
+ // Branch for the `continuing` block, this must be
+ // the last thing in the loop body
+
+ // Expect a opening brace to start the continuing block
+ lexer.expect(Token::Paren('{'))?;
+ loop {
+ if lexer.skip(Token::Word("break")) {
+ // Branch for the `break if` statement, this statement
+ // has the form `break if <expr>;` and must be the last
+ // statement in a continuing block
+
+ // The break must be followed by an `if` to form
+ // the break if
+ lexer.expect(Token::Word("if"))?;
+
+ let condition = self.general_expression(lexer, ctx)?;
+ // Set the condition of the break if to the newly parsed
+ // expression
+ break_if = Some(condition);
+
+ // Expect a semicolon to close the statement
+ lexer.expect(Token::Separator(';'))?;
+ // Expect a closing brace to close the continuing block,
+ // since the break if must be the last statement
+ lexer.expect(Token::Paren('}'))?;
+ // Stop parsing the continuing block
+ break;
+ } else if lexer.skip(Token::Paren('}')) {
+ // If we encounter a closing brace it means we have reached
+ // the end of the continuing block and should stop processing
+ break;
+ } else {
+ // Otherwise try to parse a statement
+ self.statement(lexer, ctx, &mut continuing)?;
+ }
+ }
+ // Since the continuing block must be the last part of the loop body,
+ // we expect to see a closing brace to end the loop body
+ lexer.expect(Token::Paren('}'))?;
+ break;
+ }
+ if lexer.skip(Token::Paren('}')) {
+ // If we encounter a closing brace it means we have reached
+ // the end of the loop body and should stop processing
+ break;
+ }
+ // Otherwise try to parse a statement
+ self.statement(lexer, ctx, &mut body)?;
+ }
+
+ ctx.local_table.pop_scope();
+
+ Ok(ast::StatementKind::Loop {
+ body,
+ continuing,
+ break_if,
+ })
+ }
+
+ /// compound_statement
+ fn block<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<(ast::Block<'a>, Span), Error<'a>> {
+ self.push_rule_span(Rule::Block, lexer);
+
+ ctx.local_table.push_scope();
+
+ lexer.expect(Token::Paren('{'))?;
+ let mut block = ast::Block::default();
+ while !lexer.skip(Token::Paren('}')) {
+ self.statement(lexer, ctx, &mut block)?;
+ }
+
+ ctx.local_table.pop_scope();
+
+ let span = self.pop_rule_span(lexer);
+ Ok((block, span))
+ }
+
+ fn varying_binding<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<Option<ast::Binding<'a>>, Error<'a>> {
+ let mut bind_parser = BindingParser::default();
+ self.push_rule_span(Rule::Attribute, lexer);
+
+ while lexer.skip(Token::Attribute) {
+ let (word, span) = lexer.next_ident_with_span()?;
+ bind_parser.parse(self, lexer, word, span, ctx)?;
+ }
+
+ let span = self.pop_rule_span(lexer);
+ bind_parser.finish(span)
+ }
+
+ fn function_decl<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ out: &mut ast::TranslationUnit<'a>,
+ dependencies: &mut FastIndexSet<ast::Dependency<'a>>,
+ ) -> Result<ast::Function<'a>, Error<'a>> {
+ self.push_rule_span(Rule::FunctionDecl, lexer);
+ // read function name
+ let fun_name = lexer.next_ident()?;
+
+ let mut locals = Arena::new();
+
+ let mut ctx = ExpressionContext {
+ expressions: &mut out.expressions,
+ local_table: &mut SymbolTable::default(),
+ locals: &mut locals,
+ types: &mut out.types,
+ unresolved: dependencies,
+ };
+
+ // start a scope that contains arguments as well as the function body
+ ctx.local_table.push_scope();
+
+ // read parameter list
+ let mut arguments = Vec::new();
+ lexer.expect(Token::Paren('('))?;
+ let mut ready = true;
+ while !lexer.skip(Token::Paren(')')) {
+ if !ready {
+ return Err(Error::Unexpected(
+ lexer.next().1,
+ ExpectedToken::Token(Token::Separator(',')),
+ ));
+ }
+ let binding = self.varying_binding(lexer, &mut ctx)?;
+
+ let param_name = lexer.next_ident()?;
+
+ lexer.expect(Token::Separator(':'))?;
+ let param_type = self.type_decl(lexer, &mut ctx)?;
+
+ let handle = ctx.declare_local(param_name)?;
+ arguments.push(ast::FunctionArgument {
+ name: param_name,
+ ty: param_type,
+ binding,
+ handle,
+ });
+ ready = lexer.skip(Token::Separator(','));
+ }
+ // read return type
+ let result = if lexer.skip(Token::Arrow) && !lexer.skip(Token::Word("void")) {
+ let binding = self.varying_binding(lexer, &mut ctx)?;
+ let ty = self.type_decl(lexer, &mut ctx)?;
+ Some(ast::FunctionResult { ty, binding })
+ } else {
+ None
+ };
+
+ // do not use `self.block` here, since we must not push a new scope
+ lexer.expect(Token::Paren('{'))?;
+ let mut body = ast::Block::default();
+ while !lexer.skip(Token::Paren('}')) {
+ self.statement(lexer, &mut ctx, &mut body)?;
+ }
+
+ ctx.local_table.pop_scope();
+
+ let fun = ast::Function {
+ entry_point: None,
+ name: fun_name,
+ arguments,
+ result,
+ body,
+ locals,
+ };
+
+ // done
+ self.pop_rule_span(lexer);
+
+ Ok(fun)
+ }
+
+ fn global_decl<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ out: &mut ast::TranslationUnit<'a>,
+ ) -> Result<(), Error<'a>> {
+ // read attributes
+ let mut binding = None;
+ let mut stage = ParsedAttribute::default();
+ let mut compute_span = Span::new(0, 0);
+ let mut workgroup_size = ParsedAttribute::default();
+ let mut early_depth_test = ParsedAttribute::default();
+ let (mut bind_index, mut bind_group) =
+ (ParsedAttribute::default(), ParsedAttribute::default());
+
+ let mut dependencies = FastIndexSet::default();
+ let mut ctx = ExpressionContext {
+ expressions: &mut out.expressions,
+ local_table: &mut SymbolTable::default(),
+ locals: &mut Arena::new(),
+ types: &mut out.types,
+ unresolved: &mut dependencies,
+ };
+
+ self.push_rule_span(Rule::Attribute, lexer);
+ while lexer.skip(Token::Attribute) {
+ match lexer.next_ident_with_span()? {
+ ("binding", name_span) => {
+ lexer.expect(Token::Paren('('))?;
+ bind_index.set(self.general_expression(lexer, &mut ctx)?, name_span)?;
+ lexer.expect(Token::Paren(')'))?;
+ }
+ ("group", name_span) => {
+ lexer.expect(Token::Paren('('))?;
+ bind_group.set(self.general_expression(lexer, &mut ctx)?, name_span)?;
+ lexer.expect(Token::Paren(')'))?;
+ }
+ ("vertex", name_span) => {
+ stage.set(crate::ShaderStage::Vertex, name_span)?;
+ }
+ ("fragment", name_span) => {
+ stage.set(crate::ShaderStage::Fragment, name_span)?;
+ }
+ ("compute", name_span) => {
+ stage.set(crate::ShaderStage::Compute, name_span)?;
+ compute_span = name_span;
+ }
+ ("workgroup_size", name_span) => {
+ lexer.expect(Token::Paren('('))?;
+ let mut new_workgroup_size = [None; 3];
+ for (i, size) in new_workgroup_size.iter_mut().enumerate() {
+ *size = Some(self.general_expression(lexer, &mut ctx)?);
+ match lexer.next() {
+ (Token::Paren(')'), _) => break,
+ (Token::Separator(','), _) if i != 2 => (),
+ other => {
+ return Err(Error::Unexpected(
+ other.1,
+ ExpectedToken::WorkgroupSizeSeparator,
+ ))
+ }
+ }
+ }
+ workgroup_size.set(new_workgroup_size, name_span)?;
+ }
+ ("early_depth_test", name_span) => {
+ let conservative = if lexer.skip(Token::Paren('(')) {
+ let (ident, ident_span) = lexer.next_ident_with_span()?;
+ let value = conv::map_conservative_depth(ident, ident_span)?;
+ lexer.expect(Token::Paren(')'))?;
+ Some(value)
+ } else {
+ None
+ };
+ early_depth_test.set(crate::EarlyDepthTest { conservative }, name_span)?;
+ }
+ (_, word_span) => return Err(Error::UnknownAttribute(word_span)),
+ }
+ }
+
+ let attrib_span = self.pop_rule_span(lexer);
+ match (bind_group.value, bind_index.value) {
+ (Some(group), Some(index)) => {
+ binding = Some(ast::ResourceBinding {
+ group,
+ binding: index,
+ });
+ }
+ (Some(_), None) => return Err(Error::MissingAttribute("binding", attrib_span)),
+ (None, Some(_)) => return Err(Error::MissingAttribute("group", attrib_span)),
+ (None, None) => {}
+ }
+
+ // read item
+ let start = lexer.start_byte_offset();
+ let kind = match lexer.next() {
+ (Token::Separator(';'), _) => None,
+ (Token::Word("struct"), _) => {
+ let name = lexer.next_ident()?;
+
+ let members = self.struct_body(lexer, &mut ctx)?;
+ Some(ast::GlobalDeclKind::Struct(ast::Struct { name, members }))
+ }
+ (Token::Word("alias"), _) => {
+ let name = lexer.next_ident()?;
+
+ lexer.expect(Token::Operation('='))?;
+ let ty = self.type_decl(lexer, &mut ctx)?;
+ lexer.expect(Token::Separator(';'))?;
+ Some(ast::GlobalDeclKind::Type(ast::TypeAlias { name, ty }))
+ }
+ (Token::Word("const"), _) => {
+ let name = lexer.next_ident()?;
+
+ let ty = if lexer.skip(Token::Separator(':')) {
+ let ty = self.type_decl(lexer, &mut ctx)?;
+ Some(ty)
+ } else {
+ None
+ };
+
+ lexer.expect(Token::Operation('='))?;
+ let init = self.general_expression(lexer, &mut ctx)?;
+ lexer.expect(Token::Separator(';'))?;
+
+ Some(ast::GlobalDeclKind::Const(ast::Const { name, ty, init }))
+ }
+ (Token::Word("var"), _) => {
+ let mut var = self.variable_decl(lexer, &mut ctx)?;
+ var.binding = binding.take();
+ Some(ast::GlobalDeclKind::Var(var))
+ }
+ (Token::Word("fn"), _) => {
+ let function = self.function_decl(lexer, out, &mut dependencies)?;
+ Some(ast::GlobalDeclKind::Fn(ast::Function {
+ entry_point: if let Some(stage) = stage.value {
+ if stage == ShaderStage::Compute && workgroup_size.value.is_none() {
+ return Err(Error::MissingWorkgroupSize(compute_span));
+ }
+ Some(ast::EntryPoint {
+ stage,
+ early_depth_test: early_depth_test.value,
+ workgroup_size: workgroup_size.value,
+ })
+ } else {
+ None
+ },
+ ..function
+ }))
+ }
+ (Token::End, _) => return Ok(()),
+ other => return Err(Error::Unexpected(other.1, ExpectedToken::GlobalItem)),
+ };
+
+ if let Some(kind) = kind {
+ out.decls.append(
+ ast::GlobalDecl { kind, dependencies },
+ lexer.span_from(start),
+ );
+ }
+
+ if !self.rules.is_empty() {
+ log::error!("Reached the end of global decl, but rule stack is not empty");
+ log::error!("Rules: {:?}", self.rules);
+ return Err(Error::Internal("rule stack is not empty"));
+ };
+
+ match binding {
+ None => Ok(()),
+ Some(_) => Err(Error::Internal("we had the attribute but no var?")),
+ }
+ }
+
+ pub fn parse<'a>(&mut self, source: &'a str) -> Result<ast::TranslationUnit<'a>, Error<'a>> {
+ self.reset();
+
+ let mut lexer = Lexer::new(source);
+ let mut tu = ast::TranslationUnit::default();
+ loop {
+ match self.global_decl(&mut lexer, &mut tu) {
+ Err(error) => return Err(error),
+ Ok(()) => {
+ if lexer.peek().0 == Token::End {
+ break;
+ }
+ }
+ }
+ }
+
+ Ok(tu)
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/parse/number.rs b/third_party/rust/naga/src/front/wgsl/parse/number.rs
new file mode 100644
index 0000000000..7b09ac59bb
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/parse/number.rs
@@ -0,0 +1,420 @@
+use crate::front::wgsl::error::NumberError;
+use crate::front::wgsl::parse::lexer::Token;
+
+/// When using this type assume no Abstract Int/Float for now
+#[derive(Copy, Clone, Debug, PartialEq)]
+pub enum Number {
+ /// Abstract Int (-2^63 ≤ i < 2^63)
+ AbstractInt(i64),
+ /// Abstract Float (IEEE-754 binary64)
+ AbstractFloat(f64),
+ /// Concrete i32
+ I32(i32),
+ /// Concrete u32
+ U32(u32),
+ /// Concrete f32
+ F32(f32),
+ /// Concrete f64
+ F64(f64),
+}
+
+pub(in crate::front::wgsl) fn consume_number(input: &str) -> (Token<'_>, &str) {
+ let (result, rest) = parse(input);
+ (Token::Number(result), rest)
+}
+
+enum Kind {
+ Int(IntKind),
+ Float(FloatKind),
+}
+
+enum IntKind {
+ I32,
+ U32,
+}
+
+#[derive(Debug)]
+enum FloatKind {
+ F16,
+ F32,
+ F64,
+}
+
+// The following regexes (from the WGSL spec) will be matched:
+
+// int_literal:
+// | / 0 [iu]? /
+// | / [1-9][0-9]* [iu]? /
+// | / 0[xX][0-9a-fA-F]+ [iu]? /
+
+// decimal_float_literal:
+// | / 0 [fh] /
+// | / [1-9][0-9]* [fh] /
+// | / [0-9]* \.[0-9]+ ([eE][+-]?[0-9]+)? [fh]? /
+// | / [0-9]+ \.[0-9]* ([eE][+-]?[0-9]+)? [fh]? /
+// | / [0-9]+ [eE][+-]?[0-9]+ [fh]? /
+
+// hex_float_literal:
+// | / 0[xX][0-9a-fA-F]* \.[0-9a-fA-F]+ ([pP][+-]?[0-9]+ [fh]?)? /
+// | / 0[xX][0-9a-fA-F]+ \.[0-9a-fA-F]* ([pP][+-]?[0-9]+ [fh]?)? /
+// | / 0[xX][0-9a-fA-F]+ [pP][+-]?[0-9]+ [fh]? /
+
+// You could visualize the regex below via https://debuggex.com to get a rough idea what `parse` is doing
+// (?:0[xX](?:([0-9a-fA-F]+\.[0-9a-fA-F]*|[0-9a-fA-F]*\.[0-9a-fA-F]+)(?:([pP][+-]?[0-9]+)([fh]?))?|([0-9a-fA-F]+)([pP][+-]?[0-9]+)([fh]?)|([0-9a-fA-F]+)([iu]?))|((?:[0-9]+[eE][+-]?[0-9]+|(?:[0-9]+\.[0-9]*|[0-9]*\.[0-9]+)(?:[eE][+-]?[0-9]+)?))([fh]?)|((?:[0-9]|[1-9][0-9]+))([iufh]?))
+
+// Leading signs are handled as unary operators.
+
+fn parse(input: &str) -> (Result<Number, NumberError>, &str) {
+ /// returns `true` and consumes `X` bytes from the given byte buffer
+ /// if the given `X` nr of patterns are found at the start of the buffer
+ macro_rules! consume {
+ ($bytes:ident, $($pattern:pat),*) => {
+ match $bytes {
+ &[$($pattern),*, ref rest @ ..] => { $bytes = rest; true },
+ _ => false,
+ }
+ };
+ }
+
+ /// consumes one byte from the given byte buffer
+ /// if one of the given patterns are found at the start of the buffer
+ /// returning the corresponding expr for the matched pattern
+ macro_rules! consume_map {
+ ($bytes:ident, [$( $($pattern:pat_param),* => $to:expr),* $(,)?]) => {
+ match $bytes {
+ $( &[ $($pattern),*, ref rest @ ..] => { $bytes = rest; Some($to) }, )*
+ _ => None,
+ }
+ };
+ }
+
+ /// consumes all consecutive bytes matched by the `0-9` pattern from the given byte buffer
+ /// returning the number of consumed bytes
+ macro_rules! consume_dec_digits {
+ ($bytes:ident) => {{
+ let start_len = $bytes.len();
+ while let &[b'0'..=b'9', ref rest @ ..] = $bytes {
+ $bytes = rest;
+ }
+ start_len - $bytes.len()
+ }};
+ }
+
+ /// consumes all consecutive bytes matched by the `0-9 | a-f | A-F` pattern from the given byte buffer
+ /// returning the number of consumed bytes
+ macro_rules! consume_hex_digits {
+ ($bytes:ident) => {{
+ let start_len = $bytes.len();
+ while let &[b'0'..=b'9' | b'a'..=b'f' | b'A'..=b'F', ref rest @ ..] = $bytes {
+ $bytes = rest;
+ }
+ start_len - $bytes.len()
+ }};
+ }
+
+ macro_rules! consume_float_suffix {
+ ($bytes:ident) => {
+ consume_map!($bytes, [
+ b'h' => FloatKind::F16,
+ b'f' => FloatKind::F32,
+ b'l', b'f' => FloatKind::F64,
+ ])
+ };
+ }
+
+ /// maps the given `&[u8]` (tail of the initial `input: &str`) to a `&str`
+ macro_rules! rest_to_str {
+ ($bytes:ident) => {
+ &input[input.len() - $bytes.len()..]
+ };
+ }
+
+ struct ExtractSubStr<'a>(&'a str);
+
+ impl<'a> ExtractSubStr<'a> {
+ /// given an `input` and a `start` (tail of the `input`)
+ /// creates a new [`ExtractSubStr`](`Self`)
+ fn start(input: &'a str, start: &'a [u8]) -> Self {
+ let start = input.len() - start.len();
+ Self(&input[start..])
+ }
+ /// given an `end` (tail of the initial `input`)
+ /// returns a substring of `input`
+ fn end(&self, end: &'a [u8]) -> &'a str {
+ let end = self.0.len() - end.len();
+ &self.0[..end]
+ }
+ }
+
+ let mut bytes = input.as_bytes();
+
+ let general_extract = ExtractSubStr::start(input, bytes);
+
+ if consume!(bytes, b'0', b'x' | b'X') {
+ let digits_extract = ExtractSubStr::start(input, bytes);
+
+ let consumed = consume_hex_digits!(bytes);
+
+ if consume!(bytes, b'.') {
+ let consumed_after_period = consume_hex_digits!(bytes);
+
+ if consumed + consumed_after_period == 0 {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+
+ let significand = general_extract.end(bytes);
+
+ if consume!(bytes, b'p' | b'P') {
+ consume!(bytes, b'+' | b'-');
+ let consumed = consume_dec_digits!(bytes);
+
+ if consumed == 0 {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+
+ let number = general_extract.end(bytes);
+
+ let kind = consume_float_suffix!(bytes);
+
+ (parse_hex_float(number, kind), rest_to_str!(bytes))
+ } else {
+ (
+ parse_hex_float_missing_exponent(significand, None),
+ rest_to_str!(bytes),
+ )
+ }
+ } else {
+ if consumed == 0 {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+
+ let significand = general_extract.end(bytes);
+ let digits = digits_extract.end(bytes);
+
+ let exp_extract = ExtractSubStr::start(input, bytes);
+
+ if consume!(bytes, b'p' | b'P') {
+ consume!(bytes, b'+' | b'-');
+ let consumed = consume_dec_digits!(bytes);
+
+ if consumed == 0 {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+
+ let exponent = exp_extract.end(bytes);
+
+ let kind = consume_float_suffix!(bytes);
+
+ (
+ parse_hex_float_missing_period(significand, exponent, kind),
+ rest_to_str!(bytes),
+ )
+ } else {
+ let kind = consume_map!(bytes, [b'i' => IntKind::I32, b'u' => IntKind::U32]);
+
+ (parse_hex_int(digits, kind), rest_to_str!(bytes))
+ }
+ }
+ } else {
+ let is_first_zero = bytes.first() == Some(&b'0');
+
+ let consumed = consume_dec_digits!(bytes);
+
+ if consume!(bytes, b'.') {
+ let consumed_after_period = consume_dec_digits!(bytes);
+
+ if consumed + consumed_after_period == 0 {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+
+ if consume!(bytes, b'e' | b'E') {
+ consume!(bytes, b'+' | b'-');
+ let consumed = consume_dec_digits!(bytes);
+
+ if consumed == 0 {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+ }
+
+ let number = general_extract.end(bytes);
+
+ let kind = consume_float_suffix!(bytes);
+
+ (parse_dec_float(number, kind), rest_to_str!(bytes))
+ } else {
+ if consumed == 0 {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+
+ if consume!(bytes, b'e' | b'E') {
+ consume!(bytes, b'+' | b'-');
+ let consumed = consume_dec_digits!(bytes);
+
+ if consumed == 0 {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+
+ let number = general_extract.end(bytes);
+
+ let kind = consume_float_suffix!(bytes);
+
+ (parse_dec_float(number, kind), rest_to_str!(bytes))
+ } else {
+ // make sure the multi-digit numbers don't start with zero
+ if consumed > 1 && is_first_zero {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+
+ let digits = general_extract.end(bytes);
+
+ let kind = consume_map!(bytes, [
+ b'i' => Kind::Int(IntKind::I32),
+ b'u' => Kind::Int(IntKind::U32),
+ b'h' => Kind::Float(FloatKind::F16),
+ b'f' => Kind::Float(FloatKind::F32),
+ b'l', b'f' => Kind::Float(FloatKind::F64),
+ ]);
+
+ (parse_dec(digits, kind), rest_to_str!(bytes))
+ }
+ }
+ }
+}
+
+fn parse_hex_float_missing_exponent(
+ // format: 0[xX] ( [0-9a-fA-F]+\.[0-9a-fA-F]* | [0-9a-fA-F]*\.[0-9a-fA-F]+ )
+ significand: &str,
+ kind: Option<FloatKind>,
+) -> Result<Number, NumberError> {
+ let hexf_input = format!("{}{}", significand, "p0");
+ parse_hex_float(&hexf_input, kind)
+}
+
+fn parse_hex_float_missing_period(
+ // format: 0[xX] [0-9a-fA-F]+
+ significand: &str,
+ // format: [pP][+-]?[0-9]+
+ exponent: &str,
+ kind: Option<FloatKind>,
+) -> Result<Number, NumberError> {
+ let hexf_input = format!("{significand}.{exponent}");
+ parse_hex_float(&hexf_input, kind)
+}
+
+fn parse_hex_int(
+ // format: [0-9a-fA-F]+
+ digits: &str,
+ kind: Option<IntKind>,
+) -> Result<Number, NumberError> {
+ parse_int(digits, kind, 16)
+}
+
+fn parse_dec(
+ // format: ( [0-9] | [1-9][0-9]+ )
+ digits: &str,
+ kind: Option<Kind>,
+) -> Result<Number, NumberError> {
+ match kind {
+ None => parse_int(digits, None, 10),
+ Some(Kind::Int(kind)) => parse_int(digits, Some(kind), 10),
+ Some(Kind::Float(kind)) => parse_dec_float(digits, Some(kind)),
+ }
+}
+
+// Float parsing notes
+
+// The following chapters of IEEE 754-2019 are relevant:
+//
+// 7.4 Overflow (largest finite number is exceeded by what would have been
+// the rounded floating-point result were the exponent range unbounded)
+//
+// 7.5 Underflow (tiny non-zero result is detected;
+// for decimal formats tininess is detected before rounding when a non-zero result
+// computed as though both the exponent range and the precision were unbounded
+// would lie strictly between 2^−126)
+//
+// 7.6 Inexact (rounded result differs from what would have been computed
+// were both exponent range and precision unbounded)
+
+// The WGSL spec requires us to error:
+// on overflow for decimal floating point literals
+// on overflow and inexact for hexadecimal floating point literals
+// (underflow is not mentioned)
+
+// hexf_parse errors on overflow, underflow, inexact
+// rust std lib float from str handles overflow, underflow, inexact transparently (rounds and will not error)
+
+// Therefore we only check for overflow manually for decimal floating point literals
+
+// input format: 0[xX] ( [0-9a-fA-F]+\.[0-9a-fA-F]* | [0-9a-fA-F]*\.[0-9a-fA-F]+ ) [pP][+-]?[0-9]+
+fn parse_hex_float(input: &str, kind: Option<FloatKind>) -> Result<Number, NumberError> {
+ match kind {
+ None => match hexf_parse::parse_hexf64(input, false) {
+ Ok(num) => Ok(Number::AbstractFloat(num)),
+ // can only be ParseHexfErrorKind::Inexact but we can't check since it's private
+ _ => Err(NumberError::NotRepresentable),
+ },
+ Some(FloatKind::F16) => Err(NumberError::UnimplementedF16),
+ Some(FloatKind::F32) => match hexf_parse::parse_hexf32(input, false) {
+ Ok(num) => Ok(Number::F32(num)),
+ // can only be ParseHexfErrorKind::Inexact but we can't check since it's private
+ _ => Err(NumberError::NotRepresentable),
+ },
+ Some(FloatKind::F64) => match hexf_parse::parse_hexf64(input, false) {
+ Ok(num) => Ok(Number::F64(num)),
+ // can only be ParseHexfErrorKind::Inexact but we can't check since it's private
+ _ => Err(NumberError::NotRepresentable),
+ },
+ }
+}
+
+// input format: ( [0-9]+\.[0-9]* | [0-9]*\.[0-9]+ ) ([eE][+-]?[0-9]+)?
+// | [0-9]+ [eE][+-]?[0-9]+
+fn parse_dec_float(input: &str, kind: Option<FloatKind>) -> Result<Number, NumberError> {
+ match kind {
+ None => {
+ let num = input.parse::<f64>().unwrap(); // will never fail
+ num.is_finite()
+ .then_some(Number::AbstractFloat(num))
+ .ok_or(NumberError::NotRepresentable)
+ }
+ Some(FloatKind::F32) => {
+ let num = input.parse::<f32>().unwrap(); // will never fail
+ num.is_finite()
+ .then_some(Number::F32(num))
+ .ok_or(NumberError::NotRepresentable)
+ }
+ Some(FloatKind::F64) => {
+ let num = input.parse::<f64>().unwrap(); // will never fail
+ num.is_finite()
+ .then_some(Number::F64(num))
+ .ok_or(NumberError::NotRepresentable)
+ }
+ Some(FloatKind::F16) => Err(NumberError::UnimplementedF16),
+ }
+}
+
+fn parse_int(input: &str, kind: Option<IntKind>, radix: u32) -> Result<Number, NumberError> {
+ fn map_err(e: core::num::ParseIntError) -> NumberError {
+ match *e.kind() {
+ core::num::IntErrorKind::PosOverflow | core::num::IntErrorKind::NegOverflow => {
+ NumberError::NotRepresentable
+ }
+ _ => unreachable!(),
+ }
+ }
+ match kind {
+ None => match i64::from_str_radix(input, radix) {
+ Ok(num) => Ok(Number::AbstractInt(num)),
+ Err(e) => Err(map_err(e)),
+ },
+ Some(IntKind::I32) => match i32::from_str_radix(input, radix) {
+ Ok(num) => Ok(Number::I32(num)),
+ Err(e) => Err(map_err(e)),
+ },
+ Some(IntKind::U32) => match u32::from_str_radix(input, radix) {
+ Ok(num) => Ok(Number::U32(num)),
+ Err(e) => Err(map_err(e)),
+ },
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/tests.rs b/third_party/rust/naga/src/front/wgsl/tests.rs
new file mode 100644
index 0000000000..eb2f8a2eb3
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/tests.rs
@@ -0,0 +1,637 @@
+use super::parse_str;
+
+#[test]
+fn parse_comment() {
+ parse_str(
+ "//
+ ////
+ ///////////////////////////////////////////////////////// asda
+ //////////////////// dad ////////// /
+ /////////////////////////////////////////////////////////////////////////////////////////////////////
+ //
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_types() {
+ parse_str("const a : i32 = 2;").unwrap();
+ assert!(parse_str("const a : x32 = 2;").is_err());
+ parse_str("var t: texture_2d<f32>;").unwrap();
+ parse_str("var t: texture_cube_array<i32>;").unwrap();
+ parse_str("var t: texture_multisampled_2d<u32>;").unwrap();
+ parse_str("var t: texture_storage_1d<rgba8uint,write>;").unwrap();
+ parse_str("var t: texture_storage_3d<r32float,read>;").unwrap();
+}
+
+#[test]
+fn parse_type_inference() {
+ parse_str(
+ "
+ fn foo() {
+ let a = 2u;
+ let b: u32 = a;
+ var x = 3.;
+ var y = vec2<f32>(1, 2);
+ }",
+ )
+ .unwrap();
+ assert!(parse_str(
+ "
+ fn foo() { let c : i32 = 2.0; }",
+ )
+ .is_err());
+}
+
+#[test]
+fn parse_type_cast() {
+ parse_str(
+ "
+ const a : i32 = 2;
+ fn main() {
+ var x: f32 = f32(a);
+ x = f32(i32(a + 1) / 2);
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ fn main() {
+ let x: vec2<f32> = vec2<f32>(1.0, 2.0);
+ let y: vec2<u32> = vec2<u32>(x);
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ fn main() {
+ let x: vec2<f32> = vec2<f32>(0.0);
+ }
+ ",
+ )
+ .unwrap();
+ assert!(parse_str(
+ "
+ fn main() {
+ let x: vec2<f32> = vec2<f32>(0i, 0i);
+ }
+ ",
+ )
+ .is_err());
+}
+
+#[test]
+fn parse_struct() {
+ parse_str(
+ "
+ struct Foo { x: i32 }
+ struct Bar {
+ @size(16) x: vec2<i32>,
+ @align(16) y: f32,
+ @size(32) @align(128) z: vec3<f32>,
+ };
+ struct Empty {}
+ var<storage,read_write> s: Foo;
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_standard_fun() {
+ parse_str(
+ "
+ fn main() {
+ var x: i32 = min(max(1, 2), 3);
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_statement() {
+ parse_str(
+ "
+ fn main() {
+ ;
+ {}
+ {;}
+ }
+ ",
+ )
+ .unwrap();
+
+ parse_str(
+ "
+ fn foo() {}
+ fn bar() { foo(); }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_if() {
+ parse_str(
+ "
+ fn main() {
+ if true {
+ discard;
+ } else {}
+ if 0 != 1 {}
+ if false {
+ return;
+ } else if true {
+ return;
+ } else {}
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_parentheses_if() {
+ parse_str(
+ "
+ fn main() {
+ if (true) {
+ discard;
+ } else {}
+ if (0 != 1) {}
+ if (false) {
+ return;
+ } else if (true) {
+ return;
+ } else {}
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_loop() {
+ parse_str(
+ "
+ fn main() {
+ var i: i32 = 0;
+ loop {
+ if i == 1 { break; }
+ continuing { i = 1; }
+ }
+ loop {
+ if i == 0 { continue; }
+ break;
+ }
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ fn main() {
+ var found: bool = false;
+ var i: i32 = 0;
+ while !found {
+ if i == 10 {
+ found = true;
+ }
+
+ i = i + 1;
+ }
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ fn main() {
+ while true {
+ break;
+ }
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ fn main() {
+ var a: i32 = 0;
+ for(var i: i32 = 0; i < 4; i = i + 1) {
+ a = a + 2;
+ }
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ fn main() {
+ for(;;) {
+ break;
+ }
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_switch() {
+ parse_str(
+ "
+ fn main() {
+ var pos: f32;
+ switch (3) {
+ case 0, 1: { pos = 0.0; }
+ case 2: { pos = 1.0; }
+ default: { pos = 3.0; }
+ }
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_switch_optional_colon_in_case() {
+ parse_str(
+ "
+ fn main() {
+ var pos: f32;
+ switch (3) {
+ case 0, 1 { pos = 0.0; }
+ case 2 { pos = 1.0; }
+ default { pos = 3.0; }
+ }
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_switch_default_in_case() {
+ parse_str(
+ "
+ fn main() {
+ var pos: f32;
+ switch (3) {
+ case 0, 1: { pos = 0.0; }
+ case 2: {}
+ case default, 3: { pos = 3.0; }
+ }
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_parentheses_switch() {
+ parse_str(
+ "
+ fn main() {
+ var pos: f32;
+ switch pos > 1.0 {
+ default: { pos = 3.0; }
+ }
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_texture_load() {
+ parse_str(
+ "
+ var t: texture_3d<u32>;
+ fn foo() {
+ let r: vec4<u32> = textureLoad(t, vec3<u32>(0u, 1u, 2u), 1);
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ var t: texture_multisampled_2d_array<i32>;
+ fn foo() {
+ let r: vec4<i32> = textureLoad(t, vec2<i32>(10, 20), 2, 3);
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ var t: texture_storage_1d_array<r32float,read>;
+ fn foo() {
+ let r: vec4<f32> = textureLoad(t, 10, 2);
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_texture_store() {
+ parse_str(
+ "
+ var t: texture_storage_2d<rgba8unorm,write>;
+ fn foo() {
+ textureStore(t, vec2<i32>(10, 20), vec4<f32>(0.0, 1.0, 2.0, 3.0));
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_texture_query() {
+ parse_str(
+ "
+ var t: texture_multisampled_2d_array<f32>;
+ fn foo() {
+ var dim: vec2<u32> = textureDimensions(t);
+ dim = textureDimensions(t, 0);
+ let layers: u32 = textureNumLayers(t);
+ let samples: u32 = textureNumSamples(t);
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_postfix() {
+ parse_str(
+ "fn foo() {
+ let x: f32 = vec4<f32>(1.0, 2.0, 3.0, 4.0).xyz.rgbr.aaaa.wz.g;
+ let y: f32 = fract(vec2<f32>(0.5, x)).x;
+ }",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_expressions() {
+ parse_str("fn foo() {
+ let x: f32 = select(0.0, 1.0, true);
+ let y: vec2<f32> = select(vec2<f32>(1.0, 1.0), vec2<f32>(x, x), vec2<bool>(x < 0.5, x > 0.5));
+ let z: bool = !(0.0 == 1.0);
+ }").unwrap();
+}
+
+#[test]
+fn binary_expression_mixed_scalar_and_vector_operands() {
+ for (operand, expect_splat) in [
+ ('<', false),
+ ('>', false),
+ ('&', false),
+ ('|', false),
+ ('+', true),
+ ('-', true),
+ ('*', false),
+ ('/', true),
+ ('%', true),
+ ] {
+ let module = parse_str(&format!(
+ "
+ @fragment
+ fn main(@location(0) some_vec: vec3<f32>) -> @location(0) vec4<f32> {{
+ if (all(1.0 {operand} some_vec)) {{
+ return vec4(0.0);
+ }}
+ return vec4(1.0);
+ }}
+ "
+ ))
+ .unwrap();
+
+ let expressions = &&module.entry_points[0].function.expressions;
+
+ let found_expressions = expressions
+ .iter()
+ .filter(|&(_, e)| {
+ if let crate::Expression::Binary { left, .. } = *e {
+ matches!(
+ (expect_splat, &expressions[left]),
+ (false, &crate::Expression::Literal(crate::Literal::F32(..)))
+ | (true, &crate::Expression::Splat { .. })
+ )
+ } else {
+ false
+ }
+ })
+ .count();
+
+ assert_eq!(
+ found_expressions,
+ 1,
+ "expected `{operand}` expression {} splat",
+ if expect_splat { "with" } else { "without" }
+ );
+ }
+
+ let module = parse_str(
+ "@fragment
+ fn main(mat: mat3x3<f32>) {
+ let vec = vec3<f32>(1.0, 1.0, 1.0);
+ let result = mat / vec;
+ }",
+ )
+ .unwrap();
+ let expressions = &&module.entry_points[0].function.expressions;
+ let found_splat = expressions.iter().any(|(_, e)| {
+ if let crate::Expression::Binary { left, .. } = *e {
+ matches!(&expressions[left], &crate::Expression::Splat { .. })
+ } else {
+ false
+ }
+ });
+ assert!(!found_splat, "'mat / vec' should not be splatted");
+}
+
+#[test]
+fn parse_pointers() {
+ parse_str(
+ "fn foo(a: ptr<private, f32>) -> f32 { return *a; }
+ fn bar() {
+ var x: f32 = 1.0;
+ let px = &x;
+ let py = foo(px);
+ }",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_struct_instantiation() {
+ parse_str(
+ "
+ struct Foo {
+ a: f32,
+ b: vec3<f32>,
+ }
+
+ @fragment
+ fn fs_main() {
+ var foo: Foo = Foo(0.0, vec3<f32>(0.0, 1.0, 42.0));
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_array_length() {
+ parse_str(
+ "
+ struct Foo {
+ data: array<u32>
+ } // this is used as both input and output for convenience
+
+ @group(0) @binding(0)
+ var<storage> foo: Foo;
+
+ @group(0) @binding(1)
+ var<storage> bar: array<u32>;
+
+ fn baz() {
+ var x: u32 = arrayLength(foo.data);
+ var y: u32 = arrayLength(bar);
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_storage_buffers() {
+ parse_str(
+ "
+ @group(0) @binding(0)
+ var<storage> foo: array<u32>;
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ @group(0) @binding(0)
+ var<storage,read> foo: array<u32>;
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ @group(0) @binding(0)
+ var<storage,write> foo: array<u32>;
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ @group(0) @binding(0)
+ var<storage,read_write> foo: array<u32>;
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_alias() {
+ parse_str(
+ "
+ alias Vec4 = vec4<f32>;
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_texture_load_store_expecting_four_args() {
+ for (func, texture) in [
+ (
+ "textureStore",
+ "texture_storage_2d_array<rg11b10float, write>",
+ ),
+ ("textureLoad", "texture_2d_array<i32>"),
+ ] {
+ let error = parse_str(&format!(
+ "
+ @group(0) @binding(0) var tex_los_res: {texture};
+ @compute
+ @workgroup_size(1)
+ fn main(@builtin(global_invocation_id) id: vec3<u32>) {{
+ var color = vec4(1, 1, 1, 1);
+ {func}(tex_los_res, id, color);
+ }}
+ "
+ ))
+ .unwrap_err();
+ assert_eq!(
+ error.message(),
+ "wrong number of arguments: expected 4, found 3"
+ );
+ }
+}
+
+#[test]
+fn parse_repeated_attributes() {
+ use crate::{
+ front::wgsl::{error::Error, Frontend},
+ Span,
+ };
+
+ let template_vs = "@vertex fn vs() -> __REPLACE__ vec4<f32> { return vec4<f32>(0.0); }";
+ let template_struct = "struct A { __REPLACE__ data: vec3<f32> }";
+ let template_resource = "__REPLACE__ var tex_los_res: texture_2d_array<i32>;";
+ let template_stage = "__REPLACE__ fn vs() -> vec4<f32> { return vec4<f32>(0.0); }";
+ for (attribute, template) in [
+ ("align(16)", template_struct),
+ ("binding(0)", template_resource),
+ ("builtin(position)", template_vs),
+ ("compute", template_stage),
+ ("fragment", template_stage),
+ ("group(0)", template_resource),
+ ("interpolate(flat)", template_vs),
+ ("invariant", template_vs),
+ ("location(0)", template_vs),
+ ("size(16)", template_struct),
+ ("vertex", template_stage),
+ ("early_depth_test(less_equal)", template_resource),
+ ("workgroup_size(1)", template_stage),
+ ] {
+ let shader = template.replace("__REPLACE__", &format!("@{attribute} @{attribute}"));
+ let name_length = attribute.rfind('(').unwrap_or(attribute.len()) as u32;
+ let span_start = shader.rfind(attribute).unwrap() as u32;
+ let span_end = span_start + name_length;
+ let expected_span = Span::new(span_start, span_end);
+
+ let result = Frontend::new().inner(&shader);
+ assert!(matches!(
+ result.unwrap_err(),
+ Error::RepeatedAttribute(span) if span == expected_span
+ ));
+ }
+}
+
+#[test]
+fn parse_missing_workgroup_size() {
+ use crate::{
+ front::wgsl::{error::Error, Frontend},
+ Span,
+ };
+
+ let shader = "@compute fn vs() -> vec4<f32> { return vec4<f32>(0.0); }";
+ let result = Frontend::new().inner(shader);
+ assert!(matches!(
+ result.unwrap_err(),
+ Error::MissingWorkgroupSize(span) if span == Span::new(1, 8)
+ ));
+}
diff --git a/third_party/rust/naga/src/front/wgsl/to_wgsl.rs b/third_party/rust/naga/src/front/wgsl/to_wgsl.rs
new file mode 100644
index 0000000000..c8331ace09
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/to_wgsl.rs
@@ -0,0 +1,283 @@
+//! Producing the WGSL forms of types, for use in error messages.
+
+use crate::proc::GlobalCtx;
+use crate::Handle;
+
+impl crate::proc::TypeResolution {
+ pub fn to_wgsl(&self, gctx: &GlobalCtx) -> String {
+ match *self {
+ crate::proc::TypeResolution::Handle(handle) => handle.to_wgsl(gctx),
+ crate::proc::TypeResolution::Value(ref inner) => inner.to_wgsl(gctx),
+ }
+ }
+}
+
+impl Handle<crate::Type> {
+ /// Formats the type as it is written in wgsl.
+ ///
+ /// For example `vec3<f32>`.
+ pub fn to_wgsl(self, gctx: &GlobalCtx) -> String {
+ let ty = &gctx.types[self];
+ match ty.name {
+ Some(ref name) => name.clone(),
+ None => ty.inner.to_wgsl(gctx),
+ }
+ }
+}
+
+impl crate::TypeInner {
+ /// Formats the type as it is written in wgsl.
+ ///
+ /// For example `vec3<f32>`.
+ ///
+ /// Note: `TypeInner::Struct` doesn't include the name of the
+ /// struct type. Therefore this method will simply return "struct"
+ /// for them.
+ pub fn to_wgsl(&self, gctx: &GlobalCtx) -> String {
+ use crate::TypeInner as Ti;
+
+ match *self {
+ Ti::Scalar(scalar) => scalar.to_wgsl(),
+ Ti::Vector { size, scalar } => {
+ format!("vec{}<{}>", size as u32, scalar.to_wgsl())
+ }
+ Ti::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ format!(
+ "mat{}x{}<{}>",
+ columns as u32,
+ rows as u32,
+ scalar.to_wgsl(),
+ )
+ }
+ Ti::Atomic(scalar) => {
+ format!("atomic<{}>", scalar.to_wgsl())
+ }
+ Ti::Pointer { base, .. } => {
+ let name = base.to_wgsl(gctx);
+ format!("ptr<{name}>")
+ }
+ Ti::ValuePointer { scalar, .. } => {
+ format!("ptr<{}>", scalar.to_wgsl())
+ }
+ Ti::Array { base, size, .. } => {
+ let base = base.to_wgsl(gctx);
+ match size {
+ crate::ArraySize::Constant(size) => format!("array<{base}, {size}>"),
+ crate::ArraySize::Dynamic => format!("array<{base}>"),
+ }
+ }
+ Ti::Struct { .. } => {
+ // TODO: Actually output the struct?
+ "struct".to_string()
+ }
+ Ti::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ let dim_suffix = match dim {
+ crate::ImageDimension::D1 => "_1d",
+ crate::ImageDimension::D2 => "_2d",
+ crate::ImageDimension::D3 => "_3d",
+ crate::ImageDimension::Cube => "_cube",
+ };
+ let array_suffix = if arrayed { "_array" } else { "" };
+
+ let class_suffix = match class {
+ crate::ImageClass::Sampled { multi: true, .. } => "_multisampled",
+ crate::ImageClass::Depth { multi: false } => "_depth",
+ crate::ImageClass::Depth { multi: true } => "_depth_multisampled",
+ crate::ImageClass::Sampled { multi: false, .. }
+ | crate::ImageClass::Storage { .. } => "",
+ };
+
+ let type_in_brackets = match class {
+ crate::ImageClass::Sampled { kind, .. } => {
+ // Note: The only valid widths are 4 bytes wide.
+ // The lexer has already verified this, so we can safely assume it here.
+ // https://gpuweb.github.io/gpuweb/wgsl/#sampled-texture-type
+ let element_type = crate::Scalar { kind, width: 4 }.to_wgsl();
+ format!("<{element_type}>")
+ }
+ crate::ImageClass::Depth { multi: _ } => String::new(),
+ crate::ImageClass::Storage { format, access } => {
+ if access.contains(crate::StorageAccess::STORE) {
+ format!("<{},write>", format.to_wgsl())
+ } else {
+ format!("<{}>", format.to_wgsl())
+ }
+ }
+ };
+
+ format!("texture{class_suffix}{dim_suffix}{array_suffix}{type_in_brackets}")
+ }
+ Ti::Sampler { .. } => "sampler".to_string(),
+ Ti::AccelerationStructure => "acceleration_structure".to_string(),
+ Ti::RayQuery => "ray_query".to_string(),
+ Ti::BindingArray { base, size, .. } => {
+ let member_type = &gctx.types[base];
+ let base = member_type.name.as_deref().unwrap_or("unknown");
+ match size {
+ crate::ArraySize::Constant(size) => format!("binding_array<{base}, {size}>"),
+ crate::ArraySize::Dynamic => format!("binding_array<{base}>"),
+ }
+ }
+ }
+ }
+}
+
+impl crate::Scalar {
+ /// Format a scalar kind+width as a type is written in wgsl.
+ ///
+ /// Examples: `f32`, `u64`, `bool`.
+ pub fn to_wgsl(self) -> String {
+ let prefix = match self.kind {
+ crate::ScalarKind::Sint => "i",
+ crate::ScalarKind::Uint => "u",
+ crate::ScalarKind::Float => "f",
+ crate::ScalarKind::Bool => return "bool".to_string(),
+ crate::ScalarKind::AbstractInt => return "{AbstractInt}".to_string(),
+ crate::ScalarKind::AbstractFloat => return "{AbstractFloat}".to_string(),
+ };
+ format!("{}{}", prefix, self.width * 8)
+ }
+}
+
+impl crate::StorageFormat {
+ pub const fn to_wgsl(self) -> &'static str {
+ use crate::StorageFormat as Sf;
+ match self {
+ Sf::R8Unorm => "r8unorm",
+ Sf::R8Snorm => "r8snorm",
+ Sf::R8Uint => "r8uint",
+ Sf::R8Sint => "r8sint",
+ Sf::R16Uint => "r16uint",
+ Sf::R16Sint => "r16sint",
+ Sf::R16Float => "r16float",
+ Sf::Rg8Unorm => "rg8unorm",
+ Sf::Rg8Snorm => "rg8snorm",
+ Sf::Rg8Uint => "rg8uint",
+ Sf::Rg8Sint => "rg8sint",
+ Sf::R32Uint => "r32uint",
+ Sf::R32Sint => "r32sint",
+ Sf::R32Float => "r32float",
+ Sf::Rg16Uint => "rg16uint",
+ Sf::Rg16Sint => "rg16sint",
+ Sf::Rg16Float => "rg16float",
+ Sf::Rgba8Unorm => "rgba8unorm",
+ Sf::Rgba8Snorm => "rgba8snorm",
+ Sf::Rgba8Uint => "rgba8uint",
+ Sf::Rgba8Sint => "rgba8sint",
+ Sf::Bgra8Unorm => "bgra8unorm",
+ Sf::Rgb10a2Uint => "rgb10a2uint",
+ Sf::Rgb10a2Unorm => "rgb10a2unorm",
+ Sf::Rg11b10Float => "rg11b10float",
+ Sf::Rg32Uint => "rg32uint",
+ Sf::Rg32Sint => "rg32sint",
+ Sf::Rg32Float => "rg32float",
+ Sf::Rgba16Uint => "rgba16uint",
+ Sf::Rgba16Sint => "rgba16sint",
+ Sf::Rgba16Float => "rgba16float",
+ Sf::Rgba32Uint => "rgba32uint",
+ Sf::Rgba32Sint => "rgba32sint",
+ Sf::Rgba32Float => "rgba32float",
+ Sf::R16Unorm => "r16unorm",
+ Sf::R16Snorm => "r16snorm",
+ Sf::Rg16Unorm => "rg16unorm",
+ Sf::Rg16Snorm => "rg16snorm",
+ Sf::Rgba16Unorm => "rgba16unorm",
+ Sf::Rgba16Snorm => "rgba16snorm",
+ }
+ }
+}
+
+mod tests {
+ #[test]
+ fn to_wgsl() {
+ use std::num::NonZeroU32;
+
+ let mut types = crate::UniqueArena::new();
+
+ let mytype1 = types.insert(
+ crate::Type {
+ name: Some("MyType1".to_string()),
+ inner: crate::TypeInner::Struct {
+ members: vec![],
+ span: 0,
+ },
+ },
+ Default::default(),
+ );
+ let mytype2 = types.insert(
+ crate::Type {
+ name: Some("MyType2".to_string()),
+ inner: crate::TypeInner::Struct {
+ members: vec![],
+ span: 0,
+ },
+ },
+ Default::default(),
+ );
+
+ let gctx = crate::proc::GlobalCtx {
+ types: &types,
+ constants: &crate::Arena::new(),
+ const_expressions: &crate::Arena::new(),
+ };
+ let array = crate::TypeInner::Array {
+ base: mytype1,
+ stride: 4,
+ size: crate::ArraySize::Constant(unsafe { NonZeroU32::new_unchecked(32) }),
+ };
+ assert_eq!(array.to_wgsl(&gctx), "array<MyType1, 32>");
+
+ let mat = crate::TypeInner::Matrix {
+ rows: crate::VectorSize::Quad,
+ columns: crate::VectorSize::Bi,
+ scalar: crate::Scalar::F64,
+ };
+ assert_eq!(mat.to_wgsl(&gctx), "mat2x4<f64>");
+
+ let ptr = crate::TypeInner::Pointer {
+ base: mytype2,
+ space: crate::AddressSpace::Storage {
+ access: crate::StorageAccess::default(),
+ },
+ };
+ assert_eq!(ptr.to_wgsl(&gctx), "ptr<MyType2>");
+
+ let img1 = crate::TypeInner::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: false,
+ class: crate::ImageClass::Sampled {
+ kind: crate::ScalarKind::Float,
+ multi: true,
+ },
+ };
+ assert_eq!(img1.to_wgsl(&gctx), "texture_multisampled_2d<f32>");
+
+ let img2 = crate::TypeInner::Image {
+ dim: crate::ImageDimension::Cube,
+ arrayed: true,
+ class: crate::ImageClass::Depth { multi: false },
+ };
+ assert_eq!(img2.to_wgsl(&gctx), "texture_depth_cube_array");
+
+ let img3 = crate::TypeInner::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: false,
+ class: crate::ImageClass::Depth { multi: true },
+ };
+ assert_eq!(img3.to_wgsl(&gctx), "texture_depth_multisampled_2d");
+
+ let array = crate::TypeInner::BindingArray {
+ base: mytype1,
+ size: crate::ArraySize::Constant(unsafe { NonZeroU32::new_unchecked(32) }),
+ };
+ assert_eq!(array.to_wgsl(&gctx), "binding_array<MyType1, 32>");
+ }
+}