summaryrefslogtreecommitdiffstats
path: root/testing/web-platform/tests/webnn/validation_tests/gruCell.https.any.js
diff options
context:
space:
mode:
Diffstat (limited to 'testing/web-platform/tests/webnn/validation_tests/gruCell.https.any.js')
-rw-r--r--testing/web-platform/tests/webnn/validation_tests/gruCell.https.any.js471
1 files changed, 471 insertions, 0 deletions
diff --git a/testing/web-platform/tests/webnn/validation_tests/gruCell.https.any.js b/testing/web-platform/tests/webnn/validation_tests/gruCell.https.any.js
new file mode 100644
index 0000000000..3cd9d32b07
--- /dev/null
+++ b/testing/web-platform/tests/webnn/validation_tests/gruCell.https.any.js
@@ -0,0 +1,471 @@
+// META: title=validation tests for WebNN API gruCell operation
+// META: global=window,dedicatedworker
+// META: script=../resources/utils_validation.js
+
+'use strict';
+
+const batchSize = 3, inputSize = 4, hiddenSize = 5;
+
+// Dimensions required of required inputs.
+const kValidInputDimensions = [batchSize, inputSize];
+const kValidWeightDimensions = [3 * hiddenSize, inputSize];
+const kValidRecurrentWeightDimensions = [3 * hiddenSize, hiddenSize];
+const kValidHiddenStateDimensions = [batchSize, hiddenSize];
+// Dimensions required of optional inputs.
+const kValidBiasDimensions = [3 * hiddenSize];
+const kValidRecurrentBiasDimensions = [3 * hiddenSize];
+// Dimensions required of required output.
+const kValidOutputDimensions = [batchSize, hiddenSize];
+
+// Example descriptors which are valid according to the above dimensions.
+const kExampleInputDescriptor = {
+ dataType: 'float32',
+ dimensions: kValidInputDimensions
+};
+const kExampleWeightDescriptor = {
+ dataType: 'float32',
+ dimensions: kValidWeightDimensions
+};
+const kExampleRecurrentWeightDescriptor = {
+ dataType: 'float32',
+ dimensions: kValidRecurrentWeightDimensions
+};
+const kExampleHiddenStateDescriptor = {
+ dataType: 'float32',
+ dimensions: kValidHiddenStateDimensions
+};
+const kExampleBiasDescriptor = {
+ dataType: 'float32',
+ dimensions: kValidBiasDimensions
+};
+const kExampleRecurrentBiasDescriptor = {
+ dataType: 'float32',
+ dimensions: kValidRecurrentBiasDimensions
+};
+const kExampleOutputDescriptor = {
+ dataType: 'float32',
+ dimensions: kValidOutputDimensions
+ };
+
+const tests = [
+ {
+ name: '[gruCell] Test with default options',
+ input: kExampleInputDescriptor,
+ weight: kExampleWeightDescriptor,
+ recurrentWeight: kExampleRecurrentWeightDescriptor,
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: hiddenSize,
+ output: kExampleOutputDescriptor
+ },
+ {
+ name: '[gruCell] Test with given options',
+ input: kExampleInputDescriptor,
+ weight: kExampleWeightDescriptor,
+ recurrentWeight: kExampleRecurrentWeightDescriptor,
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: hiddenSize,
+ options: {
+ bias: kExampleBiasDescriptor,
+ recurrentBias: kExampleRecurrentBiasDescriptor,
+ restAfter: true,
+ layout: 'rzn',
+ activations: ['sigmoid', 'relu']
+ },
+ output: kExampleOutputDescriptor
+ },
+ {
+ name: '[gruCell] Throw if hiddenSize equals to zero',
+ input: kExampleInputDescriptor,
+ weight: kExampleWeightDescriptor,
+ recurrentWeight: kExampleRecurrentWeightDescriptor,
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: 0
+ },
+ {
+ name: '[gruCell] Throw if hiddenSize is too large',
+ input: kExampleInputDescriptor,
+ weight: kExampleWeightDescriptor,
+ recurrentWeight: kExampleRecurrentWeightDescriptor,
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: 4294967295,
+ },
+ {
+ name:
+ '[gruCell] Throw if the data type of the inputs is not one of the floating point types',
+ input: { dataType: 'uint32', dimensions: kValidInputDimensions },
+ weight: { dataType: 'uint32', dimensions: kValidWeightDimensions },
+ recurrentWeight: {
+ dataType: 'uint32',
+ dimensions: kValidRecurrentWeightDimensions
+ },
+ hiddenState: {
+ dataType: 'uint32',
+ dimensions: kValidHiddenStateDimensions
+ },
+ hiddenSize: hiddenSize
+ },
+ {
+ name:
+ '[gruCell] Throw if the rank of input is not 2',
+ input: { dataType: 'float32', dimensions: [batchSize] },
+ weight: kExampleWeightDescriptor,
+ recurrentWeight: kExampleRecurrentWeightDescriptor,
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: hiddenSize
+ },
+ {
+ name:
+ '[gruCell] Throw if the input.dimensions[1] is incorrect',
+ input: { dataType: 'float32', dimensions: [inputSize, inputSize] },
+ weight: kExampleWeightDescriptor,
+ recurrentWeight: kExampleRecurrentWeightDescriptor,
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: hiddenSize
+ },
+ {
+ name: '[gruCell] Throw if data type of weight is not one of the floating point types',
+ input: kExampleInputDescriptor,
+ weight: {
+ dataType: 'int8',
+ dimensions: [3 * hiddenSize, inputSize]
+ },
+ recurrentWeight: kExampleRecurrentWeightDescriptor,
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: hiddenSize
+ },
+ {
+ name: '[gruCell] Throw if rank of weight is not 2',
+ input: kExampleInputDescriptor,
+ weight: {
+ dataType: 'float32',
+ dimensions: [3 * hiddenSize]
+ },
+ recurrentWeight: kExampleRecurrentWeightDescriptor,
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: hiddenSize
+ },
+ {
+ name: '[gruCell] Throw if weight.dimensions[0] is not 3 * hiddenSize',
+ input: kExampleInputDescriptor,
+ weight: {
+ dataType: 'float32',
+ dimensions: [4 * hiddenSize, inputSize]
+ },
+ recurrentWeight: kExampleRecurrentWeightDescriptor,
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: hiddenSize
+ },
+ {
+ name: '[gruCell] Throw if data type of recurrentWeight is not one of the floating point types',
+ input: kExampleInputDescriptor,
+ weight: kExampleWeightDescriptor,
+ recurrentWeight: {
+ dataType: 'int32',
+ dimensions: [3 * hiddenSize, hiddenSize]
+ },
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: hiddenSize
+ },
+ {
+ name:
+ '[gruCell] Throw if the rank of recurrentWeight is not 2',
+ input: kExampleInputDescriptor,
+ weight: kExampleWeightDescriptor,
+ recurrentWeight:
+ { dataType: 'float32', dimensions: [3 * hiddenSize] },
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: hiddenSize
+ },
+ {
+ name:
+ '[gruCell] Throw if the recurrentWeight.dimensions is invalid',
+ input: kExampleInputDescriptor,
+ weight: kExampleWeightDescriptor,
+ recurrentWeight:
+ { dataType: 'float32', dimensions: [4 * hiddenSize, inputSize] },
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: hiddenSize
+ },
+ {
+ name:
+ '[gruCell] Throw if data type of hiddenState is not one of the floating point types',
+ input: kExampleInputDescriptor,
+ weight: kExampleWeightDescriptor,
+ recurrentWeight:
+ kExampleRecurrentWeightDescriptor,
+ hiddenState: {
+ dataType: 'uint32',
+ dimensions: [batchSize, hiddenSize]
+ },
+ hiddenSize: hiddenSize
+ },
+ {
+ name:
+ '[gruCell] Throw if the rank of hiddenState is not 2',
+ input: kExampleInputDescriptor,
+ weight: kExampleWeightDescriptor,
+ recurrentWeight:
+ kExampleRecurrentWeightDescriptor,
+ hiddenState: {
+ dataType: 'float32',
+ dimensions: [hiddenSize]
+ },
+ hiddenSize: hiddenSize
+ },
+ {
+ name:
+ '[gruCell] Throw if the hiddenState.dimensions is invalid',
+ input: kExampleInputDescriptor,
+ weight: kExampleWeightDescriptor,
+ recurrentWeight: kExampleRecurrentWeightDescriptor,
+ hiddenState: {
+ dataType: 'float32',
+ dimensions: [batchSize, 3 * hiddenSize]
+ },
+ hiddenSize: hiddenSize
+ },
+ {
+ name:
+ '[gruCell] Throw if the size of options.activations is not 2',
+ input: kExampleInputDescriptor,
+ weight: kExampleWeightDescriptor,
+ recurrentWeight: kExampleRecurrentWeightDescriptor,
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: hiddenSize,
+ options: { activations: ['sigmoid', 'tanh', 'relu'] }
+ },
+ {
+ name:
+ '[gruCell] Throw if data type of options.bias is not one of the floating point types',
+ input: kExampleInputDescriptor,
+ weight: kExampleWeightDescriptor,
+ recurrentWeight: kExampleRecurrentWeightDescriptor,
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: hiddenSize,
+ options: { bias: { dataType: 'uint8', dimensions: [3 * hiddenSize] } }
+ },
+ {
+ name:
+ '[gruCell] Throw if the rank of options.bias is not 1',
+ input: kExampleInputDescriptor,
+ weight: kExampleWeightDescriptor,
+ recurrentWeight: kExampleRecurrentWeightDescriptor,
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: hiddenSize,
+ options: { bias: { dataType: 'float32', dimensions: [batchSize, 3 * hiddenSize] } }
+ },
+ {
+ name:
+ '[gruCell] Throw if options.bias.dimensions[0] is not 3 * hiddenSize',
+ input: kExampleInputDescriptor,
+ weight: kExampleWeightDescriptor,
+ recurrentWeight: kExampleRecurrentWeightDescriptor,
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: hiddenSize,
+ options: { bias: { dataType: 'float32', dimensions: [2 * hiddenSize] } }
+ },
+ {
+ name:
+ '[gruCell] Throw if data type of options.recurrentBias is not one of the floating point types',
+ input: kExampleInputDescriptor,
+ weight: kExampleWeightDescriptor,
+ recurrentWeight: kExampleRecurrentWeightDescriptor,
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: hiddenSize,
+ options: { recurrentBias: { dataType: 'int8', dimensions: [3 * hiddenSize] } }
+ },
+ {
+ name:
+ '[gruCell] Throw if the rank of options.recurrentBias is not 1',
+ input: kExampleInputDescriptor,
+ weight: kExampleWeightDescriptor,
+ recurrentWeight: kExampleRecurrentWeightDescriptor,
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: hiddenSize,
+ options: { recurrentBias: { dataType: 'float32', dimensions: [batchSize, 3 * hiddenSize] } }
+ },
+ {
+ name:
+ '[gruCell] Throw if options.recurrentBias.dimensions[0] is not 3 * hiddenSize',
+ input: kExampleInputDescriptor,
+ weight: kExampleWeightDescriptor,
+ recurrentWeight: kExampleRecurrentWeightDescriptor,
+ hiddenState: kExampleHiddenStateDescriptor,
+ hiddenSize: hiddenSize,
+ options: {
+ recurrentBias: { dataType: 'float16', dimensions: [4 * hiddenSize] }
+ }
+ }
+];
+
+tests.forEach(
+ test => promise_test(async t => {
+ const input = builder.input(
+ 'input',
+ { dataType: test.input.dataType, dimensions: test.input.dimensions });
+ const weight = builder.input(
+ 'weight',
+ { dataType: test.weight.dataType, dimensions: test.weight.dimensions });
+ const recurrentWeight = builder.input('recurrentWeight', {
+ dataType: test.recurrentWeight.dataType,
+ dimensions: test.recurrentWeight.dimensions
+ });
+ const hiddenState = builder.input('hiddenState', {
+ dataType: test.hiddenState.dataType,
+ dimensions: test.hiddenState.dimensions
+ });
+
+ const options = {};
+ if (test.options) {
+ if (test.options.bias) {
+ options.bias = builder.input('bias', {
+ dataType: test.options.bias.dataType,
+ dimensions: test.options.bias.dimensions
+ });
+ }
+ if (test.options.recurrentBias) {
+ options.bias = builder.input('recurrentBias', {
+ dataType: test.options.recurrentBias.dataType,
+ dimensions: test.options.recurrentBias.dimensions
+ });
+ }
+ if (test.options.resetAfter) {
+ options.resetAfter = test.options.resetAfter;
+ }
+ if (test.options.layout) {
+ options.layout = test.options.layout;
+ }
+ if (test.options.activations) {
+ options.activations = [];
+ test.options.activations.forEach(
+ activation => options.activations.push(builder[activation]()));
+ }
+ }
+
+ if (test.output) {
+ const output = builder.gruCell(
+ input, weight, recurrentWeight, hiddenState, test.hiddenSize,
+ options);
+ assert_equals(output.dataType(), test.output.dataType);
+ assert_array_equals(output.shape(), test.output.dimensions);
+ } else {
+ assert_throws_js(
+ TypeError,
+ () => builder.gruCell(
+ input, weight, recurrentWeight, hiddenState, test.hiddenSize,
+ options));
+ }
+ }, test.name));
+
+multi_builder_test(async (t, builder, otherBuilder) => {
+ const inputFromOtherBuilder =
+ otherBuilder.input('input', kExampleInputDescriptor);
+
+ const weight = builder.input('weight', kExampleWeightDescriptor);
+ const recurrentWeight =
+ builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
+ const hiddenState =
+ builder.input('hiddenState', kExampleHiddenStateDescriptor);
+ assert_throws_js(
+ TypeError,
+ () => builder.gruCell(
+ inputFromOtherBuilder, weight, recurrentWeight, hiddenState,
+ hiddenSize));
+}, '[gruCell] throw if input is from another builder');
+
+multi_builder_test(async (t, builder, otherBuilder) => {
+ const weightFromOtherBuilder =
+ otherBuilder.input('weight', kExampleWeightDescriptor);
+
+ const input = builder.input('input', kExampleInputDescriptor);
+ const recurrentWeight =
+ builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
+ const hiddenState =
+ builder.input('hiddenState', kExampleHiddenStateDescriptor);
+ assert_throws_js(
+ TypeError,
+ () => builder.gruCell(
+ input, weightFromOtherBuilder, recurrentWeight, hiddenState,
+ hiddenSize));
+}, '[gruCell] throw if weight is from another builder');
+
+multi_builder_test(async (t, builder, otherBuilder) => {
+ const recurrentWeightFromOtherBuilder =
+ otherBuilder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
+
+ const input = builder.input('input', kExampleInputDescriptor);
+ const weight = builder.input('weight', kExampleWeightDescriptor);
+ const hiddenState =
+ builder.input('hiddenState', kExampleHiddenStateDescriptor);
+ assert_throws_js(
+ TypeError,
+ () => builder.gruCell(
+ input, weight, recurrentWeightFromOtherBuilder, hiddenState,
+ hiddenSize));
+}, '[gruCell] throw if recurrentWeight is from another builder');
+
+multi_builder_test(async (t, builder, otherBuilder) => {
+ const hiddenStateFromOtherBuilder =
+ otherBuilder.input('hiddenState', kExampleHiddenStateDescriptor);
+
+ const input = builder.input('input', kExampleInputDescriptor);
+ const weight = builder.input('weight', kExampleWeightDescriptor);
+ const recurrentWeight =
+ builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
+ assert_throws_js(
+ TypeError,
+ () => builder.gruCell(
+ input, weight, recurrentWeight, hiddenStateFromOtherBuilder,
+ hiddenSize));
+}, '[gruCell] throw if hiddenState is from another builder');
+
+multi_builder_test(async (t, builder, otherBuilder) => {
+ const biasFromOtherBuilder =
+ otherBuilder.input('bias', kExampleBiasDescriptor);
+ const options = {bias: biasFromOtherBuilder};
+
+ const input = builder.input('input', kExampleInputDescriptor);
+ const weight = builder.input('weight', kExampleWeightDescriptor);
+ const recurrentWeight =
+ builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
+ const hiddenState =
+ builder.input('hiddenState', kExampleHiddenStateDescriptor);
+ assert_throws_js(
+ TypeError,
+ () => builder.gruCell(
+ input, weight, recurrentWeight, hiddenState, hiddenSize, options));
+}, '[gruCell] throw if bias option is from another builder');
+
+multi_builder_test(async (t, builder, otherBuilder) => {
+ const recurrentBiasFromOtherBuilder =
+ otherBuilder.input('recurrentBias', kExampleRecurrentBiasDescriptor);
+ const options = {recurrentBias: recurrentBiasFromOtherBuilder};
+
+ const input = builder.input('input', kExampleInputDescriptor);
+ const weight = builder.input('weight', kExampleWeightDescriptor);
+ const recurrentWeight =
+ builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
+ const hiddenState =
+ builder.input('hiddenState', kExampleHiddenStateDescriptor);
+ assert_throws_js(
+ TypeError,
+ () => builder.gruCell(
+ input, weight, recurrentWeight, hiddenState, hiddenSize, options));
+}, '[gruCell] throw if recurrentBias option is from another builder');
+
+multi_builder_test(async (t, builder, otherBuilder) => {
+ const activation = builder.clamp();
+ const activationFromOtherBuilder = otherBuilder.clamp();
+ const options = {activations: [activation, activationFromOtherBuilder]};
+
+ const input = builder.input('input', kExampleInputDescriptor);
+ const weight = builder.input('weight', kExampleWeightDescriptor);
+ const recurrentWeight =
+ builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
+ const hiddenState =
+ builder.input('hiddenState', kExampleHiddenStateDescriptor);
+ assert_throws_js(
+ TypeError,
+ () => builder.gruCell(
+ input, weight, recurrentWeight, hiddenState, hiddenSize, options));
+}, '[gruCell] throw if any activation option is from another builder');