// META: title=validation tests for WebNN API concat operation // META: global=window // META: variant=?cpu // META: variant=?gpu // META: variant=?npu // META: script=../resources/utils_validation.js 'use strict'; const label = `concate_123`; const tests = [ { name: '[concat] Test building Concat with one input.', inputs: [{dataType: 'float32', shape: [4, 4, 3]}], axis: 2, output: {dataType: 'float32', shape: [4, 4, 3]} }, { name: '[concat] Test building Concat with two inputs', inputs: [ {dataType: 'float32', shape: [3, 1, 5]}, {dataType: 'float32', shape: [3, 2, 5]} ], axis: 1, output: {dataType: 'float32', shape: [3, 3, 5]} }, { name: '[concat] Test building Concat with three inputs', inputs: [ {dataType: 'float32', shape: [3, 5, 1]}, {dataType: 'float32', shape: [3, 5, 2]}, {dataType: 'float32', shape: [3, 5, 3]} ], axis: 2, output: {dataType: 'float32', shape: [3, 5, 6]} }, { name: '[concat] Test building Concat with two 1D inputs.', inputs: [{dataType: 'float32', shape: [1]}, {dataType: 'float32', shape: [1]}], axis: 0, output: {dataType: 'float32', shape: [2]} }, { name: '[concat] Throw if the inputs are empty.', axis: 0, }, { name: '[concat] Throw if the argument types are inconsistent.', inputs: [ {dataType: 'float32', shape: [1, 1]}, {dataType: 'int32', shape: [1, 1]} ], axis: 0, }, { name: '[concat] Throw if the inputs have different ranks.', inputs: [ {dataType: 'float32', shape: [1, 1]}, {dataType: 'float32', shape: [1, 1, 1]} ], axis: 0, }, { name: '[concat] Throw if the axis is equal to or greater than the size of ranks', inputs: [ {dataType: 'float32', shape: [1, 1]}, {dataType: 'float32', shape: [1, 1]} ], axis: 2, }, { name: '[concat] Throw if concat with two 0-D scalars.', inputs: [{dataType: 'float32', shape: []}, {dataType: 'float32', shape: []}], axis: 0, }, { name: '[concat] Throw if the inputs have other axes with different sizes except on the axis.', inputs: [ {dataType: 'float32', shape: [1, 1, 1]}, {dataType: 'float32', shape: [1, 2, 3]} ], axis: 1, }, ]; tests.forEach( test => promise_test(async t => { const builder = new MLGraphBuilder(context); let inputs = []; if (test.inputs) { for (let i = 0; i < test.inputs.length; ++i) { inputs[i] = builder.input(`inputs[${i}]`, test.inputs[i]); } } if (test.output) { const output = builder.concat(inputs, test.axis); assert_equals(output.dataType, test.output.dataType); assert_array_equals(output.shape, test.output.shape); } else { const options = {label}; const regrexp = new RegExp('\\[' + label + '\\]'); assert_throws_with_label( () => builder.concat(inputs, test.axis, options), regrexp); } }, test.name)); multi_builder_test(async (t, builder, otherBuilder) => { const operandDescriptor = {dataType: 'float32', shape: [2, 2]}; const inputFromOtherBuilder = otherBuilder.input('input', operandDescriptor); const input1 = builder.input('input', operandDescriptor); const input2 = builder.input('input', operandDescriptor); const input3 = builder.input('input', operandDescriptor); assert_throws_js( TypeError, () => builder.concat([input1, input2, inputFromOtherBuilder, input3])); }, '[concat] throw if any input is from another builder'); promise_test(async t => { const builder = new MLGraphBuilder(context); const operandDescriptor = { dataType: 'float32', shape: [context.opSupportLimits().maxTensorByteLength / 4] }; const input1 = builder.input('input1', operandDescriptor); const input2 = builder.input('input2', operandDescriptor); const input3 = builder.input('input3', operandDescriptor); assert_throws_js( TypeError, () => builder.concat(input1, input2, input3)); }, '[concat] throw if the output tensor byte length exceeds limit');