summaryrefslogtreecommitdiffstats
path: root/testing/web-platform/tests/webnn/gemm.https.any.js
blob: 179a1649b6e3077ba0cfd73bbce449d9683d461f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
// META: title=test WebNN API gemm operation
// META: global=window,dedicatedworker
// META: script=./resources/utils.js
// META: timeout=long

'use strict';

// https://webmachinelearning.github.io/webnn/#api-mlgraphbuilder-gemm

const buildGemm= (operationName, builder, resources) => {
  // MLOperand gemm(MLOperand a, MLOperand b, optional MLGemmOptions options = {});
  const namedOutputOperand = {};
  const [inputOperandA, inputOperandB] = createMultiInputOperands(builder, resources);
  let gemmOptions = {...resources.options};
  if (gemmOptions.c) {
    if (gemmOptions.c.shape) {
      gemmOptions.c = createConstantOperand(builder, gemmOptions.c);
    } else {
      // MLOperand c;
      // Create a single-value operand when c is a scalar
      gemmOptions.c = builder.constant(gemmOptions.c);
    }
  }
  namedOutputOperand[resources.expected.name] = builder[operationName](inputOperandA, inputOperandB, gemmOptions);
  return namedOutputOperand;
};

testWebNNOperation('gemm', buildGemm);