136 lines
4.1 KiB
JavaScript
136 lines
4.1 KiB
JavaScript
// META: title=validation tests for WebNN API concat operation
|
|
// META: global=window
|
|
// META: variant=?cpu
|
|
// META: variant=?gpu
|
|
// META: variant=?npu
|
|
// META: script=../resources/utils_validation.js
|
|
|
|
'use strict';
|
|
|
|
const label = `concate_123`;
|
|
const tests = [
|
|
{
|
|
name: '[concat] Test building Concat with one input.',
|
|
inputs: [{dataType: 'float32', shape: [4, 4, 3]}],
|
|
axis: 2,
|
|
output: {dataType: 'float32', shape: [4, 4, 3]}
|
|
},
|
|
{
|
|
name: '[concat] Test building Concat with two inputs',
|
|
inputs: [
|
|
{dataType: 'float32', shape: [3, 1, 5]},
|
|
{dataType: 'float32', shape: [3, 2, 5]}
|
|
],
|
|
axis: 1,
|
|
output: {dataType: 'float32', shape: [3, 3, 5]}
|
|
},
|
|
{
|
|
name: '[concat] Test building Concat with three inputs',
|
|
inputs: [
|
|
{dataType: 'float32', shape: [3, 5, 1]},
|
|
{dataType: 'float32', shape: [3, 5, 2]},
|
|
{dataType: 'float32', shape: [3, 5, 3]}
|
|
],
|
|
axis: 2,
|
|
output: {dataType: 'float32', shape: [3, 5, 6]}
|
|
},
|
|
{
|
|
name: '[concat] Test building Concat with two 1D inputs.',
|
|
inputs:
|
|
[{dataType: 'float32', shape: [1]}, {dataType: 'float32', shape: [1]}],
|
|
axis: 0,
|
|
output: {dataType: 'float32', shape: [2]}
|
|
},
|
|
{
|
|
name: '[concat] Throw if the inputs are empty.',
|
|
axis: 0,
|
|
},
|
|
{
|
|
name: '[concat] Throw if the argument types are inconsistent.',
|
|
inputs: [
|
|
{dataType: 'float32', shape: [1, 1]}, {dataType: 'int32', shape: [1, 1]}
|
|
],
|
|
axis: 0,
|
|
},
|
|
{
|
|
name: '[concat] Throw if the inputs have different ranks.',
|
|
inputs: [
|
|
{dataType: 'float32', shape: [1, 1]},
|
|
{dataType: 'float32', shape: [1, 1, 1]}
|
|
],
|
|
axis: 0,
|
|
},
|
|
{
|
|
name:
|
|
'[concat] Throw if the axis is equal to or greater than the size of ranks',
|
|
inputs: [
|
|
{dataType: 'float32', shape: [1, 1]}, {dataType: 'float32', shape: [1, 1]}
|
|
],
|
|
axis: 2,
|
|
},
|
|
{
|
|
name: '[concat] Throw if concat with two 0-D scalars.',
|
|
inputs:
|
|
[{dataType: 'float32', shape: []}, {dataType: 'float32', shape: []}],
|
|
axis: 0,
|
|
},
|
|
{
|
|
name:
|
|
'[concat] Throw if the inputs have other axes with different sizes except on the axis.',
|
|
inputs: [
|
|
{dataType: 'float32', shape: [1, 1, 1]},
|
|
{dataType: 'float32', shape: [1, 2, 3]}
|
|
],
|
|
axis: 1,
|
|
},
|
|
];
|
|
|
|
tests.forEach(
|
|
test => promise_test(async t => {
|
|
const builder = new MLGraphBuilder(context);
|
|
let inputs = [];
|
|
if (test.inputs) {
|
|
for (let i = 0; i < test.inputs.length; ++i) {
|
|
inputs[i] = builder.input(`inputs[${i}]`, test.inputs[i]);
|
|
}
|
|
}
|
|
if (test.output) {
|
|
const output = builder.concat(inputs, test.axis);
|
|
assert_equals(output.dataType, test.output.dataType);
|
|
assert_array_equals(output.shape, test.output.shape);
|
|
} else {
|
|
const options = {label};
|
|
const regrexp = new RegExp('\\[' + label + '\\]');
|
|
assert_throws_with_label(
|
|
() => builder.concat(inputs, test.axis, options), regrexp);
|
|
}
|
|
}, test.name));
|
|
|
|
multi_builder_test(async (t, builder, otherBuilder) => {
|
|
const operandDescriptor = {dataType: 'float32', shape: [2, 2]};
|
|
|
|
const inputFromOtherBuilder = otherBuilder.input('input', operandDescriptor);
|
|
|
|
const input1 = builder.input('input', operandDescriptor);
|
|
const input2 = builder.input('input', operandDescriptor);
|
|
const input3 = builder.input('input', operandDescriptor);
|
|
|
|
assert_throws_js(
|
|
TypeError,
|
|
() => builder.concat([input1, input2, inputFromOtherBuilder, input3]));
|
|
}, '[concat] throw if any input is from another builder');
|
|
|
|
promise_test(async t => {
|
|
const builder = new MLGraphBuilder(context);
|
|
|
|
const operandDescriptor = {
|
|
dataType: 'float32',
|
|
shape: [context.opSupportLimits().maxTensorByteLength / 4]
|
|
};
|
|
const input1 = builder.input('input1', operandDescriptor);
|
|
const input2 = builder.input('input2', operandDescriptor);
|
|
const input3 = builder.input('input3', operandDescriptor);
|
|
|
|
assert_throws_js(
|
|
TypeError, () => builder.concat(input1, input2, input3));
|
|
}, '[concat] throw if the output tensor byte length exceeds limit');
|