diff options
Diffstat (limited to 'testing/web-platform/tests/webnn/validation_tests/gruCell.https.any.js')
-rw-r--r-- | testing/web-platform/tests/webnn/validation_tests/gruCell.https.any.js | 471 |
1 files changed, 471 insertions, 0 deletions
diff --git a/testing/web-platform/tests/webnn/validation_tests/gruCell.https.any.js b/testing/web-platform/tests/webnn/validation_tests/gruCell.https.any.js new file mode 100644 index 0000000000..3cd9d32b07 --- /dev/null +++ b/testing/web-platform/tests/webnn/validation_tests/gruCell.https.any.js @@ -0,0 +1,471 @@ +// META: title=validation tests for WebNN API gruCell operation +// META: global=window,dedicatedworker +// META: script=../resources/utils_validation.js + +'use strict'; + +const batchSize = 3, inputSize = 4, hiddenSize = 5; + +// Dimensions required of required inputs. +const kValidInputDimensions = [batchSize, inputSize]; +const kValidWeightDimensions = [3 * hiddenSize, inputSize]; +const kValidRecurrentWeightDimensions = [3 * hiddenSize, hiddenSize]; +const kValidHiddenStateDimensions = [batchSize, hiddenSize]; +// Dimensions required of optional inputs. +const kValidBiasDimensions = [3 * hiddenSize]; +const kValidRecurrentBiasDimensions = [3 * hiddenSize]; +// Dimensions required of required output. +const kValidOutputDimensions = [batchSize, hiddenSize]; + +// Example descriptors which are valid according to the above dimensions. +const kExampleInputDescriptor = { + dataType: 'float32', + dimensions: kValidInputDimensions +}; +const kExampleWeightDescriptor = { + dataType: 'float32', + dimensions: kValidWeightDimensions +}; +const kExampleRecurrentWeightDescriptor = { + dataType: 'float32', + dimensions: kValidRecurrentWeightDimensions +}; +const kExampleHiddenStateDescriptor = { + dataType: 'float32', + dimensions: kValidHiddenStateDimensions +}; +const kExampleBiasDescriptor = { + dataType: 'float32', + dimensions: kValidBiasDimensions +}; +const kExampleRecurrentBiasDescriptor = { + dataType: 'float32', + dimensions: kValidRecurrentBiasDimensions +}; +const kExampleOutputDescriptor = { + dataType: 'float32', + dimensions: kValidOutputDimensions + }; + +const tests = [ + { + name: '[gruCell] Test with default options', + input: kExampleInputDescriptor, + weight: kExampleWeightDescriptor, + recurrentWeight: kExampleRecurrentWeightDescriptor, + hiddenState: kExampleHiddenStateDescriptor, + hiddenSize: hiddenSize, + output: kExampleOutputDescriptor + }, + { + name: '[gruCell] Test with given options', + input: kExampleInputDescriptor, + weight: kExampleWeightDescriptor, + recurrentWeight: kExampleRecurrentWeightDescriptor, + hiddenState: kExampleHiddenStateDescriptor, + hiddenSize: hiddenSize, + options: { + bias: kExampleBiasDescriptor, + recurrentBias: kExampleRecurrentBiasDescriptor, + restAfter: true, + layout: 'rzn', + activations: ['sigmoid', 'relu'] + }, + output: kExampleOutputDescriptor + }, + { + name: '[gruCell] Throw if hiddenSize equals to zero', + input: kExampleInputDescriptor, + weight: kExampleWeightDescriptor, + recurrentWeight: kExampleRecurrentWeightDescriptor, + hiddenState: kExampleHiddenStateDescriptor, + hiddenSize: 0 + }, + { + name: '[gruCell] Throw if hiddenSize is too large', + input: kExampleInputDescriptor, + weight: kExampleWeightDescriptor, + recurrentWeight: kExampleRecurrentWeightDescriptor, + hiddenState: kExampleHiddenStateDescriptor, + hiddenSize: 4294967295, + }, + { + name: + '[gruCell] Throw if the data type of the inputs is not one of the floating point types', + input: { dataType: 'uint32', dimensions: kValidInputDimensions }, + weight: { dataType: 'uint32', dimensions: kValidWeightDimensions }, + recurrentWeight: { + dataType: 'uint32', + dimensions: kValidRecurrentWeightDimensions + }, + hiddenState: { + dataType: 'uint32', + dimensions: kValidHiddenStateDimensions + }, + hiddenSize: hiddenSize + }, + { + name: + '[gruCell] Throw if the rank of input is not 2', + input: { dataType: 'float32', dimensions: [batchSize] }, + weight: kExampleWeightDescriptor, + recurrentWeight: kExampleRecurrentWeightDescriptor, + hiddenState: kExampleHiddenStateDescriptor, + hiddenSize: hiddenSize + }, + { + name: + '[gruCell] Throw if the input.dimensions[1] is incorrect', + input: { dataType: 'float32', dimensions: [inputSize, inputSize] }, + weight: kExampleWeightDescriptor, + recurrentWeight: kExampleRecurrentWeightDescriptor, + hiddenState: kExampleHiddenStateDescriptor, + hiddenSize: hiddenSize + }, + { + name: '[gruCell] Throw if data type of weight is not one of the floating point types', + input: kExampleInputDescriptor, + weight: { + dataType: 'int8', + dimensions: [3 * hiddenSize, inputSize] + }, + recurrentWeight: kExampleRecurrentWeightDescriptor, + hiddenState: kExampleHiddenStateDescriptor, + hiddenSize: hiddenSize + }, + { + name: '[gruCell] Throw if rank of weight is not 2', + input: kExampleInputDescriptor, + weight: { + dataType: 'float32', + dimensions: [3 * hiddenSize] + }, + recurrentWeight: kExampleRecurrentWeightDescriptor, + hiddenState: kExampleHiddenStateDescriptor, + hiddenSize: hiddenSize + }, + { + name: '[gruCell] Throw if weight.dimensions[0] is not 3 * hiddenSize', + input: kExampleInputDescriptor, + weight: { + dataType: 'float32', + dimensions: [4 * hiddenSize, inputSize] + }, + recurrentWeight: kExampleRecurrentWeightDescriptor, + hiddenState: kExampleHiddenStateDescriptor, + hiddenSize: hiddenSize + }, + { + name: '[gruCell] Throw if data type of recurrentWeight is not one of the floating point types', + input: kExampleInputDescriptor, + weight: kExampleWeightDescriptor, + recurrentWeight: { + dataType: 'int32', + dimensions: [3 * hiddenSize, hiddenSize] + }, + hiddenState: kExampleHiddenStateDescriptor, + hiddenSize: hiddenSize + }, + { + name: + '[gruCell] Throw if the rank of recurrentWeight is not 2', + input: kExampleInputDescriptor, + weight: kExampleWeightDescriptor, + recurrentWeight: + { dataType: 'float32', dimensions: [3 * hiddenSize] }, + hiddenState: kExampleHiddenStateDescriptor, + hiddenSize: hiddenSize + }, + { + name: + '[gruCell] Throw if the recurrentWeight.dimensions is invalid', + input: kExampleInputDescriptor, + weight: kExampleWeightDescriptor, + recurrentWeight: + { dataType: 'float32', dimensions: [4 * hiddenSize, inputSize] }, + hiddenState: kExampleHiddenStateDescriptor, + hiddenSize: hiddenSize + }, + { + name: + '[gruCell] Throw if data type of hiddenState is not one of the floating point types', + input: kExampleInputDescriptor, + weight: kExampleWeightDescriptor, + recurrentWeight: + kExampleRecurrentWeightDescriptor, + hiddenState: { + dataType: 'uint32', + dimensions: [batchSize, hiddenSize] + }, + hiddenSize: hiddenSize + }, + { + name: + '[gruCell] Throw if the rank of hiddenState is not 2', + input: kExampleInputDescriptor, + weight: kExampleWeightDescriptor, + recurrentWeight: + kExampleRecurrentWeightDescriptor, + hiddenState: { + dataType: 'float32', + dimensions: [hiddenSize] + }, + hiddenSize: hiddenSize + }, + { + name: + '[gruCell] Throw if the hiddenState.dimensions is invalid', + input: kExampleInputDescriptor, + weight: kExampleWeightDescriptor, + recurrentWeight: kExampleRecurrentWeightDescriptor, + hiddenState: { + dataType: 'float32', + dimensions: [batchSize, 3 * hiddenSize] + }, + hiddenSize: hiddenSize + }, + { + name: + '[gruCell] Throw if the size of options.activations is not 2', + input: kExampleInputDescriptor, + weight: kExampleWeightDescriptor, + recurrentWeight: kExampleRecurrentWeightDescriptor, + hiddenState: kExampleHiddenStateDescriptor, + hiddenSize: hiddenSize, + options: { activations: ['sigmoid', 'tanh', 'relu'] } + }, + { + name: + '[gruCell] Throw if data type of options.bias is not one of the floating point types', + input: kExampleInputDescriptor, + weight: kExampleWeightDescriptor, + recurrentWeight: kExampleRecurrentWeightDescriptor, + hiddenState: kExampleHiddenStateDescriptor, + hiddenSize: hiddenSize, + options: { bias: { dataType: 'uint8', dimensions: [3 * hiddenSize] } } + }, + { + name: + '[gruCell] Throw if the rank of options.bias is not 1', + input: kExampleInputDescriptor, + weight: kExampleWeightDescriptor, + recurrentWeight: kExampleRecurrentWeightDescriptor, + hiddenState: kExampleHiddenStateDescriptor, + hiddenSize: hiddenSize, + options: { bias: { dataType: 'float32', dimensions: [batchSize, 3 * hiddenSize] } } + }, + { + name: + '[gruCell] Throw if options.bias.dimensions[0] is not 3 * hiddenSize', + input: kExampleInputDescriptor, + weight: kExampleWeightDescriptor, + recurrentWeight: kExampleRecurrentWeightDescriptor, + hiddenState: kExampleHiddenStateDescriptor, + hiddenSize: hiddenSize, + options: { bias: { dataType: 'float32', dimensions: [2 * hiddenSize] } } + }, + { + name: + '[gruCell] Throw if data type of options.recurrentBias is not one of the floating point types', + input: kExampleInputDescriptor, + weight: kExampleWeightDescriptor, + recurrentWeight: kExampleRecurrentWeightDescriptor, + hiddenState: kExampleHiddenStateDescriptor, + hiddenSize: hiddenSize, + options: { recurrentBias: { dataType: 'int8', dimensions: [3 * hiddenSize] } } + }, + { + name: + '[gruCell] Throw if the rank of options.recurrentBias is not 1', + input: kExampleInputDescriptor, + weight: kExampleWeightDescriptor, + recurrentWeight: kExampleRecurrentWeightDescriptor, + hiddenState: kExampleHiddenStateDescriptor, + hiddenSize: hiddenSize, + options: { recurrentBias: { dataType: 'float32', dimensions: [batchSize, 3 * hiddenSize] } } + }, + { + name: + '[gruCell] Throw if options.recurrentBias.dimensions[0] is not 3 * hiddenSize', + input: kExampleInputDescriptor, + weight: kExampleWeightDescriptor, + recurrentWeight: kExampleRecurrentWeightDescriptor, + hiddenState: kExampleHiddenStateDescriptor, + hiddenSize: hiddenSize, + options: { + recurrentBias: { dataType: 'float16', dimensions: [4 * hiddenSize] } + } + } +]; + +tests.forEach( + test => promise_test(async t => { + const input = builder.input( + 'input', + { dataType: test.input.dataType, dimensions: test.input.dimensions }); + const weight = builder.input( + 'weight', + { dataType: test.weight.dataType, dimensions: test.weight.dimensions }); + const recurrentWeight = builder.input('recurrentWeight', { + dataType: test.recurrentWeight.dataType, + dimensions: test.recurrentWeight.dimensions + }); + const hiddenState = builder.input('hiddenState', { + dataType: test.hiddenState.dataType, + dimensions: test.hiddenState.dimensions + }); + + const options = {}; + if (test.options) { + if (test.options.bias) { + options.bias = builder.input('bias', { + dataType: test.options.bias.dataType, + dimensions: test.options.bias.dimensions + }); + } + if (test.options.recurrentBias) { + options.bias = builder.input('recurrentBias', { + dataType: test.options.recurrentBias.dataType, + dimensions: test.options.recurrentBias.dimensions + }); + } + if (test.options.resetAfter) { + options.resetAfter = test.options.resetAfter; + } + if (test.options.layout) { + options.layout = test.options.layout; + } + if (test.options.activations) { + options.activations = []; + test.options.activations.forEach( + activation => options.activations.push(builder[activation]())); + } + } + + if (test.output) { + const output = builder.gruCell( + input, weight, recurrentWeight, hiddenState, test.hiddenSize, + options); + assert_equals(output.dataType(), test.output.dataType); + assert_array_equals(output.shape(), test.output.dimensions); + } else { + assert_throws_js( + TypeError, + () => builder.gruCell( + input, weight, recurrentWeight, hiddenState, test.hiddenSize, + options)); + } + }, test.name)); + +multi_builder_test(async (t, builder, otherBuilder) => { + const inputFromOtherBuilder = + otherBuilder.input('input', kExampleInputDescriptor); + + const weight = builder.input('weight', kExampleWeightDescriptor); + const recurrentWeight = + builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); + const hiddenState = + builder.input('hiddenState', kExampleHiddenStateDescriptor); + assert_throws_js( + TypeError, + () => builder.gruCell( + inputFromOtherBuilder, weight, recurrentWeight, hiddenState, + hiddenSize)); +}, '[gruCell] throw if input is from another builder'); + +multi_builder_test(async (t, builder, otherBuilder) => { + const weightFromOtherBuilder = + otherBuilder.input('weight', kExampleWeightDescriptor); + + const input = builder.input('input', kExampleInputDescriptor); + const recurrentWeight = + builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); + const hiddenState = + builder.input('hiddenState', kExampleHiddenStateDescriptor); + assert_throws_js( + TypeError, + () => builder.gruCell( + input, weightFromOtherBuilder, recurrentWeight, hiddenState, + hiddenSize)); +}, '[gruCell] throw if weight is from another builder'); + +multi_builder_test(async (t, builder, otherBuilder) => { + const recurrentWeightFromOtherBuilder = + otherBuilder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); + + const input = builder.input('input', kExampleInputDescriptor); + const weight = builder.input('weight', kExampleWeightDescriptor); + const hiddenState = + builder.input('hiddenState', kExampleHiddenStateDescriptor); + assert_throws_js( + TypeError, + () => builder.gruCell( + input, weight, recurrentWeightFromOtherBuilder, hiddenState, + hiddenSize)); +}, '[gruCell] throw if recurrentWeight is from another builder'); + +multi_builder_test(async (t, builder, otherBuilder) => { + const hiddenStateFromOtherBuilder = + otherBuilder.input('hiddenState', kExampleHiddenStateDescriptor); + + const input = builder.input('input', kExampleInputDescriptor); + const weight = builder.input('weight', kExampleWeightDescriptor); + const recurrentWeight = + builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); + assert_throws_js( + TypeError, + () => builder.gruCell( + input, weight, recurrentWeight, hiddenStateFromOtherBuilder, + hiddenSize)); +}, '[gruCell] throw if hiddenState is from another builder'); + +multi_builder_test(async (t, builder, otherBuilder) => { + const biasFromOtherBuilder = + otherBuilder.input('bias', kExampleBiasDescriptor); + const options = {bias: biasFromOtherBuilder}; + + const input = builder.input('input', kExampleInputDescriptor); + const weight = builder.input('weight', kExampleWeightDescriptor); + const recurrentWeight = + builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); + const hiddenState = + builder.input('hiddenState', kExampleHiddenStateDescriptor); + assert_throws_js( + TypeError, + () => builder.gruCell( + input, weight, recurrentWeight, hiddenState, hiddenSize, options)); +}, '[gruCell] throw if bias option is from another builder'); + +multi_builder_test(async (t, builder, otherBuilder) => { + const recurrentBiasFromOtherBuilder = + otherBuilder.input('recurrentBias', kExampleRecurrentBiasDescriptor); + const options = {recurrentBias: recurrentBiasFromOtherBuilder}; + + const input = builder.input('input', kExampleInputDescriptor); + const weight = builder.input('weight', kExampleWeightDescriptor); + const recurrentWeight = + builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); + const hiddenState = + builder.input('hiddenState', kExampleHiddenStateDescriptor); + assert_throws_js( + TypeError, + () => builder.gruCell( + input, weight, recurrentWeight, hiddenState, hiddenSize, options)); +}, '[gruCell] throw if recurrentBias option is from another builder'); + +multi_builder_test(async (t, builder, otherBuilder) => { + const activation = builder.clamp(); + const activationFromOtherBuilder = otherBuilder.clamp(); + const options = {activations: [activation, activationFromOtherBuilder]}; + + const input = builder.input('input', kExampleInputDescriptor); + const weight = builder.input('weight', kExampleWeightDescriptor); + const recurrentWeight = + builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor); + const hiddenState = + builder.input('hiddenState', kExampleHiddenStateDescriptor); + assert_throws_js( + TypeError, + () => builder.gruCell( + input, weight, recurrentWeight, hiddenState, hiddenSize, options)); +}, '[gruCell] throw if any activation option is from another builder'); |