122 lines
3.3 KiB
JavaScript
122 lines
3.3 KiB
JavaScript
// META: title=validation tests for WebNN API split operation
|
|
// META: global=window
|
|
// META: variant=?cpu
|
|
// META: variant=?gpu
|
|
// META: variant=?npu
|
|
// META: script=../resources/utils_validation.js
|
|
|
|
'use strict';
|
|
|
|
multi_builder_test(async (t, builder, otherBuilder) => {
|
|
const inputFromOtherBuilder =
|
|
otherBuilder.input('input', {dataType: 'float32', shape: [4, 4]});
|
|
|
|
const splits = 2;
|
|
assert_throws_js(
|
|
TypeError, () => builder.split(inputFromOtherBuilder, splits));
|
|
}, '[split] throw if input is from another builder');
|
|
|
|
const label = 'xxx-split';
|
|
const tests = [
|
|
{
|
|
name: '[split] Test with default options.',
|
|
input: {dataType: 'float32', shape: [2, 6]},
|
|
splits: [2],
|
|
outputs: [
|
|
{dataType: 'float32', shape: [2, 6]},
|
|
]
|
|
},
|
|
{
|
|
name:
|
|
'[split] Test with a sequence of unsigned long splits and with options.axis = 1.',
|
|
input: {dataType: 'float32', shape: [2, 6]},
|
|
splits: [1, 2, 3],
|
|
options: {axis: 1},
|
|
outputs: [
|
|
{dataType: 'float32', shape: [2, 1]},
|
|
{dataType: 'float32', shape: [2, 2]},
|
|
{dataType: 'float32', shape: [2, 3]},
|
|
]
|
|
},
|
|
{
|
|
name: '[split] Throw if splitting a scalar.',
|
|
input: {dataType: 'float32', shape: []},
|
|
splits: [2],
|
|
options: {label}
|
|
},
|
|
{
|
|
name: '[split] Throw if axis is larger than input rank.',
|
|
input: {dataType: 'float32', shape: [2, 6]},
|
|
splits: [2],
|
|
options: {
|
|
axis: 2,
|
|
label: label,
|
|
}
|
|
},
|
|
{
|
|
name: '[split] Throw if splits is equal to 0.',
|
|
input: {dataType: 'float32', shape: [2, 6]},
|
|
splits: [0],
|
|
options: {
|
|
axis: 0,
|
|
label: label,
|
|
}
|
|
},
|
|
{
|
|
name: '[split] Throw if splits (scalar) is equal to 0.',
|
|
input: {dataType: 'float32', shape: [2, 6]},
|
|
splits: 0,
|
|
options: {
|
|
axis: 0,
|
|
label: label,
|
|
},
|
|
},
|
|
{
|
|
name:
|
|
'[split] Throw if the splits can not evenly divide the dimension size of input along options.axis.',
|
|
input: {dataType: 'float32', shape: [2, 5]},
|
|
splits: [2],
|
|
options: {
|
|
axis: 1,
|
|
label: label,
|
|
}
|
|
},
|
|
{
|
|
name:
|
|
'[split] Throw if splits (scalar) can not evenly divide the dimension size of input along options.axis.',
|
|
input: {dataType: 'float32', shape: [2, 5]},
|
|
splits: 2,
|
|
options: {
|
|
axis: 1,
|
|
label: label,
|
|
},
|
|
},
|
|
{
|
|
name:
|
|
'[split] Throw if the sum of splits sizes not equal to the dimension size of input along options.axis.',
|
|
input: {dataType: 'float32', shape: [2, 6]},
|
|
splits: [2, 2, 3],
|
|
options: {
|
|
axis: 1,
|
|
label: label,
|
|
}
|
|
},
|
|
];
|
|
|
|
tests.forEach(
|
|
test => promise_test(async t => {
|
|
const builder = new MLGraphBuilder(context);
|
|
const input = builder.input('input', test.input);
|
|
if (test.outputs) {
|
|
const outputs = builder.split(input, test.splits, test.options);
|
|
assert_equals(outputs.length, test.outputs.length);
|
|
for (let i = 0; i < outputs.length; ++i) {
|
|
assert_equals(outputs[i].dataType, test.outputs[i].dataType);
|
|
assert_array_equals(outputs[i].shape, test.outputs[i].shape);
|
|
}
|
|
} else {
|
|
const regrexp = new RegExp('\\[' + label + '\\]');
|
|
assert_throws_with_label(
|
|
() => builder.split(input, test.splits, test.options), regrexp);
|
|
}
|
|
}, test.name));
|