summaryrefslogtreecommitdiffstats
path: root/testing/web-platform/tests/webnn/validation_tests/concat.https.any.js
blob: b61f2d2bc779bf622e26da9658a65030795ae667 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
// META: title=validation tests for WebNN API concat operation
// META: global=window,dedicatedworker
// META: script=../resources/utils_validation.js

'use strict';

const tests = [
  {
    name: '[concat] Test building Concat with one input.',
    inputs: [{dataType: 'float32', dimensions: [4,4,3]}],
    axis: 2,
    output: {dataType: 'float32', dimensions: [4,4,3]}
  },
  {
    name: '[concat] Test building Concat with two inputs',
    inputs: [{dataType: 'float32', dimensions: [3,1,5]},
             {dataType: 'float32', dimensions: [3,2,5]}],
    axis: 1,
    output: {dataType: 'float32', dimensions: [3,3,5]}
  },
  {
    name: '[concat] Test building Concat with three inputs',
    inputs: [{dataType: 'float32', dimensions: [3,5,1]},
             {dataType: 'float32', dimensions: [3,5,2]},
             {dataType: 'float32', dimensions: [3,5,3]}],
    axis: 2,
    output: {dataType: 'float32', dimensions: [3,5,6]}
  },
  {
    name: '[concat] Test building Concat with two 1D inputs.',
    inputs: [{dataType: 'float32', dimensions: [1]},
             {dataType: 'float32', dimensions: [1]}],
    axis: 0,
    output: {dataType: 'float32', dimensions: [2]}
  },
  {
    name: '[concat] Throw if the inputs are empty.',
    axis: 0,
  },
  {
    name: '[concat] Throw if the argument types are inconsistent.',
    inputs: [{dataType: 'float32', dimensions: [1,1]},
             {dataType: 'int32', dimensions: [1,1]}],
    axis: 0,
  },
  {
    name: '[concat] Throw if the inputs have different ranks.',
    inputs: [{dataType: 'float32', dimensions: [1,1]},
             {dataType: 'float32', dimensions: [1,1,1]}],
    axis: 0,
  },
  {
    name: '[concat] Throw if the axis is equal to or greater than the size of ranks',
    inputs: [{dataType: 'float32', dimensions: [1,1]},
             {dataType: 'float32', dimensions: [1,1]}],
    axis: 2,
  },
  {
    name: '[concat] Throw if concat with two 0-D scalars.',
    inputs: [{dataType: 'float32', dimensions: []},
             {dataType: 'float32', dimensions: []}],
    axis: 0,
  },
  {
    name: '[concat] Throw if the inputs have other axes with different sizes except on the axis.',
    inputs: [{dataType: 'float32', dimensions: [1,1,1]},
             {dataType: 'float32', dimensions: [1,2,3]}],
    axis: 1,
  },

];

tests.forEach(test =>
    promise_test(async t => {
      let inputs = [];
      if (test.inputs) {
        for (let i = 0; i < test.inputs.length; ++i) {
          inputs[i] = builder.input(
            `inputs[${i}]`,
            { dataType: test.inputs[i].dataType, dimensions: test.inputs[i].dimensions }
          );
        }
      }
      if (test.output) {
        const output = builder.concat(inputs, test.axis);
        assert_equals(output.dataType(), test.output.dataType);
        assert_array_equals(output.shape(), test.output.dimensions);
      } else {
        assert_throws_js(TypeError, () => builder.concat(inputs, test.axis));
      }
    }, test.name)
  );

multi_builder_test(async (t, builder, otherBuilder) => {
  const operandDescriptor = {dataType: 'float32', dimensions: [2, 2]};

  const inputFromOtherBuilder = otherBuilder.input('input', operandDescriptor);

  const input1 = builder.input('input', operandDescriptor);
  const input2 = builder.input('input', operandDescriptor);
  const input3 = builder.input('input', operandDescriptor);

  assert_throws_js(
      TypeError,
      () => builder.concat([input1, input2, inputFromOtherBuilder, input3]));
}, '[concat] throw if any input is from another builder');