summaryrefslogtreecommitdiffstats
path: root/dom/webgpu/tests/cts/checkout/src/webgpu/util/check_contents.ts
blob: 325cae41a24b16d9ac01ea7180e50635425dd6c8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
// MAINTENANCE_TODO: The "checkThingTrue" naming is confusing; these must be used with `expectOK`
// or the result is dropped on the floor. Rename these to things like `typedArrayIsOK`(??) to
// make it clearer.
// MAINTENANCE_TODO: Also, audit to make sure we aren't dropping any on the floor. Consider a
// no-ignored-return lint check if we can find one that we can use.

import {
  assert,
  ErrorWithExtra,
  iterRange,
  range,
  TypedArrayBufferView,
  TypedArrayBufferViewConstructor,
} from '../../common/util/util.js';

import { float16BitsToFloat32 } from './conversion.js';
import { generatePrettyTable } from './pretty_diff_tables.js';

/** Generate an expected value at `index`, to test for equality with the actual value. */
export type CheckElementsGenerator = (index: number) => number;
/** Check whether the actual `value` at `index` is as expected. */
export type CheckElementsPredicate = (index: number, value: number) => boolean;
/**
 * Provides a pretty-printing implementation for a particular CheckElementsPredicate.
 * This is an array; each element provides info to print an additional row in the error message.
 */
export type CheckElementsSupplementalTableRows = Array<{
  /** Row header. */
  leftHeader: string;
  /**
   * Get the value for a cell in the table with element index `index`.
   * May be a string or a number; a number will be formatted according to the TypedArray type used.
   */
  getValueForCell: (index: number) => number | string;
}>;

/**
 * Check whether two `TypedArray`s have equal contents.
 * Returns `undefined` if the check passes, or an `Error` if not.
 */
export function checkElementsEqual(
  actual: TypedArrayBufferView,
  expected: TypedArrayBufferView
): ErrorWithExtra | undefined {
  assert(actual.constructor === expected.constructor, 'TypedArray type mismatch');
  assert(actual.length === expected.length, 'size mismatch');
  return checkElementsEqualGenerated(actual, i => expected[i]);
}

/**
 * Check whether each value in a `TypedArray` is between the two corresponding "expected" values
 * (either `a(i) <= actual[i] <= b(i)` or `a(i) >= actual[i] => b(i)`).
 */
export function checkElementsBetween(
  actual: TypedArrayBufferView,
  expected: readonly [CheckElementsGenerator, CheckElementsGenerator]
): ErrorWithExtra | undefined {
  const error = checkElementsPassPredicate(
    actual,
    (index, value) =>
      value >= Math.min(expected[0](index), expected[1](index)) &&
      value <= Math.max(expected[0](index), expected[1](index)),
    {
      predicatePrinter: [
        { leftHeader: 'between', getValueForCell: index => expected[0](index) },
        { leftHeader: 'and', getValueForCell: index => expected[1](index) },
      ],
    }
  );
  // If there was an error, extend it with additional extras.
  return error ? new ErrorWithExtra(error, () => ({ expected })) : undefined;
}

/**
 * Equivalent to {@link checkElementsBetween} but interpret values as float16 and convert to JS number before comparison.
 */
export function checkElementsFloat16Between(
  actual: TypedArrayBufferView,
  expected: readonly [TypedArrayBufferView, TypedArrayBufferView]
): ErrorWithExtra | undefined {
  assert(actual.BYTES_PER_ELEMENT === 2, 'bytes per element need to be 2 (16bit)');
  const actualF32 = new Float32Array(actual.length);
  actual.forEach((v: number, i: number) => {
    actualF32[i] = float16BitsToFloat32(v);
  });
  const expectedF32 = [new Float32Array(expected[0].length), new Float32Array(expected[1].length)];
  expected[0].forEach((v: number, i: number) => {
    expectedF32[0][i] = float16BitsToFloat32(v);
  });
  expected[1].forEach((v: number, i: number) => {
    expectedF32[1][i] = float16BitsToFloat32(v);
  });

  const error = checkElementsPassPredicate(
    actualF32,
    (index, value) =>
      value >= Math.min(expectedF32[0][index], expectedF32[1][index]) &&
      value <= Math.max(expectedF32[0][index], expectedF32[1][index]),
    {
      predicatePrinter: [
        { leftHeader: 'between', getValueForCell: index => expectedF32[0][index] },
        { leftHeader: 'and', getValueForCell: index => expectedF32[1][index] },
      ],
    }
  );
  // If there was an error, extend it with additional extras.
  return error ? new ErrorWithExtra(error, () => ({ expectedF32 })) : undefined;
}

/**
 * Check whether each value in a `TypedArray` is equal to one of the two corresponding "expected"
 * values (either `actual[i] === a[i]` or `actual[i] === b[i]`)
 */
