summaryrefslogtreecommitdiffstats
path: root/third_party/rust/naga/src/front/wgsl/tests.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/naga/src/front/wgsl/tests.rs')
-rw-r--r--third_party/rust/naga/src/front/wgsl/tests.rs637
1 files changed, 637 insertions, 0 deletions
diff --git a/third_party/rust/naga/src/front/wgsl/tests.rs b/third_party/rust/naga/src/front/wgsl/tests.rs
new file mode 100644
index 0000000000..eb2f8a2eb3
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/tests.rs
@@ -0,0 +1,637 @@
+use super::parse_str;
+
+#[test]
+fn parse_comment() {
+ parse_str(
+ "//
+ ////
+ ///////////////////////////////////////////////////////// asda
+ //////////////////// dad ////////// /
+ /////////////////////////////////////////////////////////////////////////////////////////////////////
+ //
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_types() {
+ parse_str("const a : i32 = 2;").unwrap();
+ assert!(parse_str("const a : x32 = 2;").is_err());
+ parse_str("var t: texture_2d<f32>;").unwrap();
+ parse_str("var t: texture_cube_array<i32>;").unwrap();
+ parse_str("var t: texture_multisampled_2d<u32>;").unwrap();
+ parse_str("var t: texture_storage_1d<rgba8uint,write>;").unwrap();
+ parse_str("var t: texture_storage_3d<r32float,read>;").unwrap();
+}
+
+#[test]
+fn parse_type_inference() {
+ parse_str(
+ "
+ fn foo() {
+ let a = 2u;
+ let b: u32 = a;
+ var x = 3.;
+ var y = vec2<f32>(1, 2);
+ }",
+ )
+ .unwrap();
+ assert!(parse_str(
+ "
+ fn foo() { let c : i32 = 2.0; }",
+ )
+ .is_err());
+}
+
+#[test]
+fn parse_type_cast() {
+ parse_str(
+ "
+ const a : i32 = 2;
+ fn main() {
+ var x: f32 = f32(a);
+ x = f32(i32(a + 1) / 2);
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ fn main() {
+ let x: vec2<f32> = vec2<f32>(1.0, 2.0);
+ let y: vec2<u32> = vec2<u32>(x);
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ fn main() {
+ let x: vec2<f32> = vec2<f32>(0.0);
+ }
+ ",
+ )
+ .unwrap();
+ assert!(parse_str(
+ "
+ fn main() {
+ let x: vec2<f32> = vec2<f32>(0i, 0i);
+ }
+ ",
+ )
+ .is_err());
+}
+
+#[test]
+fn parse_struct() {
+ parse_str(
+ "
+ struct Foo { x: i32 }
+ struct Bar {
+ @size(16) x: vec2<i32>,
+ @align(16) y: f32,
+ @size(32) @align(128) z: vec3<f32>,
+ };
+ struct Empty {}
+ var<storage,read_write> s: Foo;
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_standard_fun() {
+ parse_str(
+ "
+ fn main() {
+ var x: i32 = min(max(1, 2), 3);
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_statement() {
+ parse_str(
+ "
+ fn main() {
+ ;
+ {}
+ {;}
+ }
+ ",
+ )
+ .unwrap();
+
+ parse_str(
+ "
+ fn foo() {}
+ fn bar() { foo(); }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_if() {
+ parse_str(
+ "
+ fn main() {
+ if true {
+ discard;
+ } else {}
+ if 0 != 1 {}
+ if false {
+ return;
+ } else if true {
+ return;
+ } else {}
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_parentheses_if() {
+ parse_str(
+ "
+ fn main() {
+ if (true) {
+ discard;
+ } else {}
+ if (0 != 1) {}
+ if (false) {
+ return;
+ } else if (true) {
+ return;
+ } else {}
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_loop() {
+ parse_str(
+ "
+ fn main() {
+ var i: i32 = 0;
+ loop {
+ if i == 1 { break; }
+ continuing { i = 1; }
+ }
+ loop {
+ if i == 0 { continue; }
+ break;
+ }
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ fn main() {
+ var found: bool = false;
+ var i: i32 = 0;
+ while !found {
+ if i == 10 {
+ found = true;
+ }
+
+ i = i + 1;
+ }
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ fn main() {
+ while true {
+ break;
+ }
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ fn main() {
+ var a: i32 = 0;
+ for(var i: i32 = 0; i < 4; i = i + 1) {
+ a = a + 2;
+ }
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ fn main() {
+ for(;;) {
+ break;
+ }
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_switch() {
+ parse_str(
+ "
+ fn main() {
+ var pos: f32;
+ switch (3) {
+ case 0, 1: { pos = 0.0; }
+ case 2: { pos = 1.0; }
+ default: { pos = 3.0; }
+ }
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_switch_optional_colon_in_case() {
+ parse_str(
+ "
+ fn main() {
+ var pos: f32;
+ switch (3) {
+ case 0, 1 { pos = 0.0; }
+ case 2 { pos = 1.0; }
+ default { pos = 3.0; }
+ }
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_switch_default_in_case() {
+ parse_str(
+ "
+ fn main() {
+ var pos: f32;
+ switch (3) {
+ case 0, 1: { pos = 0.0; }
+ case 2: {}
+ case default, 3: { pos = 3.0; }
+ }
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_parentheses_switch() {
+ parse_str(
+ "
+ fn main() {
+ var pos: f32;
+ switch pos > 1.0 {
+ default: { pos = 3.0; }
+ }
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_texture_load() {
+ parse_str(
+ "
+ var t: texture_3d<u32>;
+ fn foo() {
+ let r: vec4<u32> = textureLoad(t, vec3<u32>(0u, 1u, 2u), 1);
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ var t: texture_multisampled_2d_array<i32>;
+ fn foo() {
+ let r: vec4<i32> = textureLoad(t, vec2<i32>(10, 20), 2, 3);
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ var t: texture_storage_1d_array<r32float,read>;
+ fn foo() {
+ let r: vec4<f32> = textureLoad(t, 10, 2);
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_texture_store() {
+ parse_str(
+ "
+ var t: texture_storage_2d<rgba8unorm,write>;
+ fn foo() {
+ textureStore(t, vec2<i32>(10, 20), vec4<f32>(0.0, 1.0, 2.0, 3.0));
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_texture_query() {
+ parse_str(
+ "
+ var t: texture_multisampled_2d_array<f32>;
+ fn foo() {
+ var dim: vec2<u32> = textureDimensions(t);
+ dim = textureDimensions(t, 0);
+ let layers: u32 = textureNumLayers(t);
+ let samples: u32 = textureNumSamples(t);
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_postfix() {
+ parse_str(
+ "fn foo() {
+ let x: f32 = vec4<f32>(1.0, 2.0, 3.0, 4.0).xyz.rgbr.aaaa.wz.g;
+ let y: f32 = fract(vec2<f32>(0.5, x)).x;
+ }",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_expressions() {
+ parse_str("fn foo() {
+ let x: f32 = select(0.0, 1.0, true);
+ let y: vec2<f32> = select(vec2<f32>(1.0, 1.0), vec2<f32>(x, x), vec2<bool>(x < 0.5, x > 0.5));
+ let z: bool = !(0.0 == 1.0);
+ }").unwrap();
+}
+
+#[test]
+fn binary_expression_mixed_scalar_and_vector_operands() {
+ for (operand, expect_splat) in [
+ ('<', false),
+ ('>', false),
+ ('&', false),
+ ('|', false),
+ ('+', true),
+ ('-', true),
+ ('*', false),
+ ('/', true),
+ ('%', true),
+ ] {
+ let module = parse_str(&format!(
+ "
+ @fragment
+ fn main(@location(0) some_vec: vec3<f32>) -> @location(0) vec4<f32> {{
+ if (all(1.0 {operand} some_vec)) {{
+ return vec4(0.0);
+ }}
+ return vec4(1.0);
+ }}
+ "
+ ))
+ .unwrap();
+
+ let expressions = &&module.entry_points[0].function.expressions;
+
+ let found_expressions = expressions
+ .iter()
+ .filter(|&(_, e)| {
+ if let crate::Expression::Binary { left, .. } = *e {
+ matches!(
+ (expect_splat, &expressions[left]),
+ (false, &crate::Expression::Literal(crate::Literal::F32(..)))
+ | (true, &crate::Expression::Splat { .. })
+ )
+ } else {
+ false
+ }
+ })
+ .count();
+
+ assert_eq!(
+ found_expressions,
+ 1,
+ "expected `{operand}` expression {} splat",
+ if expect_splat { "with" } else { "without" }
+ );
+ }
+
+ let module = parse_str(
+ "@fragment
+ fn main(mat: mat3x3<f32>) {
+ let vec = vec3<f32>(1.0, 1.0, 1.0);
+ let result = mat / vec;
+ }",
+ )
+ .unwrap();
+ let expressions = &&module.entry_points[0].function.expressions;
+ let found_splat = expressions.iter().any(|(_, e)| {
+ if let crate::Expression::Binary { left, .. } = *e {
+ matches!(&expressions[left], &crate::Expression::Splat { .. })
+ } else {
+ false
+ }
+ });
+ assert!(!found_splat, "'mat / vec' should not be splatted");
+}
+
+#[test]
+fn parse_pointers() {
+ parse_str(
+ "fn foo(a: ptr<private, f32>) -> f32 { return *a; }
+ fn bar() {
+ var x: f32 = 1.0;
+ let px = &x;
+ let py = foo(px);
+ }",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_struct_instantiation() {
+ parse_str(
+ "
+ struct Foo {
+ a: f32,
+ b: vec3<f32>,
+ }
+
+ @fragment
+ fn fs_main() {
+ var foo: Foo = Foo(0.0, vec3<f32>(0.0, 1.0, 42.0));
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_array_length() {
+ parse_str(
+ "
+ struct Foo {
+ data: array<u32>
+ } // this is used as both input and output for convenience
+
+ @group(0) @binding(0)
+ var<storage> foo: Foo;
+
+ @group(0) @binding(1)
+ var<storage> bar: array<u32>;
+
+ fn baz() {
+ var x: u32 = arrayLength(foo.data);
+ var y: u32 = arrayLength(bar);
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_storage_buffers() {
+ parse_str(
+ "
+ @group(0) @binding(0)
+ var<storage> foo: array<u32>;
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ @group(0) @binding(0)
+ var<storage,read> foo: array<u32>;
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ @group(0) @binding(0)
+ var<storage,write> foo: array<u32>;
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ @group(0) @binding(0)
+ var<storage,read_write> foo: array<u32>;
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_alias() {
+ parse_str(
+ "
+ alias Vec4 = vec4<f32>;
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_texture_load_store_expecting_four_args() {
+ for (func, texture) in [
+ (
+ "textureStore",
+ "texture_storage_2d_array<rg11b10float, write>",
+ ),
+ ("textureLoad", "texture_2d_array<i32>"),
+ ] {
+ let error = parse_str(&format!(
+ "
+ @group(0) @binding(0) var tex_los_res: {texture};
+ @compute
+ @workgroup_size(1)
+ fn main(@builtin(global_invocation_id) id: vec3<u32>) {{
+ var color = vec4(1, 1, 1, 1);
+ {func}(tex_los_res, id, color);
+ }}
+ "
+ ))
+ .unwrap_err();
+ assert_eq!(
+ error.message(),
+ "wrong number of arguments: expected 4, found 3"
+ );
+ }
+}
+
+#[test]
+fn parse_repeated_attributes() {
+ use crate::{
+ front::wgsl::{error::Error, Frontend},
+ Span,
+ };
+
+ let template_vs = "@vertex fn vs() -> __REPLACE__ vec4<f32> { return vec4<f32>(0.0); }";
+ let template_struct = "struct A { __REPLACE__ data: vec3<f32> }";
+ let template_resource = "__REPLACE__ var tex_los_res: texture_2d_array<i32>;";
+ let template_stage = "__REPLACE__ fn vs() -> vec4<f32> { return vec4<f32>(0.0); }";
+ for (attribute, template) in [
+ ("align(16)", template_struct),
+ ("binding(0)", template_resource),
+ ("builtin(position)", template_vs),
+ ("compute", template_stage),
+ ("fragment", template_stage),
+ ("group(0)", template_resource),
+ ("interpolate(flat)", template_vs),
+ ("invariant", template_vs),
+ ("location(0)", template_vs),
+ ("size(16)", template_struct),
+ ("vertex", template_stage),
+ ("early_depth_test(less_equal)", template_resource),
+ ("workgroup_size(1)", template_stage),
+ ] {
+ let shader = template.replace("__REPLACE__", &format!("@{attribute} @{attribute}"));
+ let name_length = attribute.rfind('(').unwrap_or(attribute.len()) as u32;
+ let span_start = shader.rfind(attribute).unwrap() as u32;
+ let span_end = span_start + name_length;
+ let expected_span = Span::new(span_start, span_end);
+
+ let result = Frontend::new().inner(&shader);
+ assert!(matches!(
+ result.unwrap_err(),
+ Error::RepeatedAttribute(span) if span == expected_span
+ ));
+ }
+}
+
+#[test]
+fn parse_missing_workgroup_size() {
+ use crate::{
+ front::wgsl::{error::Error, Frontend},
+ Span,
+ };
+
+ let shader = "@compute fn vs() -> vec4<f32> { return vec4<f32>(0.0); }";
+ let result = Frontend::new().inner(shader);
+ assert!(matches!(
+ result.unwrap_err(),
+ Error::MissingWorkgroupSize(span) if span == Span::new(1, 8)
+ ));
+}