106 lines
3 KiB
JavaScript
106 lines
3 KiB
JavaScript
// META: title=validation tests for WebNN API argMin/Max operations
|
|
// META: global=window
|
|
// META: variant=?cpu
|
|
// META: variant=?gpu
|
|
// META: variant=?npu
|
|
// META: script=../resources/utils_validation.js
|
|
|
|
'use strict';
|
|
|
|
const kArgMinMaxOperators = [
|
|
'argMin',
|
|
'argMax',
|
|
];
|
|
|
|
const label = 'arg_min_max_1_!';
|
|
|
|
const tests = [
|
|
{
|
|
name: '[argMin/Max] Test with default options.',
|
|
input: {dataType: 'float32', shape: [1, 2, 3, 4]},
|
|
axis: 0,
|
|
output: {shape: [2, 3, 4]}
|
|
},
|
|
{
|
|
name: '[argMin/Max] Test with axes=1.',
|
|
input: {dataType: 'float32', shape: [1, 2, 3, 4]},
|
|
axis: 1,
|
|
output: {shape: [1, 3, 4]}
|
|
},
|
|
{
|
|
name: '[argMin/Max] Test with outputDataType=int32',
|
|
input: {dataType: 'float32', shape: [1, 2, 3, 4]},
|
|
axis: 1,
|
|
options: {
|
|
outputDataType: 'int32',
|
|
},
|
|
output: {shape: [1, 3, 4]}
|
|
},
|
|
{
|
|
name: '[argMin/Max] Test with outputDataType=int64',
|
|
input: {dataType: 'float32', shape: [1, 2, 3, 4]},
|
|
axis: 1,
|
|
options: {
|
|
outputDataType: 'int64',
|
|
},
|
|
output: {shape: [1, 3, 4]}
|
|
},
|
|
{
|
|
name:
|
|
'[argMin/Max] Throw if the value in axis is greater than or equal to input rank.',
|
|
input: {dataType: 'float32', shape: [1, 2, 3, 4]},
|
|
axis: 4,
|
|
options: {
|
|
label: label,
|
|
},
|
|
},
|
|
{
|
|
name: '[argMin/Max] Throw if input is a scalar and axis=0.',
|
|
input: {dataType: 'float32', shape: []},
|
|
axis: 0,
|
|
options: {
|
|
label: label,
|
|
},
|
|
},
|
|
];
|
|
|
|
function runTests(operatorName, tests) {
|
|
tests.forEach(test => {
|
|
promise_test(async t => {
|
|
const builder = new MLGraphBuilder(context);
|
|
const input = builder.input('input', test.input);
|
|
const axis = test.axis;
|
|
if (!context.opSupportLimits()[operatorName].input.dataTypes.includes(test.input.dataType)){
|
|
assert_throws_js(
|
|
TypeError, () => builder[operatorName](input, axis, test.options));
|
|
return;
|
|
}
|
|
if (test.options && test.options.outputDataType !== undefined) {
|
|
if (context.opSupportLimits()[operatorName].output.dataTypes.includes(
|
|
test.options.outputDataType)) {
|
|
const output = builder[operatorName](input, axis, test.options);
|
|
assert_equals(output.dataType, test.options.outputDataType);
|
|
assert_array_equals(output.shape, test.output.shape);
|
|
} else {
|
|
assert_throws_js(
|
|
TypeError, () => builder[operatorName](input, axis, test.options));
|
|
}
|
|
return;
|
|
}
|
|
if (test.output) {
|
|
const output = builder[operatorName](input, axis, test.options);
|
|
assert_equals(output.dataType, 'int32');
|
|
assert_array_equals(output.shape, test.output.shape);
|
|
} else {
|
|
const regrexp = /\[arg_min_max_1_\!\]/;
|
|
assert_throws_with_label(
|
|
() => builder[operatorName](input, axis, test.options), regrexp);
|
|
}
|
|
}, test.name.replace('[argMin/Max]', `[${operatorName}]`));
|
|
});
|
|
}
|
|
|
|
kArgMinMaxOperators.forEach((operatorName) => {
|
|
validateInputFromAnotherBuilder(operatorName);
|
|
runTests(operatorName, tests);
|
|
});
|