summaryrefslogtreecommitdiffstats
path: root/dom/webgpu/tests/cts/checkout/src/webgpu/util/compare.ts
diff options
context:
space:
mode:
Diffstat (limited to 'dom/webgpu/tests/cts/checkout/src/webgpu/util/compare.ts')
-rw-r--r--dom/webgpu/tests/cts/checkout/src/webgpu/util/compare.ts282
1 files changed, 282 insertions, 0 deletions
diff --git a/dom/webgpu/tests/cts/checkout/src/webgpu/util/compare.ts b/dom/webgpu/tests/cts/checkout/src/webgpu/util/compare.ts
new file mode 100644
index 0000000000..93fa55303f
--- /dev/null
+++ b/dom/webgpu/tests/cts/checkout/src/webgpu/util/compare.ts
@@ -0,0 +1,282 @@
+import { getIsBuildingDataCache } from '../../common/framework/data_cache.js';
+import { Colors } from '../../common/util/colors.js';
+import {
+ deserializeExpectation,
+ SerializedExpectation,
+ serializeExpectation,
+} from '../shader/execution/expression/case_cache.js';
+import { Expectation, toComparator } from '../shader/execution/expression/expression.js';
+
+import { isFloatValue, Scalar, Value, Vector } from './conversion.js';
+import { F32Interval } from './f32_interval.js';
+
+/** Comparison describes the result of a Comparator function. */
+export interface Comparison {
+ matched: boolean; // True if the two values were considered a match
+ got: string; // The string representation of the 'got' value (possibly with markup)
+ expected: string; // The string representation of the 'expected' value (possibly with markup)
+}
+
+/** Comparator is a function that compares whether the provided value matches an expectation. */
+export interface Comparator {
+ (got: Value): Comparison;
+}
+
+/**
+ * compares 'got' Value to 'expected' Value, returning the Comparison information.
+ * @param got the Value obtained from the test
+ * @param expected the expected Value
+ * @returns the comparison results
+ */
+function compareValue(got: Value, expected: Value): Comparison {
+ {
+ // Check types
+ const gTy = got.type;
+ const eTy = expected.type;
+ const bothFloatTypes = isFloatValue(got) && isFloatValue(expected);
+ if (gTy !== eTy && !bothFloatTypes) {
+ return {
+ matched: false,
+ got: `${Colors.red(gTy.toString())}(${got})`,
+ expected: `${Colors.red(eTy.toString())}(${expected})`,
+ };
+ }
+ }
+
+ if (got instanceof Scalar) {
+ const g = got;
+ const e = expected as Scalar;
+ const isFloat = g.type.kind === 'f64' || g.type.kind === 'f32' || g.type.kind === 'f16';
+ const matched =
+ (isFloat && (g.value as number) === (e.value as number)) || (!isFloat && g.value === e.value);
+ return {
+ matched,
+ got: g.toString(),
+ expected: matched ? Colors.green(e.toString()) : Colors.red(e.toString()),
+ };
+ }
+
+ if (got instanceof Vector) {
+ const gLen = got.elements.length;
+ const eLen = (expected as Vector).elements.length;
+ let matched = gLen === eLen;
+ const gElements = new Array<string>(gLen);
+ const eElements = new Array<string>(eLen);
+ for (let i = 0; i < Math.max(gLen, eLen); i++) {
+ if (i < gLen && i < eLen) {
+ const g = got.elements[i];
+ const e = (expected as Vector).elements[i];
+ const cmp = compare(g, e);
+ matched = matched && cmp.matched;
+ gElements[i] = cmp.got;
+ eElements[i] = cmp.expected;
+ continue;
+ }
+ matched = false;
+ if (i < gLen) {
+ gElements[i] = got.elements[i].toString();
+ }
+ if (i < eLen) {
+ eElements[i] = (expected as Vector).elements[i].toString();
+ }
+ }
+ return {
+ matched,
+ got: `${got.type}(${gElements.join(', ')})`,
+ expected: `${expected.type}(${eElements.join(', ')})`,
+ };
+ }
+ throw new Error(`unhandled type '${typeof got}`);
+}
+
+/**
+ * Tests it a 'got' Value is contained in 'expected' interval, returning the Comparison information.
+ * @param got the Value obtained from the test
+ * @param expected the expected F32Interval
+ * @returns the comparison results
+ */
+function compareInterval(got: Value, expected: F32Interval): Comparison {
+ {
+ // Check type
+ const gTy = got.type;
+ if (!isFloatValue(got)) {
+ return {
+ matched: false,
+ got: `${Colors.red(gTy.toString())}(${got})`,
+ expected: `floating point value`,
+ };
+ }
+ }
+
+ if (got instanceof Scalar) {
+ const g = got.value as number;
+ const matched = expected.contains(g);
+ return {
+ matched,
+ got: g.toString(),
+ expected: matched ? Colors.green(expected.toString()) : Colors.red(expected.toString()),
+ };
+ }
+
+ // Vector results are currently not handled
+ throw new Error(`unhandled type '${typeof got}`);
+}
+
+/**
+ * Tests it a 'got' Value is contained in 'expected' vector, returning the Comparison information.
+ * @param got the Value obtained from the test, is expected to be a Vector
+ * @param expected the expected array of F32Intervals, one for each element of the vector
+ * @returns the comparison results
+ */
+function compareVector(got: Value, expected: F32Interval[]): Comparison {
+ // Check got type
+ if (!(got instanceof Vector)) {
+ return {
+ matched: false,
+ got: `${Colors.red((typeof got).toString())}(${got})`,
+ expected: `Vector`,
+ };
+ }
+
+ // Check element type
+ {
+ const gTy = got.type.elementType;
+ if (!isFloatValue(got.elements[0])) {
+ return {
+ matched: false,
+ got: `${Colors.red(gTy.toString())}(${got})`,
+ expected: `floating point elements`,
+ };
+ }
+ }
+
+ if (got.elements.length !== expected.length) {
+ return {
+ matched: false,
+ got: `Vector of ${got.elements.length} elements`,
+ expected: `${expected.length} elements`,
+ };
+ }
+
+ const results = got.elements.map((_, idx) => {
+ const g = got.elements[idx].value as number;
+ return { match: expected[idx].contains(g), index: idx };
+ });
+
+ const failures = results.filter(v => !v.match).map(v => v.index);
+ if (failures.length !== 0) {
+ const expected_string = expected.map((v, idx) =>
+ idx in failures ? Colors.red(`[${v}]`) : Colors.green(`[${v}]`)
+ );
+ return {
+ matched: false,
+ got: `[${got.elements}]`,
+ expected: `[${expected_string}]`,
+ };
+ }
+
+ return {
+ matched: true,
+ got: `[${got.elements}]`,
+ expected: `[${Colors.green(expected.toString())}]`,
+ };
+}
+
+/**
+ * compare() compares 'got' to 'expected', returning the Comparison information.
+ * @param got the result obtained from the test
+ * @param expected the expected result
+ * @returns the comparison results
+ */
+export function compare(got: Value, expected: Value | F32Interval | F32Interval[]): Comparison {
+ if (expected instanceof Array) {
+ return compareVector(got, expected);
+ }
+
+ if (expected instanceof F32Interval) {
+ return compareInterval(got, expected);
+ }
+
+ return compareValue(got, expected);
+}
+
+/** @returns a Comparator that checks whether a test value matches any of the provided options */
+export function anyOf(
+ ...expectations: Expectation[]
+): Comparator | (Comparator & SerializedComparator) {
+ const comparator = (got: Value) => {
+ const failed = new Set<string>();
+ for (const e of expectations) {
+ const cmp = toComparator(e)(got);
+ if (cmp.matched) {
+ return cmp;
+ }
+ failed.add(cmp.expected);
+ }
+ return { matched: false, got: got.toString(), expected: [...failed].join(' or ') };
+ };
+
+ if (getIsBuildingDataCache()) {
+ // If there's an active DataCache, and it supports storing, then append the
+ // comparator kind and serialized expectations to the comparator, so it can
+ // be serialized.
+ comparator.kind = 'anyOf';
+ comparator.data = expectations.map(e => serializeExpectation(e));
+ }
+ return comparator;
+}
+
+/** @returns a Comparator that skips the test if the expectation is undefined */
+export function skipUndefined(
+ expectation: Expectation | undefined
+): Comparator | (Comparator & SerializedComparator) {
+ const comparator = (got: Value) => {
+ if (expectation !== undefined) {
+ return toComparator(expectation)(got);
+ }
+ return { matched: true, got: got.toString(), expected: `Treating 'undefined' as Any` };
+ };
+
+ if (getIsBuildingDataCache()) {
+ // If there's an active DataCache, and it supports storing, then append the
+ // comparator kind and serialized expectations to the comparator, so it can
+ // be serialized.
+ comparator.kind = 'skipUndefined';
+ if (expectation !== undefined) {
+ comparator.data = serializeExpectation(expectation);
+ }
+ }
+ return comparator;
+}
+
+/** SerializedComparatorAnyOf is the serialized type of an `anyOf` comparator. */
+type SerializedComparatorAnyOf = {
+ kind: 'anyOf';
+ data: SerializedExpectation[];
+};
+
+/** SerializedComparatorSkipUndefined is the serialized type of an `skipUndefined` comparator. */
+type SerializedComparatorSkipUndefined = {
+ kind: 'skipUndefined';
+ data?: SerializedExpectation;
+};
+
+/** SerializedComparator is a union of all the possible serialized comparator types. */
+export type SerializedComparator = SerializedComparatorAnyOf | SerializedComparatorSkipUndefined;
+
+/**
+ * Deserializes a comparator from a SerializedComparator.
+ * @param data the SerializedComparator
+ * @returns the deserialized Comparator.
+ */
+export function deserializeComparator(data: SerializedComparator): Comparator {
+ switch (data.kind) {
+ case 'anyOf': {
+ return anyOf(...data.data.map(e => deserializeExpectation(e)));
+ }
+ case 'skipUndefined': {
+ return skipUndefined(data.data !== undefined ? deserializeExpectation(data.data) : undefined);
+ }
+ }
+ throw `unhandled comparator kind`;
+}