110 lines
3.6 KiB
JavaScript
110 lines
3.6 KiB
JavaScript
// META: title=validation tests for WebNN API softmax operation
|
|
// META: global=window
|
|
// META: variant=?cpu
|
|
// META: variant=?gpu
|
|
// META: variant=?npu
|
|
// META: script=../resources/utils_validation.js
|
|
|
|
'use strict';
|
|
|
|
const tests_without_axis = [
|
|
{
|
|
name: '[softmax] Test building Softmax with float32 input without axis.',
|
|
input: {dataType: 'float32', shape: [4, 3]},
|
|
output: {dataType: 'float32', shape: [4, 3]}
|
|
},
|
|
{
|
|
name: '[softmax] Test building Softmax with float16 input without axis.',
|
|
input: {dataType: 'float16', shape: [3, 5]},
|
|
output: {dataType: 'float16', shape: [3, 5]}
|
|
},
|
|
{
|
|
name: '[softmax] Throw if the input is not a non-floating point data.',
|
|
input: {dataType: 'int32', shape: [3, 2]}
|
|
},
|
|
{
|
|
name: '[softmax] Throw if the input dimensions is not 2.',
|
|
input: {dataType: 'float32', shape: [1, 4, 3]}
|
|
}
|
|
];
|
|
|
|
tests_without_axis.forEach(
|
|
test => promise_test(async t => {
|
|
const builder = new MLGraphBuilder(context);
|
|
let input = builder.input(`input`, test.input);
|
|
if (test.output) {
|
|
const output = builder.softmax(input);
|
|
assert_equals(output.dataType, test.output.dataType);
|
|
assert_array_equals(output.shape, test.output.shape);
|
|
} else {
|
|
const options = {
|
|
label: 'softmax_xxx',
|
|
};
|
|
try {
|
|
builder.softmax(input, options);
|
|
} catch (e) {
|
|
assert_equals(e.name, 'TypeError');
|
|
const error_message = e.message;
|
|
const regrexp = /\[softmax_xxx\]/;
|
|
assert_not_equals(error_message.match(regrexp), null);
|
|
}
|
|
}
|
|
}, test.name));
|
|
|
|
multi_builder_test(async (t, builder, otherBuilder) => {
|
|
const operandDescriptor = {dataType: 'float32', shape: [2, 3]};
|
|
const inputFromOtherBuilder = otherBuilder.input('input', operandDescriptor);
|
|
|
|
assert_throws_js(TypeError, () => builder.softmax(inputFromOtherBuilder));
|
|
}, '[softmax without axis] throw if any input is from another builder');
|
|
|
|
const tests = [
|
|
{
|
|
name: '[softmax] Test building Softmax with float32 input.',
|
|
input: {dataType: 'float32', shape: [4, 4, 3]},
|
|
axis: 1,
|
|
output: {dataType: 'float32', shape: [4, 4, 3]}
|
|
},
|
|
{
|
|
name: '[softmax] Test building Softmax with float16 input.',
|
|
input: {dataType: 'float16', shape: [3, 1, 5, 2]},
|
|
axis: 2,
|
|
output: {dataType: 'float16', shape: [3, 1, 5, 2]}
|
|
},
|
|
{
|
|
name: '[softmax] Throw if the input is not a non-floating-point data.',
|
|
input: {dataType: 'int32', shape: [3, 1, 5, 2]},
|
|
axis: 3
|
|
},
|
|
{
|
|
name: '[softmax] Throw if the axis is greater than input rank - 1.',
|
|
input: {dataType: 'float16', shape: [3, 1, 5, 2]},
|
|
axis: 4
|
|
}
|
|
];
|
|
|
|
tests.forEach(
|
|
test => promise_test(async t => {
|
|
const builder = new MLGraphBuilder(context);
|
|
let input = builder.input(`input`, test.input);
|
|
if (test.output) {
|
|
const output = builder.softmax(input, test.axis);
|
|
assert_equals(output.dataType, test.output.dataType);
|
|
assert_array_equals(output.shape, test.output.shape);
|
|
} else {
|
|
const label = 'softmax_xxx';
|
|
const options = {label};
|
|
const regrexp = new RegExp('\\[' + label + '\\]');
|
|
assert_throws_with_label(
|
|
() => builder.softmax(input, test.axis, options), regrexp);
|
|
}
|
|
}, test.name));
|
|
|
|
multi_builder_test(async (t, builder, otherBuilder) => {
|
|
const operandDescriptor = {dataType: 'float32', shape: [1, 2, 3]};
|
|
const inputFromOtherBuilder = otherBuilder.input('input', operandDescriptor);
|
|
const axis = 1;
|
|
|
|
assert_throws_js(
|
|
TypeError, () => builder.softmax(inputFromOtherBuilder, axis));
|
|
}, '[softmax] throw if any input is from another builder');
|