// META: title=validation tests for WebNN API argMin/Max operations // META: global=window // META: variant=?cpu // META: variant=?gpu // META: variant=?npu // META: script=../resources/utils_validation.js 'use strict'; const kArgMinMaxOperators = [ 'argMin', 'argMax', ]; const label = 'arg_min_max_1_!'; const tests = [ { name: '[argMin/Max] Test with default options.', input: {dataType: 'float32', shape: [1, 2, 3, 4]}, axis: 0, output: {shape: [2, 3, 4]} }, { name: '[argMin/Max] Test with axes=1.', input: {dataType: 'float32', shape: [1, 2, 3, 4]}, axis: 1, output: {shape: [1, 3, 4]} }, { name: '[argMin/Max] Test with outputDataType=int32', input: {dataType: 'float32', shape: [1, 2, 3, 4]}, axis: 1, options: { outputDataType: 'int32', }, output: {shape: [1, 3, 4]} }, { name: '[argMin/Max] Test with outputDataType=int64', input: {dataType: 'float32', shape: [1, 2, 3, 4]}, axis: 1, options: { outputDataType: 'int64', }, output: {shape: [1, 3, 4]} }, { name: '[argMin/Max] Throw if the value in axis is greater than or equal to input rank.', input: {dataType: 'float32', shape: [1, 2, 3, 4]}, axis: 4, options: { label: label, }, }, { name: '[argMin/Max] Throw if input is a scalar and axis=0.', input: {dataType: 'float32', shape: []}, axis: 0, options: { label: label, }, }, ]; function runTests(operatorName, tests) { tests.forEach(test => { promise_test(async t => { const builder = new MLGraphBuilder(context); const input = builder.input('input', test.input); const axis = test.axis; if (!context.opSupportLimits()[operatorName].input.dataTypes.includes(test.input.dataType)){ assert_throws_js( TypeError, () => builder[operatorName](input, axis, test.options)); return; } if (test.options && test.options.outputDataType !== undefined) { if (context.opSupportLimits()[operatorName].output.dataTypes.includes( test.options.outputDataType)) { const output = builder[operatorName](input, axis, test.options); assert_equals(output.dataType, test.options.outputDataType); assert_array_equals(output.shape, test.output.shape); } else { assert_throws_js( TypeError, () => builder[operatorName](input, axis, test.options)); } return; } if (test.output) { const output = builder[operatorName](input, axis, test.options); assert_equals(output.dataType, 'int32'); assert_array_equals(output.shape, test.output.shape); } else { const regrexp = /\[arg_min_max_1_\!\]/; assert_throws_with_label( () => builder[operatorName](input, axis, test.options), regrexp); } }, test.name.replace('[argMin/Max]', `[${operatorName}]`)); }); } kArgMinMaxOperators.forEach((operatorName) => { validateInputFromAnotherBuilder(operatorName); runTests(operatorName, tests); });