summaryrefslogtreecommitdiffstats
path: root/testing/web-platform/tests/webnn/validation_tests/lstmCell.https.any.js
diff options
context:
space:
mode:
Diffstat (limited to 'testing/web-platform/tests/webnn/validation_tests/lstmCell.https.any.js')
-rw-r--r--testing/web-platform/tests/webnn/validation_tests/lstmCell.https.any.js600
1 files changed, 600 insertions, 0 deletions
diff --git a/testing/web-platform/tests/webnn/validation_tests/lstmCell.https.any.js b/testing/web-platform/tests/webnn/validation_tests/lstmCell.https.any.js
new file mode 100644
index 0000000000..c3769c828d
--- /dev/null
+++ b/testing/web-platform/tests/webnn/validation_tests/lstmCell.https.any.js
@@ -0,0 +1,600 @@
+// META: title=validation tests for WebNN API lstmCell operation
+// META: global=window,dedicatedworker
+// META: script=../resources/utils_validation.js
+
+'use strict';
+
+const batchSize = 3, inputSize = 4, hiddenSize = 5;
+
+// Dimensions required of required inputs.
+const kValidInputDimensions = [batchSize, inputSize];
+const kValidWeightDimensions = [4 * hiddenSize, inputSize];
+const kValidRecurrentWeightDimensions = [4 * hiddenSize, hiddenSize];
+const kValidHiddenStateDimensions = [batchSize, hiddenSize];
+const kValidCellStateDimensions = [batchSize, hiddenSize];
+// Dimensions required of optional inputs.
+const kValidBiasDimensions = [4 * hiddenSize];
+const kValidPeepholeWeightDimensions = [3 * hiddenSize];
+
+// Example descriptors which are valid according to the above dimensions.
+const kExampleInputDescriptor = {
+ dataType: 'float32',
+ dimensions: kValidInputDimensions
+};
+const kExampleWeightDescriptor = {
+ dataType: 'float32',
+ dimensions: kValidWeightDimensions
+};
+const kExampleRecurrentWeightDescriptor = {
+ dataType: 'float32',
+ dimensions: kValidRecurrentWeightDimensions
+};
+const kExampleHiddenStateDescriptor = {
+ dataType: 'float32',
+ dimensions: kValidHiddenStateDimensions
+};
+const kExampleCellStateDescriptor = {
+ dataType: 'float32',
+ dimensions: kValidCellStateDimensions
+};
+const kExampleBiasDescriptor = {
+ dataType: 'float32',
+ dimensions: kValidBiasDimensions
+};
+const kExamplePeepholeWeightDescriptor = {
+ dataType: 'float32',
+ dimensions: kValidPeepholeWeightDimensions
+};
+
+multi_builder_test(async (t, builder, otherBuilder) => {
+ const inputFromOtherBuilder =
+ otherBuilder.input('input', kExampleInputDescriptor);
+
+ const weight = builder.input('weight', kExampleWeightDescriptor);
+ const recurrentWeight =
+ builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
+ const hiddenState =
+ builder.input('hiddenState', kExampleHiddenStateDescriptor);
+ const cellState = builder.input('cellState', kExampleCellStateDescriptor);
+ assert_throws_js(
+ TypeError,
+ () => builder.lstmCell(
+ inputFromOtherBuilder, weight, recurrentWeight, hiddenState,
+ cellState, hiddenSize));
+}, '[lstmCell] throw if input is from another builder');
+
+multi_builder_test(async (t, builder, otherBuilder) => {
+ const weightFromOtherBuilder =
+ otherBuilder.input('weight', kExampleWeightDescriptor);
+
+ const input = builder.input('input', kExampleInputDescriptor);
+ const recurrentWeight =
+ builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
+ const hiddenState =
+ builder.input('hiddenState', kExampleHiddenStateDescriptor);
+ const cellState = builder.input('cellState', kExampleCellStateDescriptor);
+ assert_throws_js(
+ TypeError,
+ () => builder.lstmCell(
+ input, weightFromOtherBuilder, recurrentWeight, hiddenState,
+ cellState, hiddenSize));
+}, '[lstmCell] throw if weight is from another builder');
+
+multi_builder_test(async (t, builder, otherBuilder) => {
+ const recurrentWeightFromOtherBuilder =
+ otherBuilder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
+
+ const input = builder.input('input', kExampleInputDescriptor);
+ const weight = builder.input('weight', kExampleWeightDescriptor);
+ const hiddenState =
+ builder.input('hiddenState', kExampleHiddenStateDescriptor);
+ const cellState = builder.input('cellState', kExampleCellStateDescriptor);
+ assert_throws_js(
+ TypeError,
+ () => builder.lstmCell(
+ input, weight, recurrentWeightFromOtherBuilder, hiddenState,
+ cellState, hiddenSize));
+}, '[lstmCell] throw if recurrentWeight is from another builder');
+
+multi_builder_test(async (t, builder, otherBuilder) => {
+ const hiddenStateFromOtherBuilder =
+ otherBuilder.input('hiddenState', kExampleHiddenStateDescriptor);
+
+ const input = builder.input('input', kExampleInputDescriptor);
+ const weight = builder.input('weight', kExampleWeightDescriptor);
+ const recurrentWeight =
+ builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
+ const cellState = builder.input('cellState', kExampleCellStateDescriptor);
+ assert_throws_js(
+ TypeError,
+ () => builder.lstmCell(
+ input, weight, recurrentWeight, hiddenStateFromOtherBuilder,
+ cellState, hiddenSize));
+}, '[lstmCell] throw if hiddenState is from another builder');
+
+multi_builder_test(async (t, builder, otherBuilder) => {
+ const cellStateFromOtherBuilder =
+ otherBuilder.input('cellState', kExampleCellStateDescriptor);
+
+ const input = builder.input('input', kExampleInputDescriptor);
+ const weight = builder.input('weight', kExampleWeightDescriptor);
+ const recurrentWeight =
+ builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
+ const hiddenState =
+ builder.input('hiddenState', kExampleHiddenStateDescriptor);
+ assert_throws_js(
+ TypeError,
+ () => builder.lstmCell(
+ input, weight, recurrentWeight, hiddenState,
+ cellStateFromOtherBuilder, hiddenSize));
+}, '[lstmCell] throw if cellState is from another builder');
+
+multi_builder_test(async (t, builder, otherBuilder) => {
+ const biasFromOtherBuilder =
+ otherBuilder.input('bias', kExampleBiasDescriptor);
+ const options = {bias: biasFromOtherBuilder};
+
+ const input = builder.input('input', kExampleInputDescriptor);
+ const weight = builder.input('weight', kExampleWeightDescriptor);
+ const recurrentWeight =
+ builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
+ const hiddenState =
+ builder.input('hiddenState', kExampleHiddenStateDescriptor);
+ const cellState = builder.input('cellState', kExampleCellStateDescriptor);
+ assert_throws_js(
+ TypeError,
+ () => builder.lstmCell(
+ input, weight, recurrentWeight, hiddenState, cellState, hiddenSize,
+ options));
+}, '[lstmCell] throw if bias option is from another builder');
+
+multi_builder_test(async (t, builder, otherBuilder) => {
+ const recurrentBiasFromOtherBuilder =
+ otherBuilder.input('bias', kExampleBiasDescriptor);
+ const options = {recurrentBias: recurrentBiasFromOtherBuilder};
+
+ const input = builder.input('input', kExampleInputDescriptor);
+ const weight = builder.input('weight', kExampleWeightDescriptor);
+ const recurrentWeight =
+ builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
+ const hiddenState =
+ builder.input('hiddenState', kExampleHiddenStateDescriptor);
+ const cellState = builder.input('cellState', kExampleCellStateDescriptor);
+ assert_throws_js(
+ TypeError,
+ () => builder.lstmCell(
+ input, weight, recurrentWeight, hiddenState, cellState, hiddenSize,
+ options));
+}, '[lstmCell] throw if recurrentBias option is from another builder');
+
+multi_builder_test(async (t, builder, otherBuilder) => {
+ const peepholeWeightFromOtherBuilder =
+ otherBuilder.input('peepholeWeight', kExamplePeepholeWeightDescriptor);
+ const options = {peepholeWeight: peepholeWeightFromOtherBuilder};
+
+ const input = builder.input('input', kExampleInputDescriptor);
+ const weight = builder.input('weight', kExampleWeightDescriptor);
+ const recurrentWeight =
+ builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
+ const hiddenState =
+ builder.input('hiddenState', kExampleHiddenStateDescriptor);
+ const cellState = builder.input('cellState', kExampleCellStateDescriptor);
+ assert_throws_js(
+ TypeError,
+ () => builder.lstmCell(
+ input, weight, recurrentWeight, hiddenState, cellState, hiddenSize,
+ options));
+}, '[lstmCell] throw if peepholeWeight option is from another builder');
+
+multi_builder_test(async (t, builder, otherBuilder) => {
+ const activation = builder.clamp();
+ const activationFromOtherBuilder = otherBuilder.clamp();
+ const options = {activations: [activation, activationFromOtherBuilder]};
+
+ const input = builder.input('input', kExampleInputDescriptor);
+ const weight = builder.input('weight', kExampleWeightDescriptor);
+ const recurrentWeight =
+ builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
+ const hiddenState =
+ builder.input('hiddenState', kExampleHiddenStateDescriptor);
+ const cellState = builder.input('cellState', kExampleCellStateDescriptor);
+ assert_throws_js(
+ TypeError,
+ () => builder.lstmCell(
+ input, weight, recurrentWeight, hiddenState, cellState, hiddenSize,
+ options));
+}, '[lstmCell] throw if activation option is from another builder');
+
+const tests = [
+ {
+ name: '[lstmCell] Test with default options',
+ input: {dataType: 'float16', dimensions: [batchSize, inputSize]},
+ weight: {dataType: 'float16', dimensions: [4 * hiddenSize, inputSize]},
+ recurrentWeight:
+ {dataType: 'float16', dimensions: [4 * hiddenSize, hiddenSize]},
+ hiddenState: {dataType: 'float16', dimensions: [batchSize, hiddenSize]},
+ cellState: {dataType: 'float16', dimensions: [batchSize, hiddenSize]},
+ hiddenSize: hiddenSize,
+ outputs: [
+ {dataType: 'float16', dimensions: [batchSize, hiddenSize]},
+ {dataType: 'float16', dimensions: [batchSize, hiddenSize]}
+ ]
+ },
+ {
+ name: '[lstmCell] Test with given options',
+ input: {dataType: 'float32', dimensions: [batchSize, inputSize]},
+ weight: {dataType: 'float32', dimensions: [4 * hiddenSize, inputSize]},
+ recurrentWeight:
+ {dataType: 'float32', dimensions: [4 * hiddenSize, hiddenSize]},
+ hiddenState: {dataType: 'float32', dimensions: [batchSize, hiddenSize]},
+ cellState: {dataType: 'float32', dimensions: [batchSize, hiddenSize]},
+ hiddenSize: hiddenSize,
+ options: {
+ bias: {dataType: 'float32', dimensions: [4 * hiddenSize]},
+ recurrentBias: {dataType: 'float32', dimensions: [4 * hiddenSize]},
+ peepholeWeight: {dataType: 'float32', dimensions: [3 * hiddenSize]},
+ layout: 'ifgo',
+ activations: ['sigmoid', 'relu', 'tanh']
+ },
+ outputs: [
+ {dataType: 'float32', dimensions: [batchSize, hiddenSize]},
+ {dataType: 'float32', dimensions: [batchSize, hiddenSize]}
+ ]
+ },
+ {
+ name: '[lstmCell] Throw if hiddenSize is equal to zero',
+ input: {dataType: 'float32', dimensions: [batchSize, inputSize]},
+ weight: {dataType: 'float32', dimensions: [4 * hiddenSize, inputSize]},
+ recurrentWeight:
+ {dataType: 'float32', dimensions: [4 * hiddenSize, hiddenSize]},
+ hiddenState: {dataType: 'float32', dimensions: [batchSize, hiddenSize]},
+ cellState: {dataType: 'float32', dimensions: [batchSize, hiddenSize]},
+ hiddenSize: 0
+ },
+ {
+ name: '[lstmCell] Throw if hiddenSize is too large',
+ input: {dataType: 'float32', dimensions: [batchSize, inputSize]},
+ weight: {dataType: 'float32', dimensions: [4 * hiddenSize, inputSize]},
+ recurrentWeight:
+ {dataType: 'float32', dimensions: [4 * hiddenSize, hiddenSize]},
+ hiddenState: {dataType: 'float32', dimensions: [batchSize, hiddenSize]},
+ cellState: {dataType: 'float32', dimensions: [batchSize, hiddenSize]},
+ hiddenSize: 4294967295
+ },
+ {
+ name:
+ '[lstmCell] Throw if the input data type is not one of the floating point types',
+ input: {dataType: 'uint32', dimensions: [batchSize, inputSize]},
+ weight: {dataType: 'float32', dimensions: [4 * hiddenSize, inputSize]},
+ recurrentWeight:
+ {dataType: 'float32', dimensions: [4 * hiddenSize, hiddenSize]},
+ hiddenState: {dataType: 'float32', dimensions: [batchSize, hiddenSize]},
+ cellState: {dataType: 'float32', dimensions: [batchSize, hiddenSize]},
+ hiddenSize: hiddenSize
+ },
+ {
+ name: '[lstmCell] Throw if the rank of input is not 2',
+ input: {dataType: 'float32', dimensions: [batchSize]},
+ weight: {dataType: 'float32', dimensions: [4 * hiddenSize, inputSize]},
+ recurrentWeight:
+ {dataType: 'float32', dimensions: [4 * hiddenSize, hiddenSize]},
+ hiddenState: {dataType: 'float32', dimensions: [batchSize, hiddenSize]},
+ cellState: {dataType: 'float32', dimensions: [batchSize, hiddenSize]},
+ hiddenSize: hiddenSize
+ },
+ {
+ name: '[lstmCell] Throw if the shape of input is incorrect',
+ input: {dataType: 'float32', dimensions: [batchSize, 1000]},
+ weight: {dataType: 'float32', dimensions: [4 * hiddenSize, inputSize]},
+ recurrentWeight:
+ {dataType: 'float32', dimensions: [4 * hiddenSize, hiddenSize]},
+ hiddenState: {dataType: 'float32', dimensions: [batchSize, hiddenSize]},
+ cellState: {dataType: 'float32', dimensions: [batchSize, hiddenSize]},
+ hiddenSize: hiddenSize
+ },
+ {
+ name: '[lstmCell] Throw if the data type of weight is incorrect',
+ input: {dataType: 'float32', dimensions: [batchSize, inputSize]},
+ weight: {dataType: 'float16', dimensions: [4 * hiddenSize, inputSize]},
+ recurrentWeight:
+ {dataType: 'float32', dimensions: [4 * hiddenSize, hiddenSize]},
+ hiddenState: {dataType: 'float32', dimensions: [batchSize, hiddenSize]},
+ cellState: {dataType: 'float32', dimensions: [batchSize, hiddenSize]},
+ hiddenSize: hiddenSize
+ },
+ {
+ name: '[lstmCell] Throw if the rank of weight is not 2',
+ input: {dataType: 'float32', dimensions: [batchSize, inputSize]},
+ weight:
+ {dataType: 'float32', dimensions: [4 * hiddenSize, inputSize, 1000]},
+ recurrentWeight:
+ {dataType: 'float32', dimensions: [4 * hiddenSize, hiddenSize]},
+ hiddenState: {dataType: 'float32', dimensions: [batchSize, hiddenSize]},
+ cellState: {dataType: 'float32', dimensions: [batchSize, hiddenSize]},
+ hiddenSize: hiddenSize
+ },
+ {
+ name: '[lstmCell] Throw if the shape of weight is incorrect',
+ input: {dataType: 'float32', dimensions: [batchSize, inputSize]},
+ weight: {dataType: 'float32', dimensions: [1000, inputSize]},
+ recurrentWeight:
+ {dataType: 'float32', dimensions: [4 * hiddenSize, hiddenSize]},
+ hiddenState: {dataType: 'float32', dimensions: [batchSize, hiddenSize]},
+ cellState: {dataType: 'float32', dimensions: [batchSize, hiddenSize]},
+ hiddenSize: hiddenSize
+ },
+ {
+ name: '[lstmCell] Throw if the data type of recurrentWeight is incorrect',
+ input: {dataType: 'float32', dimensions: [batchSize, inputSize]},
+ weight: {dataType: 'float32', dimensions: [4 * hiddenSize, inputSize]},
+ recurrentWeight:
+ {dataType: 'float16', dimensions: [4 * hiddenSize, hiddenSize]},
+ hiddenState: {dataType: 'float32', dimensions: [batchSize, hiddenSize]},
+ cellState: {dataType: 'float32', dimensions: [batchSize, hiddenSize]},
+ hiddenSize: hiddenSize
+ },
+ {
+ name: '[lstmCell] Throw if the rank of recurrentWeight is not 2',
+ input: {dataType: 'float32', dimensions: [batchSize, inputSize]},
+ weight: {dataType: 'float32', dimensions: [4 * hiddenSize, inputSize]},
+ recurrentWeight:
+ {dataType: 'float32', dimensions: [1000, 4 * hiddenSize, hiddenSize]},
+ hiddenState: {dataType: 'float32', dimensions: [batchSize, hiddenSize]},
+ cellState: {dataType: 'float32', dimensions: [batchSize, hiddenSize]},
+ hiddenSize: hiddenSize
+ },
+ {
+ name: '[lstmCell] Throw if the shape of recurrentWeight is incorrect',
+ input: {dataType: 'float32', dimensions: [batchSize, inputSize]},
+ weight: {dataType: 'float32', dimensions: [4 * hiddenSize, inputSize]},
+ recurrentWeight: {dataType: 'float32', dimensions: [1000, hiddenSize]},
+ hiddenState: {dataType: 'float32', dimensions: [batchSize, hiddenSize]},
+ cellState: {dataType: 'float32', dimensions: [batchSize, hiddenSize]},
+ hiddenSize: hiddenSize
+ },
+ {
+ name: '[lstmCell] Throw if the data type of hiddenState is incorrect',
+ input: {dataType: 'float16', dimensions: [batchSize, inputSize]},
+ weight: {dataType: 'float16', dimensions: [4 * hiddenSize, inputSize]},
+ recurrentWeight:
+ {dataType: 'float16', dimensions: [4 * hiddenSize, hiddenSize]},
+ hiddenState: {dataType: 'int64', dimensions: [batchSize, hiddenSize]},
+ cellState: {dataType: 'float16', dimensions: [batchSize, hiddenSize]},
+ hiddenSize: hiddenSize
+ },
+ {
+ name: '[lstmCell] Throw if the rank of hiddenState is not 2',
+ input: {dataType: 'float32', dimensions: [batchSize, inputSize]},
+ weight: {dataType: 'float32', dimensions: [4 * hiddenSize, inputSize]},
+ recurrentWeight:
+ {dataType: 'float32', dimensions: [4 * hiddenSize, hiddenSize]},
+ hiddenState: {dataType: 'float32', dimensions: [batchSize]},
+ cellState: {dataType: 'float32', dimensions: [batchSize, hiddenSize]},
+ hiddenSize: hiddenSize
+ },
+ {
+ name: '[lstmCell] Throw if the shape of hiddenState is incorrect',
+ input: {dataType: 'float32', dimensions: [batchSize, inputSize]},
+ weight: {dataType: 'float32', dimensions: [4 * hiddenSize, inputSize]},
+ recurrentWeight:
+ {dataType: 'float32', dimensions: [4 * hiddenSize, hiddenSize]},
+ hiddenState: {dataType: 'float32', dimensions: [batchSize, 1000]},
+ cellState: {dataType: 'float32', dimensions: [batchSize, hiddenSize]},
+ hiddenSize: hiddenSize
+ },
+ {
+ name: '[lstmCell] Throw if the data type of cellState is incorrect',
+ input: {dataType: 'float16', dimensions: [batchSize, inputSize]},
+ weight: {dataType: 'float16', dimensions: [4 * hiddenSize, inputSize]},
+ recurrentWeight:
+ {dataType: 'float16', dimensions: [4 * hiddenSize, hiddenSize]},
+ hiddenState: {dataType: 'float16', dimensions: [batchSize, hiddenSize]},
+ cellState: {dataType: 'float32', dimensions: [batchSize, hiddenSize]},
+ hiddenSize: hiddenSize
+ },
+ {
+ name: '[lstmCell] Throw if the rank of cellState is not 2',
+ input: {dataType: 'float32', dimensions: [batchSize, inputSize]},
+ weight: {dataType: 'float32', dimensions: [4 * hiddenSize, inputSize]},
+ recurrentWeight:
+ {dataType: 'float32', dimensions: [4 * hiddenSize, hiddenSize]},
+ hiddenState: {dataType: 'float32', dimensions: [batchSize, hiddenSize]},
+ cellState: {dataType: 'float32', dimensions: [batchSize]},
+ hiddenSize: hiddenSize
+ },
+ {
+ name: '[lstmCell] Throw if the shape of cellState is incorrect',
+ input: {dataType: 'float16', dimensions: [batchSize, inputSize]},
+ weight: {dataType: 'float16', dimensions: [4 * hiddenSize, inputSize]},
+ recurrentWeight:
+ {dataType: 'float16', dimensions: [4 * hiddenSize, hiddenSize]},
+ hiddenState: {dataType: 'float16', dimensions: [batchSize, hiddenSize]},
+ cellState: {dataType: 'float16', dimensions: [batchSize, 1000]},
+ hiddenSize: hiddenSize
+ },
+ {
+ name: '[lstmCell] Throw if the data type of options.bias is incorrect',
+ input: {dataType: 'float16', dimensions: [batchSize, inputSize]},
+ weight: {dataType: 'float16', dimensions: [4 * hiddenSize, inputSize]},
+ recurrentWeight:
+ {dataType: 'float16', dimensions: [4 * hiddenSize, hiddenSize]},
+ hiddenState: {dataType: 'float16', dimensions: [batchSize, hiddenSize]},
+ cellState: {dataType: 'float16', dimensions: [batchSize, hiddenSize]},
+ hiddenSize: hiddenSize,
+ options: {bias: {dataType: 'int8', dimensions: [4 * hiddenSize]}}
+ },
+ {
+ name: '[lstmCell] Throw if the rank of options.bias is not 1',
+ input: {dataType: 'float16', dimensions: [batchSize, inputSize]},
+ weight: {dataType: 'float16', dimensions: [4 * hiddenSize, inputSize]},
+ recurrentWeight:
+ {dataType: 'float16', dimensions: [4 * hiddenSize, hiddenSize]},
+ hiddenState: {dataType: 'float16', dimensions: [batchSize, hiddenSize]},
+ cellState: {dataType: 'float16', dimensions: [batchSize, hiddenSize]},
+ hiddenSize: hiddenSize,
+ options: {bias: {dataType: 'float16', dimensions: [4 * hiddenSize, 1000]}}
+ },
+ {
+ name: '[lstmCell] Throw if the shape of options.bias is incorrect',
+ input: {dataType: 'float16', dimensions: [batchSize, inputSize]},
+ weight: {dataType: 'float16', dimensions: [4 * hiddenSize, inputSize]},
+ recurrentWeight:
+ {dataType: 'float16', dimensions: [4 * hiddenSize, hiddenSize]},
+ hiddenState: {dataType: 'float16', dimensions: [batchSize, hiddenSize]},
+ cellState: {dataType: 'float16', dimensions: [batchSize, hiddenSize]},
+ hiddenSize: hiddenSize,
+ options: {bias: {dataType: 'float16', dimensions: [1000]}}
+ },
+ {
+ name:
+ '[lstmCell] Throw if the data type of options.recurrentBias is incorrect',
+ input: {dataType: 'float16', dimensions: [batchSize, inputSize]},
+ weight: {dataType: 'float16', dimensions: [4 * hiddenSize, inputSize]},
+ recurrentWeight:
+ {dataType: 'float16', dimensions: [4 * hiddenSize, hiddenSize]},
+ hiddenState: {dataType: 'float16', dimensions: [batchSize, hiddenSize]},
+ cellState: {dataType: 'float16', dimensions: [batchSize, hiddenSize]},
+ hiddenSize: hiddenSize,
+ options: {recurrentBias: {dataType: 'uint8', dimensions: [4 * hiddenSize]}}
+ },
+ {
+ name: '[lstmCell] Throw if the rank of options.recurrentBias is not 1',
+ input: {dataType: 'float16', dimensions: [batchSize, inputSize]},
+ weight: {dataType: 'float16', dimensions: [4 * hiddenSize, inputSize]},
+ recurrentWeight:
+ {dataType: 'float16', dimensions: [4 * hiddenSize, hiddenSize]},
+ hiddenState: {dataType: 'float16', dimensions: [batchSize, hiddenSize]},
+ cellState: {dataType: 'float16', dimensions: [batchSize, hiddenSize]},
+ hiddenSize: hiddenSize,
+ options: {
+ recurrentBias: {dataType: 'float16', dimensions: [4 * hiddenSize, 1000]}
+ }
+ },
+ {
+ name: '[lstmCell] Throw if the shape of options.recurrentBias is incorrect',
+ input: {dataType: 'float16', dimensions: [batchSize, inputSize]},
+ weight: {dataType: 'float16', dimensions: [4 * hiddenSize, inputSize]},
+ recurrentWeight:
+ {dataType: 'float16', dimensions: [4 * hiddenSize, hiddenSize]},
+ hiddenState: {dataType: 'float16', dimensions: [batchSize, hiddenSize]},
+ cellState: {dataType: 'float16', dimensions: [batchSize, hiddenSize]},
+ hiddenSize: hiddenSize,
+ options: {recurrentBias: {dataType: 'float16', dimensions: [1000]}}
+ },
+ {
+ name:
+ '[lstmCell] Throw if the data type of options.peepholeWeight is incorrect',
+ input: {dataType: 'float16', dimensions: [batchSize, inputSize]},
+ weight: {dataType: 'float16', dimensions: [4 * hiddenSize, inputSize]},
+ recurrentWeight:
+ {dataType: 'float16', dimensions: [4 * hiddenSize, hiddenSize]},
+ hiddenState: {dataType: 'float16', dimensions: [batchSize, hiddenSize]},
+ cellState: {dataType: 'float16', dimensions: [batchSize, hiddenSize]},
+ hiddenSize: hiddenSize,
+ options:
+ {peepholeWeight: {dataType: 'float32', dimensions: [3 * hiddenSize]}}
+ },
+ {
+ name: '[lstmCell] Throw if the rank of options.peepholeWeight is not 1',
+ input: {dataType: 'float16', dimensions: [batchSize, inputSize]},
+ weight: {dataType: 'float16', dimensions: [4 * hiddenSize, inputSize]},
+ recurrentWeight:
+ {dataType: 'float16', dimensions: [4 * hiddenSize, hiddenSize]},
+ hiddenState: {dataType: 'float16', dimensions: [batchSize, hiddenSize]},
+ cellState: {dataType: 'float16', dimensions: [batchSize, hiddenSize]},
+ hiddenSize: hiddenSize,
+ options: {peepholeWeight: {dataType: 'float16', dimensions: []}}
+ },
+ {
+ name:
+ '[lstmCell] Throw if the shape of options.peepholeWeight is incorrect',
+ input: {dataType: 'float16', dimensions: [batchSize, inputSize]},
+ weight: {dataType: 'float16', dimensions: [4 * hiddenSize, inputSize]},
+ recurrentWeight:
+ {dataType: 'float16', dimensions: [4 * hiddenSize, hiddenSize]},
+ hiddenState: {dataType: 'float16', dimensions: [batchSize, hiddenSize]},
+ cellState: {dataType: 'float16', dimensions: [batchSize, hiddenSize]},
+ hiddenSize: hiddenSize,
+ options: {peepholeWeight: {dataType: 'float16', dimensions: [1000]}}
+ },
+ {
+ name: '[lstmCell] Throw if the size of options.activations is not 3',
+ input: {dataType: 'float32', dimensions: [batchSize, inputSize]},
+ weight: {dataType: 'float32', dimensions: [4 * hiddenSize, inputSize]},
+ recurrentWeight:
+ {dataType: 'float32', dimensions: [4 * hiddenSize, hiddenSize]},
+ hiddenState: {dataType: 'float32', dimensions: [batchSize, hiddenSize]},
+ cellState: {dataType: 'float32', dimensions: [batchSize, hiddenSize]},
+ hiddenSize: hiddenSize,
+ options: {activations: ['sigmoid', 'tanh', 'sigmoid', 'tanh']}
+ }
+];
+
+tests.forEach(
+ test => promise_test(async t => {
+ const input = builder.input(
+ 'input',
+ {dataType: test.input.dataType, dimensions: test.input.dimensions});
+ const weight = builder.input(
+ 'weight',
+ {dataType: test.weight.dataType, dimensions: test.weight.dimensions});
+ const recurrentWeight = builder.input('recurrentWeight', {
+ dataType: test.recurrentWeight.dataType,
+ dimensions: test.recurrentWeight.dimensions
+ });
+ const hiddenState = builder.input('hiddenState', {
+ dataType: test.hiddenState.dataType,
+ dimensions: test.hiddenState.dimensions
+ });
+ const cellState = builder.input('cellState', {
+ dataType: test.cellState.dataType,
+ dimensions: test.cellState.dimensions
+ });
+
+ const options = {};
+ if (test.options) {
+ if (test.options.bias) {
+ options.bias = builder.input('bias', {
+ dataType: test.options.bias.dataType,
+ dimensions: test.options.bias.dimensions
+ });
+ }
+ if (test.options.recurrentBias) {
+ options.bias = builder.input('recurrentBias', {
+ dataType: test.options.recurrentBias.dataType,
+ dimensions: test.options.recurrentBias.dimensions
+ });
+ }
+ if (test.options.peepholeWeight) {
+ options.peepholeWeight = builder.input('peepholeWeight', {
+ dataType: test.options.peepholeWeight.dataType,
+ dimensions: test.options.peepholeWeight.dimensions
+ });
+ }
+ if (test.options.layout) {
+ options.layout = test.options.layout;
+ }
+ if (test.options.activations) {
+ options.activations = [];
+ test.options.activations.forEach(
+ activation => options.activations.push(builder[activation]()));
+ }
+ }
+
+ if (test.outputs) {
+ const outputs = builder.lstmCell(
+ input, weight, recurrentWeight, hiddenState, cellState,
+ test.hiddenSize, options);
+ assert_equals(outputs.length, test.outputs.length);
+ for (let i = 0; i < outputs.length; ++i) {
+ assert_equals(outputs[i].dataType(), test.outputs[i].dataType);
+ assert_array_equals(outputs[i].shape(), test.outputs[i].dimensions);
+ }
+ } else {
+ assert_throws_js(
+ TypeError,
+ () => builder.lstmCell(
+ input, weight, recurrentWeight, hiddenState, cellState,
+ test.hiddenSize, options));
+ }
+ }, test.name));