export function checkElementsEqualEither(
  actual: TypedArrayBufferView,
  expected: readonly [TypedArrayBufferView, TypedArrayBufferView]
): ErrorWithExtra | undefined {
  const error = checkElementsPassPredicate(
    actual,
    (index, value) => value === expected[0][index] || value === expected[1][index],
    {
      predicatePrinter: [
        { leftHeader: 'either', getValueForCell: index => expected[0][index] },
        { leftHeader: 'or', getValueForCell: index => expected[1][index] },
      ],
    }
  );
  // If there was an error, extend it with additional extras.
  return error ? new ErrorWithExtra(error, () => ({ expected })) : undefined;
}

/**
 * Check whether a `TypedArray`'s contents equal the values produced by a generator function.
 * Returns `undefined` if the check passes, or an `Error` if not.
 *
 * ```text
 * Array had unexpected contents at indices 2 through 19.
 *  Starting at index 1:
 *    actual == 0x: 00 fe ff 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 00
 *    failed ->        xx xx    xx xx xx xx xx xx xx xx xx xx xx xx xx xx xx
 *  expected ==     00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00
 * ```
 *
 * ```text
 * Array had unexpected contents at indices 2 through 29.
 *  Starting at index 1:
 *    actual ==  0.000 -2.000e+100 -1.000e+100 0.000 1.000e+100 2.000e+100 3.000e+100 4.000e+100 5.000e+100 6.000e+100 7.000e+100 ...
 *    failed ->                 xx          xx               xx         xx         xx         xx         xx         xx         xx ...
 *  expected ==  0.000       0.000       0.000 0.000      0.000      0.000      0.000      0.000      0.000      0.000      0.000 ...
 * ```
 */
export function checkElementsEqualGenerated(
  actual: TypedArrayBufferView,
  generator: CheckElementsGenerator
): ErrorWithExtra | undefined {
  const error = checkElementsPassPredicate(actual, (index, value) => value === generator(index), {
    predicatePrinter: [{ leftHeader: 'expected ==', getValueForCell: index => generator(index) }],
  });
  // If there was an error, extend it with additional extras.
  return error ? new ErrorWithExtra(error, () => ({ generator })) : undefined;
}

/**
 * Check whether a `TypedArray`'s values pass the provided predicate function.
 * Returns `undefined` if the check passes, or an `Error` if not.
 */
export function checkElementsPassPredicate(
  actual: TypedArrayBufferView,
  predicate: CheckElementsPredicate,
  { predicatePrinter }: { predicatePrinter?: CheckElementsSupplementalTableRows }
): ErrorWithExtra | undefined {
  const size = actual.length;
  const ctor = actual.constructor as TypedArrayBufferViewConstructor;
  const printAsFloat = ctor === Float32Array || ctor === Float64Array;

  let failedElementsFirstMaybe: number | undefined = undefined;
  /** Sparse array with `true` for elements that failed. */
  const failedElements: (true | undefined)[] = [];
  for (let i = 0; i < size; ++i) {
    if (!predicate(i, actual[i])) {
      failedElementsFirstMaybe ??= i;
      failedElements[i] = true;
    }
  }

  if (failedElementsFirstMaybe === undefined) {
    return undefined;
  }
  const failedElementsFirst = failedElementsFirstMaybe;
  const failedElementsLast = failedElements.length - 1;

  // Include one extra non-failed element at the beginning and end (if they exist), for context.
  const printElementsStart = Math.max(0, failedElementsFirst - 1);
  const printElementsEnd = Math.min(size, failedElementsLast + 2);
  const printElementsCount = printElementsEnd - printElementsStart;

  const numberToString = printAsFloat
    ? (n: number) => n.toPrecision(4)
    : (n: number) => intToPaddedHex(n, { byteLength: ctor.BYTES_PER_ELEMENT });
  const numberPrefix = printAsFloat ? '' : '0x:';

  const printActual = actual.subarray(printElementsStart, printElementsEnd);
  const printExpected: Array<Iterable<string | number>> = [];
  if (predicatePrinter) {
    for (const { leftHeader, getValueForCell: cell } of predicatePrinter) {
      printExpected.push(
        (function* () {
          yield* [leftHeader, ''];
          yield* iterRange(printElementsCount, i => cell(printElementsStart + i));
        })()
      );
    }
  }

  const printFailedValueMarkers = (function* () {
    yield* ['failed ->', ''];
    yield* range(printElementsCount, i => (failedElements[printElementsStart + i] ? 'xx' : ''));
  })();

  const opts = {
    fillToWidth: 120,
    numberToString,
  };
  const msg = `Array had unexpected contents at indices ${failedElementsFirst} through ${failedElementsLast}.
 Starting at index ${printElementsStart}:
${generatePrettyTable(opts, [
  ['actual ==', numberPrefix, ...printActual],
  printFailedValueMarkers,
  ...printExpected,
])}`;
  return new ErrorWithExtra(msg, () => ({
    actual: actual.slice(),
  }));
}

// Helper helpers

/** Convert an integral `number` into a hex string, padded to the specified `byteLength`. */
function intToPaddedHex(number: number, { byteLength }: { byteLength: number }) {
  assert(Number.isInteger(number), 'number must be integer');
  let s = Math.abs(number).toString(16);
  if (byteLength) s = s.padStart(byteLength * 2, '0');
  if (number < 0) s = '-' + s;
  return s;
}