summaryrefslogtreecommitdiffstats
path: root/dom/webgpu/tests/cts/checkout/src/webgpu/shader/execution/expression/expression.ts
diff options
context:
space:
mode:
Diffstat (limited to 'dom/webgpu/tests/cts/checkout/src/webgpu/shader/execution/expression/expression.ts')
-rw-r--r--dom/webgpu/tests/cts/checkout/src/webgpu/shader/execution/expression/expression.ts800
1 files changed, 351 insertions, 449 deletions
diff --git a/dom/webgpu/tests/cts/checkout/src/webgpu/shader/execution/expression/expression.ts b/dom/webgpu/tests/cts/checkout/src/webgpu/shader/execution/expression/expression.ts
index f85516f29b..be8f1fd7fd 100644
--- a/dom/webgpu/tests/cts/checkout/src/webgpu/shader/execution/expression/expression.ts
+++ b/dom/webgpu/tests/cts/checkout/src/webgpu/shader/execution/expression/expression.ts
@@ -1,70 +1,25 @@
import { globalTestConfig } from '../../../../common/framework/test_config.js';
-import { ROArrayArray } from '../../../../common/util/types.js';
import { assert, objectEquals, unreachable } from '../../../../common/util/util.js';
import { GPUTest } from '../../../gpu_test.js';
-import { compare, Comparator, ComparatorImpl } from '../../../util/compare.js';
+import { Comparator, ComparatorImpl } from '../../../util/compare.js';
import { kValue } from '../../../util/constants.js';
import {
+ MatrixType,
+ ScalarValue,
ScalarType,
- Scalar,
Type,
- TypeVec,
- TypeU32,
- Value,
- Vector,
VectorType,
- u32,
- i32,
- Matrix,
- MatrixType,
- ScalarBuilder,
+ Value,
+ VectorValue,
+ isAbstractType,
scalarTypeOf,
+ ArrayType,
+ elementTypeOf,
} from '../../../util/conversion.js';
-import { FPInterval } from '../../../util/floating_point.js';
-import {
- cartesianProduct,
- QuantizeFunc,
- quantizeToI32,
- quantizeToU32,
-} from '../../../util/math.js';
-
-export type Expectation =
- | Value
- | FPInterval
- | readonly FPInterval[]
- | ROArrayArray<FPInterval>
- | Comparator;
-
-/** @returns if this Expectation actually a Comparator */
-export function isComparator(e: Expectation): e is Comparator {
- return !(
- e instanceof FPInterval ||
- e instanceof Scalar ||
- e instanceof Vector ||
- e instanceof Matrix ||
- e instanceof Array
- );
-}
-
-/** @returns the input if it is already a Comparator, otherwise wraps it in a 'value' comparator */
-export function toComparator(input: Expectation): Comparator {
- if (isComparator(input)) {
- return input;
- }
-
- return { compare: got => compare(got, input as Value), kind: 'value' };
-}
-
-/** Case is a single expression test case. */
-export type Case = {
- // The input value(s)
- input: Value | ReadonlyArray<Value>;
- // The expected result, or function to check the result
- expected: Expectation;
-};
+import { align } from '../../../util/math.js';
-/** CaseList is a list of Cases */
-export type CaseList = Array<Case>;
+import { Case } from './case.js';
+import { toComparator } from './expectation.js';
/** The input value source */
export type InputSource =
@@ -79,6 +34,9 @@ export const allInputSources: InputSource[] = ['const', 'uniform', 'storage_r',
/** Just constant input source */
export const onlyConstInputSource: InputSource[] = ['const'];
+/** All input sources except const */
+export const allButConstInputSource: InputSource[] = ['uniform', 'storage_r', 'storage_rw'];
+
/** Configuration for running a expression test */
export type Config = {
// Where the input values are read from
@@ -92,127 +50,157 @@ export type Config = {
vectorize?: number;
};
-// Helper for returning the stride for a given Type
-function valueStride(ty: Type): number {
- // AbstractFloats are passed out of the shader via a struct of 2x u32s and
- // unpacking containers as arrays
- if (scalarTypeOf(ty).kind === 'abstract-float') {
- if (ty instanceof ScalarType) {
- return 16;
- }
- if (ty instanceof VectorType) {
- if (ty.width === 2) {
- return 16;
- }
- // vec3s have padding to make them the same size as vec4s
- return 32;
- }
- if (ty instanceof MatrixType) {
- switch (ty.cols) {
- case 2:
- switch (ty.rows) {
- case 2:
- return 32;
- case 3:
- return 64;
- case 4:
- return 64;
- }
- break;
- case 3:
- switch (ty.rows) {
- case 2:
- return 48;
- case 3:
- return 96;
- case 4:
- return 96;
- }
- break;
- case 4:
- switch (ty.rows) {
- case 2:
- return 64;
- case 3:
- return 128;
- case 4:
- return 128;
- }
- break;
- }
+/**
+ * @returns the size and alignment in bytes of the type 'ty', taking into
+ * consideration storage alignment constraints and abstract numerics, which are
+ * encoded as a struct of holding two u32s.
+ */
+function sizeAndAlignmentOf(ty: Type, source: InputSource): { size: number; alignment: number } {
+ if (ty instanceof ScalarType) {
+ if (ty.kind === 'abstract-float' || ty.kind === 'abstract-int') {
+ // AbstractFloats and AbstractInts are passed out of the shader via structs of
+ // 2x u32s and unpacking containers as arrays
+ return { size: 8, alignment: 8 };
}
- unreachable(`AbstractFloats have not yet been implemented for ${ty.toString()}`);
+ return { size: ty.size, alignment: ty.alignment };
+ }
+
+ if (ty instanceof VectorType) {
+ const out = sizeAndAlignmentOf(ty.elementType, source);
+ const n = ty.width === 3 ? 4 : ty.width;
+ out.size *= n;
+ out.alignment *= n;
+ return out;
}
if (ty instanceof MatrixType) {
- switch (ty.cols) {
- case 2:
- switch (ty.rows) {
- case 2:
- return 16;
- case 3:
- return 32;
- case 4:
- return 32;
- }
- break;
- case 3:
- switch (ty.rows) {
- case 2:
- return 32;
- case 3:
- return 64;
- case 4:
- return 64;
- }
- break;
- case 4:
- switch (ty.rows) {
- case 2:
- return 32;
- case 3:
- return 64;
- case 4:
- return 64;
- }
- break;
+ const out = sizeAndAlignmentOf(ty.elementType, source);
+ const n = ty.rows === 3 ? 4 : ty.rows;
+ out.size *= n * ty.cols;
+ out.alignment *= n;
+ return out;
+ }
+
+ if (ty instanceof ArrayType) {
+ const out = sizeAndAlignmentOf(ty.elementType, source);
+ if (source === 'uniform') {
+ out.alignment = align(out.alignment, 16);
}
- unreachable(
- `Attempted to get stride length for a matrix with dimensions (${ty.cols}x${ty.rows}), which isn't currently handled`
- );
+ out.size *= ty.count;
+ return out;
+ }
+
+ unreachable(`unhandled type: ${ty}`);
+}
+
+/**
+ * @returns the stride in bytes of the type 'ty', taking into consideration abstract numerics,
+ * which are encoded as a struct of 2 x u32.
+ */
+function strideOf(ty: Type, source: InputSource): number {
+ const sizeAndAlign = sizeAndAlignmentOf(ty, source);
+ return align(sizeAndAlign.size, sizeAndAlign.alignment);
+}
+
+/**
+ * Calls 'callback' with the layout information of each structure member with the types 'members'.
+ * @returns the byte size, stride and alignment of the structure.
+ */
+export function structLayout(
+ members: Type[],
+ source: InputSource,
+ callback?: (m: {
+ index: number;
+ type: Type;
+ size: number;
+ alignment: number;
+ offset: number;
+ }) => void
+): { size: number; stride: number; alignment: number } {
+ let offset = 0;
+ let alignment = 1;
+ for (let i = 0; i < members.length; i++) {
+ const member = members[i];
+ const sizeAndAlign = sizeAndAlignmentOf(member, source);
+ offset = align(offset, sizeAndAlign.alignment);
+ if (callback) {
+ callback({
+ index: i,
+ type: member,
+ size: sizeAndAlign.size,
+ alignment: sizeAndAlign.alignment,
+ offset,
+ });
+ }
+ offset += sizeAndAlign.size;
+ alignment = Math.max(alignment, sizeAndAlign.alignment);
}
- // Handles scalars and vectors
- return 16;
+ if (source === 'uniform') {
+ alignment = align(alignment, 16);
+ }
+
+ const size = offset;
+ const stride = align(size, alignment);
+ return { size, stride, alignment };
+}
+
+/** @returns the stride in bytes between two consecutive structures with the given members */
+export function structStride(members: Type[], source: InputSource): number {
+ return structLayout(members, source).stride;
}
-// Helper for summing up all of the stride values for an array of Types
-function valueStrides(tys: Type[]): number {
- return tys.map(valueStride).reduce((sum, c) => sum + c);
+/** @returns the WGSL to describe the structure members in 'members' */
+function wgslMembers(members: Type[], source: InputSource, memberName: (i: number) => string) {
+ const lines: string[] = [];
+ const layout = structLayout(members, source, m => {
+ lines.push(` @size(${m.size}) ${memberName(lines.length)} : ${m.type},`);
+ });
+ const padding = layout.stride - layout.size;
+ if (padding > 0) {
+ // Pad with a 'f16' if the padding requires an odd multiple of 2 bytes.
+ // This is required as 'i32' has an alignment and size of 4 bytes.
+ const ty = (padding & 2) !== 0 ? 'f16' : 'i32';
+ lines.push(` @size(${padding}) padding : ${ty},`);
+ }
+ return lines.join('\n');
}
// Helper for returning the WGSL storage type for the given Type.
function storageType(ty: Type): Type {
if (ty instanceof ScalarType) {
assert(ty.kind !== 'f64', `No storage type defined for 'f64' values`);
+ assert(ty.kind !== 'abstract-int', `Custom handling is implemented for 'abstract-int' values`);
assert(
ty.kind !== 'abstract-float',
`Custom handling is implemented for 'abstract-float' values`
);
if (ty.kind === 'bool') {
- return TypeU32;
+ return Type.u32;
}
}
if (ty instanceof VectorType) {
- return TypeVec(ty.width, storageType(ty.elementType) as ScalarType);
+ return Type.vec(ty.width, storageType(ty.elementType) as ScalarType);
+ }
+ if (ty instanceof ArrayType) {
+ return Type.array(ty.count, storageType(ty.elementType));
}
return ty;
}
+/** Structure used to hold [from|to]Storage conversion helpers */
+type TypeConversionHelpers = {
+ // The module-scope WGSL to emit with the shader.
+ wgsl: string;
+ // A function that generates a unique WGSL identifier.
+ uniqueID: () => string;
+};
+
// Helper for converting a value of the type 'ty' from the storage type.
-function fromStorage(ty: Type, expr: string): string {
+function fromStorage(ty: Type, expr: string, helpers: TypeConversionHelpers): string {
if (ty instanceof ScalarType) {
- assert(ty.kind !== 'abstract-float', `AbstractFloat values should not be in input storage`);
+ assert(ty.kind !== 'abstract-int', `'abstract-int' values should not be in input storage`);
+ assert(ty.kind !== 'abstract-float', `'abstract-float' values should not be in input storage`);
assert(ty.kind !== 'f64', `'No storage type defined for 'f64' values`);
if (ty.kind === 'bool') {
return `${expr} != 0u`;
@@ -220,23 +208,46 @@ function fromStorage(ty: Type, expr: string): string {
}
if (ty instanceof VectorType) {
assert(
+ ty.elementType.kind !== 'abstract-int',
+ `'abstract-int' values cannot appear in input storage`
+ );
+ assert(
ty.elementType.kind !== 'abstract-float',
- `AbstractFloat values cannot appear in input storage`
+ `'abstract-float' values cannot appear in input storage`
);
assert(ty.elementType.kind !== 'f64', `'No storage type defined for 'f64' values`);
if (ty.elementType.kind === 'bool') {
- return `${expr} != vec${ty.width}<u32>(0u)`;
+ return `(${expr} != vec${ty.width}<u32>(0u))`;
}
}
+ if (ty instanceof ArrayType && elementTypeOf(ty) === Type.bool) {
+ // array<u32, N> -> array<bool, N>
+ const conv = helpers.uniqueID();
+ const inTy = Type.array(ty.count, Type.u32);
+ helpers.wgsl += `
+fn ${conv}(in : ${inTy}) -> ${ty} {
+ var out : ${ty};
+ for (var i = 0; i < ${ty.count}; i++) {
+ out[i] = in[i] != 0;
+ }
+ return out;
+}
+`;
+ return `${conv}(${expr})`;
+ }
return expr;
}
// Helper for converting a value of the type 'ty' to the storage type.
-function toStorage(ty: Type, expr: string): string {
+function toStorage(ty: Type, expr: string, helpers: TypeConversionHelpers): string {
if (ty instanceof ScalarType) {
assert(
+ ty.kind !== 'abstract-int',
+ `'abstract-int' values have custom code for writing to storage`
+ );
+ assert(
ty.kind !== 'abstract-float',
- `AbstractFloat values have custom code for writing to storage`
+ `'abstract-float' values have custom code for writing to storage`
);
assert(ty.kind !== 'f64', `No storage type defined for 'f64' values`);
if (ty.kind === 'bool') {
@@ -245,14 +256,33 @@ function toStorage(ty: Type, expr: string): string {
}
if (ty instanceof VectorType) {
assert(
+ ty.elementType.kind !== 'abstract-int',
+ `'abstract-int' values have custom code for writing to storage`
+ );
+ assert(
ty.elementType.kind !== 'abstract-float',
- `AbstractFloat values have custom code for writing to storage`
+ `'abstract-float' values have custom code for writing to storage`
);
assert(ty.elementType.kind !== 'f64', `'No storage type defined for 'f64' values`);
if (ty.elementType.kind === 'bool') {
return `select(vec${ty.width}<u32>(0u), vec${ty.width}<u32>(1u), ${expr})`;
}
}
+ if (ty instanceof ArrayType && elementTypeOf(ty) === Type.bool) {
+ // array<bool, N> -> array<u32, N>
+ const conv = helpers.uniqueID();
+ const outTy = Type.array(ty.count, Type.u32);
+ helpers.wgsl += `
+fn ${conv}(in : ${ty}) -> ${outTy} {
+ var out : ${outTy};
+ for (var i = 0; i < ${ty.count}; i++) {
+ out[i] = select(0u, 1u, in[i]);
+ }
+ return out;
+}
+`;
+ return `${conv}(${expr})`;
+ }
return expr;
}
@@ -296,7 +326,7 @@ export async function run(
parameterTypes: Array<Type>,
resultType: Type,
cfg: Config = { inputSource: 'storage_r' },
- cases: CaseList,
+ cases: Case[],
batch_size?: number
) {
// If the 'vectorize' config option was provided, pack the cases into vectors.
@@ -325,12 +355,13 @@ export async function run(
// 2k appears to be a sweet-spot when benchmarking.
return Math.floor(
Math.min(1024 * 2, t.device.limits.maxUniformBufferBindingSize) /
- valueStrides(parameterTypes)
+ structStride(parameterTypes, cfg.inputSource)
);
case 'storage_r':
case 'storage_rw':
return Math.floor(
- t.device.limits.maxStorageBufferBindingSize / valueStrides(parameterTypes)
+ t.device.limits.maxStorageBufferBindingSize /
+ structStride(parameterTypes, cfg.inputSource)
);
}
})();
@@ -353,7 +384,7 @@ export async function run(
}
};
- const processBatch = async (batchCases: CaseList) => {
+ const processBatch = async (batchCases: Case[]) => {
const checkBatch = await submitBatch(
t,
shaderBuilder,
@@ -404,12 +435,13 @@ async function submitBatch(
shaderBuilder: ShaderBuilder,
parameterTypes: Array<Type>,
resultType: Type,
- cases: CaseList,
+ cases: Case[],
inputSource: InputSource,
pipelineCache: PipelineCache
): Promise<() => void> {
// Construct a buffer to hold the results of the expression tests
- const outputBufferSize = cases.length * valueStride(resultType);
+ const outputStride = structStride([resultType], 'storage_rw');
+ const outputBufferSize = align(cases.length * outputStride, 4);
const outputBuffer = t.device.createBuffer({
size: outputBufferSize,
usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE,
@@ -444,7 +476,7 @@ async function submitBatch(
// Read the outputs from the output buffer
const outputs = new Array<Value>(cases.length);
for (let i = 0; i < cases.length; i++) {
- outputs[i] = resultType.read(outputData, i * valueStride(resultType));
+ outputs[i] = resultType.read(outputData, i * outputStride);
}
// The list of expectation failures
@@ -498,7 +530,7 @@ function map<T, U>(v: T | readonly T[], fn: (value: T, index?: number) => U): U[
export type ShaderBuilder = (
parameterTypes: Array<Type>,
resultType: Type,
- cases: CaseList,
+ cases: Case[],
inputSource: InputSource
) => string;
@@ -507,10 +539,13 @@ export type ShaderBuilder = (
*/
function wgslOutputs(resultType: Type, count: number): string {
let output_struct = undefined;
- if (scalarTypeOf(resultType).kind !== 'abstract-float') {
+ if (
+ scalarTypeOf(resultType).kind !== 'abstract-float' &&
+ scalarTypeOf(resultType).kind !== 'abstract-int'
+ ) {
output_struct = `
struct Output {
- @size(${valueStride(resultType)}) value : ${storageType(resultType)}
+ @size(${strideOf(resultType, 'storage_rw')}) value : ${storageType(resultType)}
};`;
} else {
if (resultType instanceof ScalarType) {
@@ -520,7 +555,7 @@ struct Output {
};
struct Output {
- @size(${valueStride(resultType)}) value: AF,
+ @size(${strideOf(resultType, 'storage_rw')}) value: AF,
};`;
}
if (resultType instanceof VectorType) {
@@ -531,7 +566,7 @@ struct Output {
};
struct Output {
- @size(${valueStride(resultType)}) value: array<AF, ${dim}>,
+ @size(${strideOf(resultType, 'storage_rw')}) value: array<AF, ${dim}>,
};`;
}
@@ -544,7 +579,7 @@ struct Output {
};
struct Output {
- @size(${valueStride(resultType)}) value: array<array<AF, ${rows}>, ${cols}>,
+ @size(${strideOf(resultType, 'storage_rw')}) value: array<array<AF, ${rows}>, ${cols}>,
};`;
}
@@ -562,7 +597,7 @@ struct Output {
function wgslValuesArray(
parameterTypes: Array<Type>,
resultType: Type,
- cases: CaseList,
+ cases: Case[],
expressionBuilder: ExpressionBuilder
): string {
return `
@@ -612,19 +647,28 @@ function basicExpressionShaderBody(
expressionBuilder: ExpressionBuilder,
parameterTypes: Array<Type>,
resultType: Type,
- cases: CaseList,
+ cases: Case[],
inputSource: InputSource
): string {
assert(
+ scalarTypeOf(resultType).kind !== 'abstract-int',
+ `abstractIntShaderBuilder should be used when result type is 'abstract-int'`
+ );
+ assert(
scalarTypeOf(resultType).kind !== 'abstract-float',
- `abstractFloatShaderBuilder should be used when result type is 'abstract-float`
+ `abstractFloatShaderBuilder should be used when result type is 'abstract-float'`
);
+ let nextUniqueIDSuffix = 0;
+ const convHelpers: TypeConversionHelpers = {
+ wgsl: '',
+ uniqueID: () => `cts_symbol_${nextUniqueIDSuffix++}`,
+ };
if (inputSource === 'const') {
//////////////////////////////////////////////////////////////////////////
// Constant eval
//////////////////////////////////////////////////////////////////////////
let body = '';
- if (parameterTypes.some(ty => scalarTypeOf(ty).kind === 'abstract-float')) {
+ if (parameterTypes.some(ty => isAbstractType(elementTypeOf(ty)))) {
// Directly assign the expression to the output, to avoid an
// intermediate store, which will concretize the value early
body = cases
@@ -632,7 +676,8 @@ function basicExpressionShaderBody(
(c, i) =>
` outputs[${i}].value = ${toStorage(
resultType,
- expressionBuilder(map(c.input, v => v.wgsl()))
+ expressionBuilder(map(c.input, v => v.wgsl())),
+ convHelpers
)};`
)
.join('\n ');
@@ -640,47 +685,60 @@ function basicExpressionShaderBody(
body = cases
.map((_, i) => {
const value = `values[${i}]`;
- return ` outputs[${i}].value = ${toStorage(resultType, value)};`;
+ return ` outputs[${i}].value = ${toStorage(resultType, value, convHelpers)};`;
})
.join('\n ');
} else {
body = `
for (var i = 0u; i < ${cases.length}; i++) {
- outputs[i].value = ${toStorage(resultType, `values[i]`)};
+ outputs[i].value = ${toStorage(resultType, `values[i]`, convHelpers)};
}`;
}
+ // If params are abstract, we will assign them directly to the storage array, so skip the values array.
+ let valuesArray = '';
+ if (!parameterTypes.some(isAbstractType)) {
+ valuesArray = wgslValuesArray(parameterTypes, resultType, cases, expressionBuilder);
+ }
+
return `
${wgslOutputs(resultType, cases.length)}
-${wgslValuesArray(parameterTypes, resultType, cases, expressionBuilder)}
+${valuesArray}
+
+${convHelpers.wgsl}
@compute @workgroup_size(1)
fn main() {
${body}
-}`;
+}
+`;
} else {
//////////////////////////////////////////////////////////////////////////
// Runtime eval
//////////////////////////////////////////////////////////////////////////
// returns the WGSL expression to load the ith parameter of the given type from the input buffer
- const paramExpr = (ty: Type, i: number) => fromStorage(ty, `inputs[i].param${i}`);
+ const paramExpr = (ty: Type, i: number) => fromStorage(ty, `inputs[i].param${i}`, convHelpers);
// resolves to the expression that calls the builtin
- const expr = toStorage(resultType, expressionBuilder(parameterTypes.map(paramExpr)));
+ const expr = toStorage(
+ resultType,
+ expressionBuilder(parameterTypes.map(paramExpr)),
+ convHelpers
+ );
return `
struct Input {
-${parameterTypes
- .map((ty, i) => ` @size(${valueStride(ty)}) param${i} : ${storageType(ty)},`)
- .join('\n')}
-};
+${wgslMembers(parameterTypes.map(storageType), inputSource, i => `param${i}`)}
+}
${wgslOutputs(resultType, cases.length)}
${wgslInputVar(inputSource, cases.length)}
+${convHelpers.wgsl}
+
@compute @workgroup_size(1)
fn main() {
for (var i = 0; i < ${cases.length}; i++) {
@@ -699,7 +757,7 @@ export function basicExpressionBuilder(expressionBuilder: ExpressionBuilder): Sh
return (
parameterTypes: Array<Type>,
resultType: Type,
- cases: CaseList,
+ cases: Case[],
inputSource: InputSource
) => {
return `\
@@ -722,7 +780,7 @@ export function basicExpressionWithPredeclarationBuilder(
return (
parameterTypes: Array<Type>,
resultType: Type,
- cases: CaseList,
+ cases: Case[],
inputSource: InputSource
) => {
return `\
@@ -742,7 +800,7 @@ export function compoundAssignmentBuilder(op: string): ShaderBuilder {
return (
parameterTypes: Array<Type>,
resultType: Type,
- cases: CaseList,
+ cases: Case[],
inputSource: InputSource
) => {
//////////////////////////////////////////////////////////////////////////
@@ -807,8 +865,7 @@ ${wgslHeader(parameterTypes, resultType)}
${wgslOutputs(resultType, cases.length)}
struct Input {
- @size(${valueStride(lhsType)}) lhs : ${storageType(lhsType)},
- @size(${valueStride(rhsType)}) rhs : ${storageType(rhsType)},
+${wgslMembers([lhsType, rhsType].map(storageType), inputSource, i => ['lhs', 'rhs'][i])}
}
${wgslInputVar(inputSource, cases.length)}
@@ -969,10 +1026,10 @@ export function abstractFloatShaderBuilder(expressionBuilder: ExpressionBuilder)
return (
parameterTypes: Array<Type>,
resultType: Type,
- cases: CaseList,
+ cases: Case[],
inputSource: InputSource
) => {
- assert(inputSource === 'const', 'AbstractFloat results are only defined for const-eval');
+ assert(inputSource === 'const', `'abstract-float' results are only defined for const-eval`);
assert(
scalarTypeOf(resultType).kind === 'abstract-float',
`Expected resultType of 'abstract-float', received '${scalarTypeOf(resultType).kind}' instead`
@@ -998,6 +1055,90 @@ ${body}
}
/**
+ * @returns a string that extracts the value of an AbstractInt into an output
+ * destination
+ * @param expr expression for an AbstractInt value, if working with vectors,
+ * this string needs to include indexing into the container.
+ * @param case_idx index in the case output array to assign the result
+ * @param accessor string representing how access to the AbstractInt that needs
+ * to be operated on.
+ * For scalars this should be left as ''.
+ * For vectors this will be an indexing operation,
+ * i.e. '[i]'
+ */
+function abstractIntSnippet(expr: string, case_idx: number, accessor: string = ''): string {
+ // AbstractInts are i64s under the hood. WebGPU does not support
+ // putting i64s in buffers, or any 64-bit simple types, so the result needs to
+ // be split up into u32 bitfields
+ //
+ // Since there is no 64-bit data type that can be used as an element for a
+ // vector or a matrix in WGSL, the testing framework needs to pass the u32s
+ // via a struct with two u32s, and deconstruct vectors into arrays.
+ //
+ // This is complicated by the fact that user defined functions cannot
+ // take/return AbstractInts, and AbstractInts cannot be stored in
+ // variables, so the code cannot just inject a simple utility function
+ // at the top of the shader, instead this snippet needs to be inlined
+ // everywhere the test needs to return an AbstractInt.
+ return ` {
+ outputs[${case_idx}].value${accessor}.high = bitcast<u32>(i32(${expr}${accessor} >> 32)) & 0xFFFFFFFF;
+ const low_sign = (${expr}${accessor} & (1 << 31));
+ outputs[${case_idx}].value${accessor}.low = bitcast<u32>((${expr}${accessor} & 0x7FFFFFFF)) | low_sign;
+ }`;
+}
+
+/** @returns a string for a specific case that has a AbstractInt result */
+function abstractIntCaseBody(expr: string, resultType: Type, i: number): string {
+ if (resultType instanceof ScalarType) {
+ return abstractIntSnippet(expr, i);
+ }
+
+ if (resultType instanceof VectorType) {
+ return [...Array(resultType.width).keys()]
+ .map(idx => abstractIntSnippet(expr, i, `[${idx}]`))
+ .join(' \n');
+ }
+
+ unreachable(`Results of type '${resultType}' not yet implemented`);
+}
+
+/**
+ * @returns a ShaderBuilder that builds a test shader hands AbstractInt results.
+ * @param expressionBuilder an expression builder that will return AbstractInts
+ */
+export function abstractIntShaderBuilder(expressionBuilder: ExpressionBuilder): ShaderBuilder {
+ return (
+ parameterTypes: Array<Type>,
+ resultType: Type,
+ cases: Case[],
+ inputSource: InputSource
+ ) => {
+ assert(inputSource === 'const', `'abstract-int' results are only defined for const-eval`);
+ assert(
+ scalarTypeOf(resultType).kind === 'abstract-int',
+ `Expected resultType of 'abstract-int', received '${scalarTypeOf(resultType).kind}' instead`
+ );
+
+ const body = cases
+ .map((c, i) => {
+ const expr = `${expressionBuilder(map(c.input, v => v.wgsl()))}`;
+ return abstractIntCaseBody(expr, resultType, i);
+ })
+ .join('\n ');
+
+ return `
+${wgslHeader(parameterTypes, resultType)}
+
+${wgslOutputs(resultType, cases.length)}
+
+@compute @workgroup_size(1)
+fn main() {
+${body}
+}`;
+ };
+}
+
+/**
* Constructs and returns a GPUComputePipeline and GPUBindGroup for running a
* batch of test cases. If a pre-created pipeline can be found in
* `pipelineCache`, then this may be returned instead of creating a new
@@ -1016,7 +1157,7 @@ async function buildPipeline(
shaderBuilder: ShaderBuilder,
parameterTypes: Array<Type>,
resultType: Type,
- cases: CaseList,
+ cases: Case[],
inputSource: InputSource,
outputBuffer: GPUBuffer,
pipelineCache: PipelineCache
@@ -1060,27 +1201,23 @@ async function buildPipeline(
// Input values come from a uniform or storage buffer
// size in bytes of the input buffer
- const inputSize = cases.length * valueStrides(parameterTypes);
+ const caseStride = structStride(parameterTypes, inputSource);
+ const inputSize = align(cases.length * caseStride, 4);
// Holds all the parameter values for all cases
const inputData = new Uint8Array(inputSize);
// Pack all the input parameter values into the inputData buffer
- {
- const caseStride = valueStrides(parameterTypes);
- for (let caseIdx = 0; caseIdx < cases.length; caseIdx++) {
- const caseBase = caseIdx * caseStride;
- let offset = caseBase;
- for (let paramIdx = 0; paramIdx < parameterTypes.length; paramIdx++) {
- const params = cases[caseIdx].input;
- if (params instanceof Array) {
- params[paramIdx].copyTo(inputData, offset);
- } else {
- params.copyTo(inputData, offset);
- }
- offset += valueStride(parameterTypes[paramIdx]);
+ for (let caseIdx = 0; caseIdx < cases.length; caseIdx++) {
+ const offset = caseIdx * caseStride;
+ structLayout(parameterTypes, inputSource, m => {
+ const arg = cases[caseIdx].input;
+ if (arg instanceof Array) {
+ arg[m.index].copyTo(inputData, offset + m.offset);
+ } else {
+ arg.copyTo(inputData, offset + m.offset);
}
- }
+ });
}
// build the compute pipeline, if the shader hasn't been compiled already.
@@ -1123,12 +1260,12 @@ async function buildPipeline(
* If `cases.length` is not a multiple of `vectorWidth`, then the last scalar
* test case value is repeated to fill the vector value.
*/
-function packScalarsToVector(
+export function packScalarsToVector(
parameterTypes: Array<Type>,
resultType: Type,
- cases: CaseList,
+ cases: Case[],
vectorWidth: number
-): { cases: CaseList; parameterTypes: Array<Type>; resultType: Type } {
+): { cases: Case[]; parameterTypes: Array<Type>; resultType: Type } {
// Validate that the parameters and return type are all vectorizable
for (let i = 0; i < parameterTypes.length; i++) {
const ty = parameterTypes[i];
@@ -1145,22 +1282,22 @@ function packScalarsToVector(
}
const packedCases: Array<Case> = [];
- const packedParameterTypes = parameterTypes.map(p => TypeVec(vectorWidth, p as ScalarType));
- const packedResultType = new VectorType(vectorWidth, resultType);
+ const packedParameterTypes = parameterTypes.map(p => Type.vec(vectorWidth, p as ScalarType));
+ const packedResultType = Type.vec(vectorWidth, resultType);
const clampCaseIdx = (idx: number) => Math.min(idx, cases.length - 1);
let caseIdx = 0;
while (caseIdx < cases.length) {
// Construct the vectorized inputs from the scalar cases
- const packedInputs = new Array<Vector>(parameterTypes.length);
+ const packedInputs = new Array<VectorValue>(parameterTypes.length);
for (let paramIdx = 0; paramIdx < parameterTypes.length; paramIdx++) {
- const inputElements = new Array<Scalar>(vectorWidth);
+ const inputElements = new Array<ScalarValue>(vectorWidth);
for (let i = 0; i < vectorWidth; i++) {
const input = cases[clampCaseIdx(caseIdx + i)].input;
- inputElements[i] = (input instanceof Array ? input[paramIdx] : input) as Scalar;
+ inputElements[i] = (input instanceof Array ? input[paramIdx] : input) as ScalarValue;
}
- packedInputs[paramIdx] = new Vector(inputElements);
+ packedInputs[paramIdx] = new VectorValue(inputElements);
}
// Gather the comparators for the packed cases
@@ -1174,7 +1311,7 @@ function packScalarsToVector(
const gElements = new Array<string>(vectorWidth);
const eElements = new Array<string>(vectorWidth);
for (let i = 0; i < vectorWidth; i++) {
- const d = cmp_impls[i]((got as Vector).elements[i]);
+ const d = cmp_impls[i]((got as VectorValue).elements[i]);
matched = matched && d.matched;
gElements[i] = d.got;
eElements[i] = d.expected;
@@ -1199,238 +1336,3 @@ function packScalarsToVector(
resultType: packedResultType,
};
}
-
-/**
- * Indicates bounds that acceptance intervals need to be within to avoid inputs
- * being filtered out. This is used for const-eval tests, since going OOB will
- * cause a validation error not an execution error.
- */
-export type IntervalFilter =
- | 'finite' // Expected to be finite in the interval numeric space
- | 'unfiltered'; // No expectations
-
-/**
- * A function that performs a binary operation on x and y, and returns the expected
- * result.
- */
-export interface BinaryOp {
- (x: number, y: number): number | undefined;
-}
-
-/**
- * @returns array of Case for the input params with op applied
- * @param param0s array of inputs to try for the first param
- * @param param1s array of inputs to try for the second param
- * @param op callback called on each pair of inputs to produce each case
- * @param quantize function to quantize all values
- * @param scalarize function to convert numbers to Scalars
- */
-function generateScalarBinaryToScalarCases(
- param0s: readonly number[],
- param1s: readonly number[],
- op: BinaryOp,
- quantize: QuantizeFunc,
- scalarize: ScalarBuilder
-): Case[] {
- param0s = param0s.map(quantize);
- param1s = param1s.map(quantize);
- return cartesianProduct(param0s, param1s).reduce((cases, e) => {
- const expected = op(e[0], e[1]);
- if (expected !== undefined) {
- cases.push({ input: [scalarize(e[0]), scalarize(e[1])], expected: scalarize(expected) });
- }
- return cases;
- }, new Array<Case>());
-}
-
-/**
- * @returns an array of Cases for operations over a range of inputs
- * @param param0s array of inputs to try for the first param
- * @param param1s array of inputs to try for the second param
- * @param op callback called on each pair of inputs to produce each case
- */
-export function generateBinaryToI32Cases(
- param0s: readonly number[],
- param1s: readonly number[],
- op: BinaryOp
-) {
- return generateScalarBinaryToScalarCases(param0s, param1s, op, quantizeToI32, i32);
-}
-
-/**
- * @returns an array of Cases for operations over a range of inputs
- * @param param0s array of inputs to try for the first param
- * @param param1s array of inputs to try for the second param
- * @param op callback called on each pair of inputs to produce each case
- */
-export function generateBinaryToU32Cases(
- param0s: readonly number[],
- param1s: readonly number[],
- op: BinaryOp
-) {
- return generateScalarBinaryToScalarCases(param0s, param1s, op, quantizeToU32, u32);
-}
-
-/**
- * @returns a Case for the input params with op applied
- * @param scalar scalar param
- * @param vector vector param (2, 3, or 4 elements)
- * @param op the op to apply to scalar and vector
- * @param quantize function to quantize all values in vectors and scalars
- * @param scalarize function to convert numbers to Scalars
- */
-function makeScalarVectorBinaryToVectorCase(
- scalar: number,
- vector: readonly number[],
- op: BinaryOp,
- quantize: QuantizeFunc,
- scalarize: ScalarBuilder
-): Case | undefined {
- scalar = quantize(scalar);
- vector = vector.map(quantize);
- const result = vector.map(v => op(scalar, v));
- if (result.includes(undefined)) {
- return undefined;
- }
- return {
- input: [scalarize(scalar), new Vector(vector.map(scalarize))],
- expected: new Vector((result as readonly number[]).map(scalarize)),
- };
-}
-
-/**
- * @returns array of Case for the input params with op applied
- * @param scalars array of scalar params
- * @param vectors array of vector params (2, 3, or 4 elements)
- * @param op the op to apply to each pair of scalar and vector
- * @param quantize function to quantize all values in vectors and scalars
- * @param scalarize function to convert numbers to Scalars
- */
-function generateScalarVectorBinaryToVectorCases(
- scalars: readonly number[],
- vectors: ROArrayArray<number>,
- op: BinaryOp,
- quantize: QuantizeFunc,
- scalarize: ScalarBuilder
-): Case[] {
- const cases = new Array<Case>();
- scalars.forEach(s => {
- vectors.forEach(v => {
- const c = makeScalarVectorBinaryToVectorCase(s, v, op, quantize, scalarize);
- if (c !== undefined) {
- cases.push(c);
- }
- });
- });
- return cases;
-}
-
-/**
- * @returns a Case for the input params with op applied
- * @param vector vector param (2, 3, or 4 elements)
- * @param scalar scalar param
- * @param op the op to apply to vector and scalar
- * @param quantize function to quantize all values in vectors and scalars
- * @param scalarize function to convert numbers to Scalars
- */
-function makeVectorScalarBinaryToVectorCase(
- vector: readonly number[],
- scalar: number,
- op: BinaryOp,
- quantize: QuantizeFunc,
- scalarize: ScalarBuilder
-): Case | undefined {
- vector = vector.map(quantize);
- scalar = quantize(scalar);
- const result = vector.map(v => op(v, scalar));
- if (result.includes(undefined)) {
- return undefined;
- }
- return {
- input: [new Vector(vector.map(scalarize)), scalarize(scalar)],
- expected: new Vector((result as readonly number[]).map(scalarize)),
- };
-}
-
-/**
- * @returns array of Case for the input params with op applied
- * @param vectors array of vector params (2, 3, or 4 elements)
- * @param scalars array of scalar params
- * @param op the op to apply to each pair of vector and scalar
- * @param quantize function to quantize all values in vectors and scalars
- * @param scalarize function to convert numbers to Scalars
- */
-function generateVectorScalarBinaryToVectorCases(
- vectors: ROArrayArray<number>,
- scalars: readonly number[],
- op: BinaryOp,
- quantize: QuantizeFunc,
- scalarize: ScalarBuilder
-): Case[] {
- const cases = new Array<Case>();
- scalars.forEach(s => {
- vectors.forEach(v => {
- const c = makeVectorScalarBinaryToVectorCase(v, s, op, quantize, scalarize);
- if (c !== undefined) {
- cases.push(c);
- }
- });
- });
- return cases;
-}
-
-/**
- * @returns array of Case for the input params with op applied
- * @param scalars array of scalar params
- * @param vectors array of vector params (2, 3, or 4 elements)
- * @param op he op to apply to each pair of scalar and vector
- */
-export function generateU32VectorBinaryToVectorCases(
- scalars: readonly number[],
- vectors: ROArrayArray<number>,
- op: BinaryOp
-): Case[] {
- return generateScalarVectorBinaryToVectorCases(scalars, vectors, op, quantizeToU32, u32);
-}
-
-/**
- * @returns array of Case for the input params with op applied
- * @param vectors array of vector params (2, 3, or 4 elements)
- * @param scalars array of scalar params
- * @param op he op to apply to each pair of vector and scalar
- */
-export function generateVectorU32BinaryToVectorCases(
- vectors: ROArrayArray<number>,
- scalars: readonly number[],
- op: BinaryOp
-): Case[] {
- return generateVectorScalarBinaryToVectorCases(vectors, scalars, op, quantizeToU32, u32);
-}
-
-/**
- * @returns array of Case for the input params with op applied
- * @param scalars array of scalar params
- * @param vectors array of vector params (2, 3, or 4 elements)
- * @param op he op to apply to each pair of scalar and vector
- */
-export function generateI32VectorBinaryToVectorCases(
- scalars: readonly number[],
- vectors: ROArrayArray<number>,
- op: BinaryOp
-): Case[] {
- return generateScalarVectorBinaryToVectorCases(scalars, vectors, op, quantizeToI32, i32);
-}
-
-/**
- * @returns array of Case for the input params with op applied
- * @param vectors array of vector params (2, 3, or 4 elements)
- * @param scalars array of scalar params
- * @param op he op to apply to each pair of vector and scalar
- */
-export function generateVectorI32BinaryToVectorCases(
- vectors: ROArrayArray<number>,
- scalars: readonly number[],
- op: BinaryOp
-): Case[] {
- return generateVectorScalarBinaryToVectorCases(vectors, scalars, op, quantizeToI32, i32);
-}