'use strict'; // https://www.w3.org/TR/webnn/#enumdef-mloperanddatatype const allWebNNOperandDataTypes = [ 'float32', 'float16', 'int32', 'uint32', 'int64', 'uint64', 'int8', 'uint8' ]; // https://webidl.spec.whatwg.org/#idl-unsigned-long // The unsigned long type is an unsigned integer type that has values in the // range [0, 4294967295]. // 4294967295 = 2 ** 32 - 1 const kMaxUnsignedLong = 2 ** 32 - 1; const floatingPointTypes = ['float32', 'float16']; const signedIntegerTypes = ['int32', 'int64', 'int8']; const unsignedLongType = 'unsigned long'; const shape0D = []; const shape1D = [2]; const shape2D = [2, 3]; const shape3D = [2, 3, 4]; const shape4D = [2, 3, 4, 5]; const shape5D = [2, 3, 4, 5, 6]; const adjustOffsetsArray = [ // Decrease 1 -1, // Increase 1 1 ]; // TODO // Add more 5+ dimensions const allWebNNShapesArray = [shape0D, shape1D, shape2D, shape3D, shape4D, shape5D]; const notUnsignedLongAxisArray = [ // String 'abc', // BigInt BigInt(100), // Object { value: 1 }, // Array Object [0, 1], // Date Object new Date("2024-01-01"), ]; function getRank(inputShape) { return inputShape.length; } function getAxisArray(inputShape) { return Array.from({length: inputShape.length}, (_, i) => i); } function getAxesArrayContainSameValues(inputShape) { // TODO // Currently this function returns an array containing each element which all have the same value. // For example axes: [0, 1, 2] for 3D input tensor // this function returns // [ // // two values are same // [0, 0], // [1, 1], // [2, 2], // // three values are same // [0, 0, 0], // [1, 1, 1] // [2, 2, 2] // ] // while it should return // [ // // two values are same // [0, 0], // [1, 1], // [2, 2], // [0, 0, 1], // [0, 0, 2], // [0, 1, 0], // [0, 2, 0], // [1, 0, 0], // [2, 0, 0], // [1, 1, 0], // [1, 1, 2], // [1, 0, 1], // [1, 2, 1], // [0, 1, 1], // [2, 1, 1], // [2, 2, 0], // [2, 2, 1], // [2, 0, 2], // [2, 1, 2], // [0, 2, 2], // [1, 2, 2], // // three (all) values are same // [0, 0, 0], // [1, 1, 1] // [2, 2, 2] // ] const axesArrayContainSameValues = []; const length = inputShape.length; if (length >= 2) { const validAxesArrayFull = getAxisArray(inputShape); for (let index = 0; index < length; index++) { axesArrayContainSameValues.push(new Array(2).fill(validAxesArrayFull[index])); if (length > 2) { axesArrayContainSameValues.push(new Array(3).fill(validAxesArrayFull[index])); } } } return axesArrayContainSameValues; } function generateUnbroadcastableShapes(shape) { // Currently this function returns an array of some unbroadcastable shapes. // for example given the input shape [2, 3, 4] // this function returns // [ // [3, 3, 4], // [2, 2, 4], // [2, 4, 4], // [2, 3, 3], // [2, 3, 5], // [3], // [5], // [1, 3], // [1, 5], // [1, 1, 3], // [1, 1, 5], // [1, 1, 1, 3], // [1, 1, 1, 5], // ] if (shape.every(dimension => dimension === 1)) { throw new Error(`[${shape}] always can be broadcasted`); } const resultShapes = []; const length = shape.length; if (!shape.slice(0, length - 1).every(dimension => dimension === 1)) { for (let i = 0; i < length; i++) { if (shape[i] !== 1) { for (let offset of [-1, 1]) { const shapeB = shape.slice(); shapeB[i] += offset; if (shapeB[i] !== 1) { resultShapes.push(shapeB); } } } } } const lastDimensionSize = shape[length - 1]; if (lastDimensionSize !== 1) { for (let j = 0; j <= length; j++) { if (lastDimensionSize > 2) { resultShapes.push(Array(j).fill(1).concat([lastDimensionSize - 1])); } resultShapes.push(Array(j).fill(1).concat([lastDimensionSize + 1])); } } return resultShapes; } function generateOutOfRangeValuesArray(type) { let range, outsideValueArray; switch (type) { case 'unsigned long': range = [0, kMaxUnsignedLong]; break; default: throw new Error(`Unsupport ${type}`); } outsideValueArray = [range[0] - 1, range[1] + 1]; return outsideValueArray; } let inputIndex = 0; let inputAIndex = 0; let inputBIndex = 0; let context; test(() => assert_not_equals(navigator.ml, undefined, "ml property is defined on navigator")); promise_setup(async () => { if (navigator.ml === undefined) { return; } const deviceType = new URLSearchParams(location.search).get('device') || location.search.substring(1); context = await navigator.ml.createContext({deviceType: deviceType}); }, {explicit_timeout: true}); function assert_throws_with_label(func, regrexp) { try { func.call(this); assert_unreached('Graph builder method unexpectedly succeeded'); } catch (e) { assert_equals(e.name, 'TypeError'); const error_message = e.message; assert_not_equals(error_message.match(regrexp), null); } } function validateTwoInputsBroadcastable(operationName, label) { if (navigator.ml === undefined) { return; } promise_test(async t => { const builder = new MLGraphBuilder(context); for (let dataType of allWebNNOperandDataTypes) { if (!context.opSupportLimits().input.dataTypes.includes(dataType)) { assert_throws_js( TypeError, () => builder.input( `inputA${++inputAIndex}`, {dataType, shape: shape1D})); continue; } for (let shape of allWebNNShapesArray) { if (shape.length > 0) { const inputA = builder.input(`inputA${++inputAIndex}`, {dataType, shape}); const unbroadcastableShapes = generateUnbroadcastableShapes(shape); for (let shape of unbroadcastableShapes) { const inputB = builder.input(`inputB${++inputBIndex}`, {dataType, shape}); assert_equals(typeof builder[operationName], 'function'); const options = {label}; const regrexp = new RegExp('\\[' + label + '\\]'); assert_throws_with_label( () => builder[operationName](inputA, inputB, options), regrexp); assert_throws_with_label( () => builder[operationName](inputB, inputA, options), regrexp); } } } } }, `[${operationName}] TypeError is expected if two inputs aren't broadcastable`); } function validateTwoBroadcastableInputsTensorLimit(operationName, label) { if (navigator.ml === undefined) { return; } promise_test(async t => { const builder = new MLGraphBuilder(context); const a = builder.input('a', {dataType: 'float32', shape: [context.opSupportLimits().maxTensorByteLength / 4, 1]}); const b = builder.input('b', {dataType: 'float32', shape: [1, 5] }); const options = {label}; const regrexp = new RegExp('\\[' + label + '\\]'); assert_throws_with_label( () => builder[operationName](a, b, options), regrexp); }, `[${operationName}] throw if the output tensor byte length exceeds limit`); } function validateTwoInputsOfSameDataType(operationName, label) { if (navigator.ml === undefined) { return; } let operationNameArray; if (typeof operationName === 'string') { operationNameArray = [operationName]; } else if (Array.isArray(operationName)) { operationNameArray = operationName; } else { throw new Error(`${operationName} should be an operation name string or an operation name string array`); } for (let subOperationName of operationNameArray) { promise_test(async t => { const builder = new MLGraphBuilder(context); for (let dataType of allWebNNOperandDataTypes) { if (!context.opSupportLimits().input.dataTypes.includes(dataType)) { assert_throws_js( TypeError, () => builder.input( `inputA${++inputAIndex}`, {dataType, shape: shape1D})); continue; } for (let shape of allWebNNShapesArray) { const inputA = builder.input(`inputA${++inputAIndex}`, {dataType, shape}); for (let dataTypeB of allWebNNOperandDataTypes) { if (!context.opSupportLimits().input.dataTypes.includes( dataTypeB)) { assert_throws_js( TypeError, () => builder.input( `inputB${++inputBIndex}`, {dataTypeB, shape1D})); continue; } if (dataType !== dataTypeB) { const inputB = builder.input( `inputB${++inputBIndex}`, {dataType: dataTypeB, shape}); const options = {label}; const regrexp = new RegExp('\\[' + label + '\\]'); assert_equals(typeof builder[subOperationName], 'function'); assert_throws_with_label( () => builder[subOperationName](inputA, inputB, options), regrexp); } } } } }, `[${subOperationName}] TypeError is expected if two inputs aren't of same data type`); } } /** * Validate options.axes by given operation and input rank for * argMin/Max / layerNormalization / Reduction operations operations * @param {(String[]|String)} operationName - An operation name array or an * operation name */ function validateOptionsAxes(operationName) { if (navigator.ml === undefined) { return; } let operationNameArray; if (typeof operationName === 'string') { operationNameArray = [operationName]; } else if (Array.isArray(operationName)) { operationNameArray = operationName; } else { throw new Error(`${operationName} should be an operation name string or an operation name string array`); } const invalidAxisArray = generateOutOfRangeValuesArray(unsignedLongType); for (let subOperationName of operationNameArray) { // TypeError is expected if any of options.axes elements is not an unsigned long interger promise_test(async t => { const builder = new MLGraphBuilder(context); for (let dataType of allWebNNOperandDataTypes) { if (!context.opSupportLimits().input.dataTypes.includes(dataType)) { assert_throws_js( TypeError, () => builder.input( `inputA${++inputAIndex}`, {dataType, shape: shape1D})); continue; } for (let shape of allWebNNShapesArray) { const rank = getRank(shape); if (rank >= 1) { const input = builder.input(`input${++inputIndex}`, {dataType, shape}); for (let invalidAxis of invalidAxisArray) { assert_equals(typeof builder[subOperationName], 'function'); assert_throws_js( TypeError, () => builder[subOperationName](input, {axes: invalidAxis})); } for (let axis of notUnsignedLongAxisArray) { assert_false( typeof axis === 'number' && Number.isInteger(axis), `[${subOperationName}] any of options.axes elements should be of 'unsigned long'`); assert_equals(typeof builder[subOperationName], 'function'); assert_throws_js( TypeError, () => builder[subOperationName](input, {axes: [axis]})); } } } } }, `[${subOperationName}] TypeError is expected if any of options.axes elements is not an unsigned long interger`); // TypeError is expected if any of options.axes elements is greater or equal // to the size of input promise_test(async t => { const builder = new MLGraphBuilder(context); for (let dataType of allWebNNOperandDataTypes) { if (!context.opSupportLimits().input.dataTypes.includes(dataType)) { assert_throws_js( TypeError, () => builder.input( `inputA${++inputAIndex}`, {dataType, shape: shape1D})); continue; } for (let shape of allWebNNShapesArray) { const rank = getRank(shape); if (rank >= 1) { const input = builder.input(`input${++inputIndex}`, {dataType, shape}); assert_equals(typeof builder[subOperationName], 'function'); assert_throws_js( TypeError, () => builder[subOperationName](input, {axes: [rank]})); assert_throws_js( TypeError, () => builder[subOperationName](input, {axes: [rank + 1]})); } } } }, `[${subOperationName}] TypeError is expected if any of options.axes elements is greater or equal to the size of input`); // TypeError is expected if two or more values are same in the axes sequence promise_test(async t => { const builder = new MLGraphBuilder(context); for (let dataType of allWebNNOperandDataTypes) { if (!context.opSupportLimits().input.dataTypes.includes(dataType)) { assert_throws_js( TypeError, () => builder.input( `inputA${++inputAIndex}`, {dataType, shape: shape1D})); continue; } for (let shape of allWebNNShapesArray) { const rank = getRank(shape); if (rank >= 2) { const input = builder.input(`input${++inputIndex}`, {dataType, shape}); const axesArrayContainSameValues = getAxesArrayContainSameValues(shape); for (let axes of axesArrayContainSameValues) { assert_equals(typeof builder[subOperationName], 'function'); assert_throws_js( TypeError, () => builder[subOperationName](input, {axes})); } } } } }, `[${subOperationName}] TypeError is expected if two or more values are same in the axes sequence`); } } // TODO: remove this method once all the data type limits of the unary // operations are specified in context.OpSupportLimits(). /** * Validate a unary operation * @param {String} operationName - An operation name * @param {Array} supportedDataTypes - Test building with these data types * succeeds and test building with all other data types fails */ function validateUnaryOperation(operationName, supportedDataTypes, label) { promise_test(async t => { const builder = new MLGraphBuilder(context); for (let dataType of supportedDataTypes) { if (!context.opSupportLimits().input.dataTypes.includes(dataType)) { assert_throws_js( TypeError, () => builder.input( `inputA${++inputAIndex}`, {dataType, shape: shape1D})); continue; } for (let shape of allWebNNShapesArray) { const input = builder.input(`input`, {dataType, shape}); assert_equals(typeof builder[operationName], 'function'); const output = builder[operationName](input); assert_equals(output.dataType, dataType); assert_array_equals(output.shape, shape); } } }, `[${operationName}] Test building an unary operator with supported type.`); const unsupportedDataTypes = new Set(allWebNNOperandDataTypes).difference(new Set(supportedDataTypes)); promise_test(async t => { const builder = new MLGraphBuilder(context); for (let dataType of unsupportedDataTypes) { if (!context.opSupportLimits().input.dataTypes.includes(dataType)) { assert_throws_js( TypeError, () => builder.input( `inputA${++inputAIndex}`, {dataType, shape: shape1D})); continue; } for (let shape of allWebNNShapesArray) { const input = builder.input(`input`, {dataType, shape}); assert_equals(typeof builder[operationName], 'function'); const options = {label}; const regrexp = new RegExp('\\[' + label + '\\]'); assert_throws_with_label( () => builder[operationName](input, options), regrexp); } } }, `[${operationName}] Throw if the dataType is not supported for an unary operator.`); } /** * Validate a single input operation * @param {String} operationName - An operation name */ function validateSingleInputOperation(operationName, label) { promise_test(async t => { const builder = new MLGraphBuilder(context); const supportedDataTypes = context.opSupportLimits()[operationName].input.dataTypes; for (let dataType of supportedDataTypes) { if (!context.opSupportLimits().input.dataTypes.includes(dataType)) { continue; } for (let shape of allWebNNShapesArray) { const input = builder.input(`input`, {dataType, shape}); const output = builder[operationName](input); assert_equals(output.dataType, dataType); assert_array_equals(output.shape, shape); } } }, `[${operationName}] Test building the operator with supported data type.`); promise_test(async t => { const builder = new MLGraphBuilder(context); const unsupportedDataTypes = new Set(allWebNNOperandDataTypes) .difference(new Set( context.opSupportLimits()[operationName].input.dataTypes)); for (let dataType of unsupportedDataTypes) { if (!context.opSupportLimits().input.dataTypes.includes(dataType)) { assert_throws_js( TypeError, () => builder.input( `inputA${++inputAIndex}`, {dataType, shape: shape1D})); continue; } for (let shape of allWebNNShapesArray) { const input = builder.input(`input`, {dataType, shape}); assert_equals(typeof builder[operationName], 'function'); const options = {label}; const regrexp = new RegExp('\\[' + label + '\\]'); assert_throws_with_label( () => builder[operationName](input, options), regrexp); } } }, `[${operationName}] Throw if the data type is not supported for the operator.`); } /** * Basic test that the builder method specified by `operationName` throws if * given an input from another builder. Operands which do not accept a float32 * square 2D input should pass their own `operatorDescriptor`. * @param {String} operationName * @param {String} operatorDescriptor */ function validateInputFromAnotherBuilder(operatorName, operatorDescriptor = { dataType: 'float32', shape: [2, 2] }) { multi_builder_test(async (t, builder, otherBuilder) => { const inputFromOtherBuilder = otherBuilder.input('input', operatorDescriptor); assert_equals(typeof builder[operatorName], 'function'); assert_throws_js( TypeError, () => builder[operatorName](inputFromOtherBuilder)); }, `[${operatorName}] throw if input is from another builder`); }; /** * Basic test that the builder method specified by `operationName` throws if one * of its inputs is from another builder. This helper may only be used by * operands which accept float32 square 2D inputs. * @param {String} operationName */ function validateTwoInputsFromMultipleBuilders(operatorName) { const opDescriptor = {dataType: 'float32', shape: [2, 2]}; multi_builder_test(async (t, builder, otherBuilder) => { const inputFromOtherBuilder = otherBuilder.input('other', opDescriptor); const input = builder.input('input', opDescriptor); assert_equals(typeof builder[operatorName], 'function'); assert_throws_js( TypeError, () => builder[operatorName](inputFromOtherBuilder, input)); }, `[${operatorName}] throw if first input is from another builder`); multi_builder_test(async (t, builder, otherBuilder) => { const inputFromOtherBuilder = otherBuilder.input('other', opDescriptor); const input = builder.input('input', opDescriptor); assert_equals(typeof builder[operatorName], 'function'); assert_throws_js( TypeError, () => builder[operatorName](input, inputFromOtherBuilder)); }, `[${operatorName}] throw if second input is from another builder`); }; function multi_builder_test(func, description) { promise_test(async t => { const builder = new MLGraphBuilder(context); const otherBuilder = new MLGraphBuilder(context); await func(t, builder, otherBuilder); }, description); }