summaryrefslogtreecommitdiffstats
path: root/testing/web-platform/tests/webnn/validation_tests/where.https.any.js
diff options
context:
space:
mode:
Diffstat (limited to 'testing/web-platform/tests/webnn/validation_tests/where.https.any.js')
-rw-r--r--testing/web-platform/tests/webnn/validation_tests/where.https.any.js129
1 files changed, 129 insertions, 0 deletions
diff --git a/testing/web-platform/tests/webnn/validation_tests/where.https.any.js b/testing/web-platform/tests/webnn/validation_tests/where.https.any.js
new file mode 100644
index 0000000000..a26fa24931
--- /dev/null
+++ b/testing/web-platform/tests/webnn/validation_tests/where.https.any.js
@@ -0,0 +1,129 @@
+// META: title=validation tests for WebNN API where operation
+// META: global=window,dedicatedworker
+// META: script=../resources/utils_validation.js
+
+'use strict';
+
+const kExampleConditionDescriptor = {
+ dataType: 'uint8',
+ dimensions: [2, 4]
+};
+const kExampleInputDescriptor = {
+ dataType: 'float32',
+ dimensions: [2, 4]
+};
+
+const tests = [
+ {
+ name:
+ '[where] Throw if the condition data type is not uint8.',
+ condition: {dataType: 'float32', dimensions: [2, 4]},
+ input: {dataType: 'float32', dimensions: [2, 4]},
+ other: {dataType: 'float32', dimensions: [2, 4]},
+ },
+ {
+ name:
+ '[where] Throw if the data types of input and other do not match',
+ condition: {dataType: 'uint8', dimensions: [2, 4]},
+ input: {dataType: 'float16', dimensions: [2, 4]},
+ other: {dataType: 'float32', dimensions: [2, 4]},
+ },
+ {
+ name:
+ '[where] Throw if the shapes of input and other are not broadcastable',
+ condition: {dataType: 'uint8', dimensions: [2, 4]},
+ input: {dataType: 'float32', dimensions: [2, 4]},
+ other: {dataType: 'float32', dimensions: [2, 3]},
+ },
+ {
+ name:
+ '[where] Throw if the condition shape is not broadcastable',
+ condition: {dataType: 'uint8', dimensions: [2, 4]},
+ input: {dataType: 'float32', dimensions: [2, 3]},
+ other: {dataType: 'float32', dimensions: [2, 1]},
+ },
+ {
+ name:
+ '[where] Test building where with 2-D condition, 2-D input and 2-D other using broadcast',
+ condition: {dataType: 'uint8', dimensions: [2, 1]},
+ input: {dataType: 'float32', dimensions: [2, 4]},
+ other: {dataType: 'float32', dimensions: [2, 4]},
+ output: {dataType: 'float32', dimensions: [2, 4]},
+ },
+ {
+ name:
+ '[where] Test building where with 2-D condition, 2-D input and 3-D other using broadcast',
+ condition: {dataType: 'uint8', dimensions: [1, 4]},
+ input: {dataType: 'float32', dimensions: [3, 4]},
+ other: {dataType: 'float32', dimensions: [2, 3, 4]},
+ output: {dataType: 'float32', dimensions: [2, 3, 4]},
+ },
+ {
+ name:
+ '[where] Test building where with 3-D condition, 3-D input and 2-D other using broadcast',
+ condition: {dataType: 'uint8', dimensions: [2, 1, 4]},
+ input: {dataType: 'float32', dimensions: [2, 3, 4]},
+ other: {dataType: 'float32', dimensions: [1, 4]},
+ output: {dataType: 'float32', dimensions: [2, 3, 4]},
+ },
+ {
+ name:
+ '[where] Test building where with 4-D condition, 3-D input and 2-D other using broadcast',
+ condition: {dataType: 'uint8', dimensions: [2, 3, 4, 5]},
+ input: {dataType: 'float32', dimensions: [3, 4, 5]},
+ other: {dataType: 'float32', dimensions: [4, 5]},
+ output: {dataType: 'float32', dimensions: [2, 3, 4, 5]},
+ }
+];
+
+tests.forEach(
+ test => promise_test(async t => {
+ const condition = builder.input('condition', {
+ dataType: test.condition.dataType,
+ dimensions: test.condition.dimensions
+ });
+ const input = builder.input(
+ 'input',
+ {dataType: test.input.dataType, dimensions: test.input.dimensions});
+ const other = builder.input(
+ 'other',
+ {dataType: test.other.dataType, dimensions: test.other.dimensions});
+ if (test.output) {
+ const output = builder.where(condition, input, other);
+ assert_equals(output.dataType(), test.output.dataType);
+ assert_array_equals(output.shape(), test.output.dimensions);
+ } else {
+ assert_throws_js(
+ TypeError, () => builder.where(condition, input, other));
+ }
+ }, test.name));
+
+multi_builder_test(async (t, builder, otherBuilder) => {
+ const conditionFromOtherBuilder =
+ otherBuilder.input('condition', kExampleConditionDescriptor);
+
+ const input = builder.input('input', kExampleInputDescriptor);
+ const other = builder.input('other', kExampleInputDescriptor);
+ assert_throws_js(
+ TypeError, () => builder.where(conditionFromOtherBuilder, input, other));
+}, '[where] throw if condition is from another builder');
+
+multi_builder_test(async (t, builder, otherBuilder) => {
+ const inputFromOtherBuilder =
+ otherBuilder.input('input', kExampleInputDescriptor);
+
+ const condition = builder.input('condition', kExampleConditionDescriptor);
+ const other = builder.input('other', kExampleInputDescriptor);
+ assert_throws_js(
+ TypeError, () => builder.where(condition, inputFromOtherBuilder, other));
+}, '[where] throw if input is from another builder');
+
+multi_builder_test(async (t, builder, otherBuilder) => {
+ const otherFromOtherBuilder =
+ otherBuilder.input('other', kExampleInputDescriptor);
+
+ const condition = builder.input('condition', kExampleConditionDescriptor);
+ const input = builder.input('input', kExampleInputDescriptor);
+ assert_throws_js(
+ TypeError, () => builder.where(condition, input, otherFromOtherBuilder));
+}, '[where] throw if other is from another builder');