diff options
Diffstat (limited to 'dom/webgpu/tests/cts/checkout/src/webgpu/shader/execution/expression/expression.ts')
-rw-r--r-- | dom/webgpu/tests/cts/checkout/src/webgpu/shader/execution/expression/expression.ts | 800 |
1 files changed, 351 insertions, 449 deletions
diff --git a/dom/webgpu/tests/cts/checkout/src/webgpu/shader/execution/expression/expression.ts b/dom/webgpu/tests/cts/checkout/src/webgpu/shader/execution/expression/expression.ts index f85516f29b..be8f1fd7fd 100644 --- a/dom/webgpu/tests/cts/checkout/src/webgpu/shader/execution/expression/expression.ts +++ b/dom/webgpu/tests/cts/checkout/src/webgpu/shader/execution/expression/expression.ts @@ -1,70 +1,25 @@ import { globalTestConfig } from '../../../../common/framework/test_config.js'; -import { ROArrayArray } from '../../../../common/util/types.js'; import { assert, objectEquals, unreachable } from '../../../../common/util/util.js'; import { GPUTest } from '../../../gpu_test.js'; -import { compare, Comparator, ComparatorImpl } from '../../../util/compare.js'; +import { Comparator, ComparatorImpl } from '../../../util/compare.js'; import { kValue } from '../../../util/constants.js'; import { + MatrixType, + ScalarValue, ScalarType, - Scalar, Type, - TypeVec, - TypeU32, - Value, - Vector, VectorType, - u32, - i32, - Matrix, - MatrixType, - ScalarBuilder, + Value, + VectorValue, + isAbstractType, scalarTypeOf, + ArrayType, + elementTypeOf, } from '../../../util/conversion.js'; -import { FPInterval } from '../../../util/floating_point.js'; -import { - cartesianProduct, - QuantizeFunc, - quantizeToI32, - quantizeToU32, -} from '../../../util/math.js'; - -export type Expectation = - | Value - | FPInterval - | readonly FPInterval[] - | ROArrayArray<FPInterval> - | Comparator; - -/** @returns if this Expectation actually a Comparator */ -export function isComparator(e: Expectation): e is Comparator { - return !( - e instanceof FPInterval || - e instanceof Scalar || - e instanceof Vector || - e instanceof Matrix || - e instanceof Array - ); -} - -/** @returns the input if it is already a Comparator, otherwise wraps it in a 'value' comparator */ -export function toComparator(input: Expectation): Comparator { - if (isComparator(input)) { - return input; - } - - return { compare: got => compare(got, input as Value), kind: 'value' }; -} - -/** Case is a single expression test case. */ -export type Case = { - // The input value(s) - input: Value | ReadonlyArray<Value>; - // The expected result, or function to check the result - expected: Expectation; -}; +import { align } from '../../../util/math.js'; -/** CaseList is a list of Cases */ -export type CaseList = Array<Case>; +import { Case } from './case.js'; +import { toComparator } from './expectation.js'; /** The input value source */ export type InputSource = @@ -79,6 +34,9 @@ export const allInputSources: InputSource[] = ['const', 'uniform', 'storage_r', /** Just constant input source */ export const onlyConstInputSource: InputSource[] = ['const']; +/** All input sources except const */ +export const allButConstInputSource: InputSource[] = ['uniform', 'storage_r', 'storage_rw']; + /** Configuration for running a expression test */ export type Config = { // Where the input values are read from @@ -92,127 +50,157 @@ export type Config = { vectorize?: number; }; -// Helper for returning the stride for a given Type -function valueStride(ty: Type): number { - // AbstractFloats are passed out of the shader via a struct of 2x u32s and - // unpacking containers as arrays - if (scalarTypeOf(ty).kind === 'abstract-float') { - if (ty instanceof ScalarType) { - return 16; - } - if (ty instanceof VectorType) { - if (ty.width === 2) { - return 16; - } - // vec3s have padding to make them the same size as vec4s - return 32; - } - if (ty instanceof MatrixType) { - switch (ty.cols) { - case 2: - switch (ty.rows) { - case 2: - return 32; - case 3: - return 64; - case 4: - return 64; - } - break; - case 3: - switch (ty.rows) { - case 2: - return 48; - case 3: - return 96; - case 4: - return 96; - } - break; - case 4: - switch (ty.rows) { - case 2: - return 64; - case 3: - return 128; - case 4: - return 128; - } - break; - } +/** + * @returns the size and alignment in bytes of the type 'ty', taking into + * consideration storage alignment constraints and abstract numerics, which are + * encoded as a struct of holding two u32s. + */ +function sizeAndAlignmentOf(ty: Type, source: InputSource): { size: number; alignment: number } { + if (ty instanceof ScalarType) { + if (ty.kind === 'abstract-float' || ty.kind === 'abstract-int') { + // AbstractFloats and AbstractInts are passed out of the shader via structs of + // 2x u32s and unpacking containers as arrays + return { size: 8, alignment: 8 }; } - unreachable(`AbstractFloats have not yet been implemented for ${ty.toString()}`); + return { size: ty.size, alignment: ty.alignment }; + } + + if (ty instanceof VectorType) { + const out = sizeAndAlignmentOf(ty.elementType, source); + const n = ty.width === 3 ? 4 : ty.width; + out.size *= n; + out.alignment *= n; + return out; } if (ty instanceof MatrixType) { - switch (ty.cols) { - case 2: - switch (ty.rows) { - case 2: - return 16; - case 3: - return 32; - case 4: - return 32; - } - break; - case 3: - switch (ty.rows) { - case 2: - return 32; - case 3: - return 64; - case 4: - return 64; - } - break; - case 4: - switch (ty.rows) { - case 2: - return 32; - case 3: - return 64; - case 4: - return 64; - } - break; + const out = sizeAndAlignmentOf(ty.elementType, source); + const n = ty.rows === 3 ? 4 : ty.rows; + out.size *= n * ty.cols; + out.alignment *= n; + return out; + } + + if (ty instanceof ArrayType) { + const out = sizeAndAlignmentOf(ty.elementType, source); + if (source === 'uniform') { + out.alignment = align(out.alignment, 16); } - unreachable( - `Attempted to get stride length for a matrix with dimensions (${ty.cols}x${ty.rows}), which isn't currently handled` - ); + out.size *= ty.count; + return out; + } + + unreachable(`unhandled type: ${ty}`); +} + +/** + * @returns the stride in bytes of the type 'ty', taking into consideration abstract numerics, + * which are encoded as a struct of 2 x u32. + */ +function strideOf(ty: Type, source: InputSource): number { + const sizeAndAlign = sizeAndAlignmentOf(ty, source); + return align(sizeAndAlign.size, sizeAndAlign.alignment); +} + +/** + * Calls 'callback' with the layout information of each structure member with the types 'members'. + * @returns the byte size, stride and alignment of the structure. + */ +export function structLayout( + members: Type[], + source: InputSource, + callback?: (m: { + index: number; + type: Type; + size: number; + alignment: number; + offset: number; + }) => void +): { size: number; stride: number; alignment: number } { + let offset = 0; + let alignment = 1; + for (let i = 0; i < members.length; i++) { + const member = members[i]; + const sizeAndAlign = sizeAndAlignmentOf(member, source); + offset = align(offset, sizeAndAlign.alignment); + if (callback) { + callback({ + index: i, + type: member, + size: sizeAndAlign.size, + alignment: sizeAndAlign.alignment, + offset, + }); + } + offset += sizeAndAlign.size; + alignment = Math.max(alignment, sizeAndAlign.alignment); } - // Handles scalars and vectors - return 16; + if (source === 'uniform') { + alignment = align(alignment, 16); + } + + const size = offset; + const stride = align(size, alignment); + return { size, stride, alignment }; +} + +/** @returns the stride in bytes between two consecutive structures with the given members */ +export function structStride(members: Type[], source: InputSource): number { + return structLayout(members, source).stride; } -// Helper for summing up all of the stride values for an array of Types -function valueStrides(tys: Type[]): number { - return tys.map(valueStride).reduce((sum, c) => sum + c); +/** @returns the WGSL to describe the structure members in 'members' */ +function wgslMembers(members: Type[], source: InputSource, memberName: (i: number) => string) { + const lines: string[] = []; + const layout = structLayout(members, source, m => { + lines.push(` @size(${m.size}) ${memberName(lines.length)} : ${m.type},`); + }); + const padding = layout.stride - layout.size; + if (padding > 0) { + // Pad with a 'f16' if the padding requires an odd multiple of 2 bytes. + // This is required as 'i32' has an alignment and size of 4 bytes. + const ty = (padding & 2) !== 0 ? 'f16' : 'i32'; + lines.push(` @size(${padding}) padding : ${ty},`); + } + return lines.join('\n'); } // Helper for returning the WGSL storage type for the given Type. function storageType(ty: Type): Type { if (ty instanceof ScalarType) { assert(ty.kind !== 'f64', `No storage type defined for 'f64' values`); + assert(ty.kind !== 'abstract-int', `Custom handling is implemented for 'abstract-int' values`); assert( ty.kind !== 'abstract-float', `Custom handling is implemented for 'abstract-float' values` ); if (ty.kind === 'bool') { - return TypeU32; + return Type.u32; } } if (ty instanceof VectorType) { - return TypeVec(ty.width, storageType(ty.elementType) as ScalarType); + return Type.vec(ty.width, storageType(ty.elementType) as ScalarType); + } + if (ty instanceof ArrayType) { + return Type.array(ty.count, storageType(ty.elementType)); } return ty; } +/** Structure used to hold [from|to]Storage conversion helpers */ +type TypeConversionHelpers = { + // The module-scope WGSL to emit with the shader. + wgsl: string; + // A function that generates a unique WGSL identifier. + uniqueID: () => string; +}; + // Helper for converting a value of the type 'ty' from the storage type. -function fromStorage(ty: Type, expr: string): string { +function fromStorage(ty: Type, expr: string, helpers: TypeConversionHelpers): string { if (ty instanceof ScalarType) { - assert(ty.kind !== 'abstract-float', `AbstractFloat values should not be in input storage`); + assert(ty.kind !== 'abstract-int', `'abstract-int' values should not be in input storage`); + assert(ty.kind !== 'abstract-float', `'abstract-float' values should not be in input storage`); assert(ty.kind !== 'f64', `'No storage type defined for 'f64' values`); if (ty.kind === 'bool') { return `${expr} != 0u`; @@ -220,23 +208,46 @@ function fromStorage(ty: Type, expr: string): string { } if (ty instanceof VectorType) { assert( + ty.elementType.kind !== 'abstract-int', + `'abstract-int' values cannot appear in input storage` + ); + assert( ty.elementType.kind !== 'abstract-float', - `AbstractFloat values cannot appear in input storage` + `'abstract-float' values cannot appear in input storage` ); assert(ty.elementType.kind !== 'f64', `'No storage type defined for 'f64' values`); if (ty.elementType.kind === 'bool') { - return `${expr} != vec${ty.width}<u32>(0u)`; + return `(${expr} != vec${ty.width}<u32>(0u))`; } } + if (ty instanceof ArrayType && elementTypeOf(ty) === Type.bool) { + // array<u32, N> -> array<bool, N> + const conv = helpers.uniqueID(); + const inTy = Type.array(ty.count, Type.u32); + helpers.wgsl += ` +fn ${conv}(in : ${inTy}) -> ${ty} { + var out : ${ty}; + for (var i = 0; i < ${ty.count}; i++) { + out[i] = in[i] != 0; + } + return out; +} +`; + return `${conv}(${expr})`; + } return expr; } // Helper for converting a value of the type 'ty' to the storage type. -function toStorage(ty: Type, expr: string): string { +function toStorage(ty: Type, expr: string, helpers: TypeConversionHelpers): string { if (ty instanceof ScalarType) { assert( + ty.kind !== 'abstract-int', + `'abstract-int' values have custom code for writing to storage` + ); + assert( ty.kind !== 'abstract-float', - `AbstractFloat values have custom code for writing to storage` + `'abstract-float' values have custom code for writing to storage` ); assert(ty.kind !== 'f64', `No storage type defined for 'f64' values`); if (ty.kind === 'bool') { @@ -245,14 +256,33 @@ function toStorage(ty: Type, expr: string): string { } if (ty instanceof VectorType) { assert( + ty.elementType.kind !== 'abstract-int', + `'abstract-int' values have custom code for writing to storage` + ); + assert( ty.elementType.kind !== 'abstract-float', - `AbstractFloat values have custom code for writing to storage` + `'abstract-float' values have custom code for writing to storage` ); assert(ty.elementType.kind !== 'f64', `'No storage type defined for 'f64' values`); if (ty.elementType.kind === 'bool') { return `select(vec${ty.width}<u32>(0u), vec${ty.width}<u32>(1u), ${expr})`; } } + if (ty instanceof ArrayType && elementTypeOf(ty) === Type.bool) { + // array<bool, N> -> array<u32, N> + const conv = helpers.uniqueID(); + const outTy = Type.array(ty.count, Type.u32); + helpers.wgsl += ` +fn ${conv}(in : ${ty}) -> ${outTy} { + var out : ${outTy}; + for (var i = 0; i < ${ty.count}; i++) { + out[i] = select(0u, 1u, in[i]); + } + return out; +} +`; + return `${conv}(${expr})`; + } return expr; } @@ -296,7 +326,7 @@ export async function run( parameterTypes: Array<Type>, resultType: Type, cfg: Config = { inputSource: 'storage_r' }, - cases: CaseList, + cases: Case[], batch_size?: number ) { // If the 'vectorize' config option was provided, pack the cases into vectors. @@ -325,12 +355,13 @@ export async function run( // 2k appears to be a sweet-spot when benchmarking. return Math.floor( Math.min(1024 * 2, t.device.limits.maxUniformBufferBindingSize) / - valueStrides(parameterTypes) + structStride(parameterTypes, cfg.inputSource) ); case 'storage_r': case 'storage_rw': return Math.floor( - t.device.limits.maxStorageBufferBindingSize / valueStrides(parameterTypes) + t.device.limits.maxStorageBufferBindingSize / + structStride(parameterTypes, cfg.inputSource) ); } })(); @@ -353,7 +384,7 @@ export async function run( } }; - const processBatch = async (batchCases: CaseList) => { + const processBatch = async (batchCases: Case[]) => { const checkBatch = await submitBatch( t, shaderBuilder, @@ -404,12 +435,13 @@ async function submitBatch( shaderBuilder: ShaderBuilder, parameterTypes: Array<Type>, resultType: Type, - cases: CaseList, + cases: Case[], inputSource: InputSource, pipelineCache: PipelineCache ): Promise<() => void> { // Construct a buffer to hold the results of the expression tests - const outputBufferSize = cases.length * valueStride(resultType); + const outputStride = structStride([resultType], 'storage_rw'); + const outputBufferSize = align(cases.length * outputStride, 4); const outputBuffer = t.device.createBuffer({ size: outputBufferSize, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE, @@ -444,7 +476,7 @@ async function submitBatch( // Read the outputs from the output buffer const outputs = new Array<Value>(cases.length); for (let i = 0; i < cases.length; i++) { - outputs[i] = resultType.read(outputData, i * valueStride(resultType)); + outputs[i] = resultType.read(outputData, i * outputStride); } // The list of expectation failures @@ -498,7 +530,7 @@ function map<T, U>(v: T | readonly T[], fn: (value: T, index?: number) => U): U[ export type ShaderBuilder = ( parameterTypes: Array<Type>, resultType: Type, - cases: CaseList, + cases: Case[], inputSource: InputSource ) => string; @@ -507,10 +539,13 @@ export type ShaderBuilder = ( */ function wgslOutputs(resultType: Type, count: number): string { let output_struct = undefined; - if (scalarTypeOf(resultType).kind !== 'abstract-float') { + if ( + scalarTypeOf(resultType).kind !== 'abstract-float' && + scalarTypeOf(resultType).kind !== 'abstract-int' + ) { output_struct = ` struct Output { - @size(${valueStride(resultType)}) value : ${storageType(resultType)} + @size(${strideOf(resultType, 'storage_rw')}) value : ${storageType(resultType)} };`; } else { if (resultType instanceof ScalarType) { @@ -520,7 +555,7 @@ struct Output { }; struct Output { - @size(${valueStride(resultType)}) value: AF, + @size(${strideOf(resultType, 'storage_rw')}) value: AF, };`; } if (resultType instanceof VectorType) { @@ -531,7 +566,7 @@ struct Output { }; struct Output { - @size(${valueStride(resultType)}) value: array<AF, ${dim}>, + @size(${strideOf(resultType, 'storage_rw')}) value: array<AF, ${dim}>, };`; } @@ -544,7 +579,7 @@ struct Output { }; struct Output { - @size(${valueStride(resultType)}) value: array<array<AF, ${rows}>, ${cols}>, + @size(${strideOf(resultType, 'storage_rw')}) value: array<array<AF, ${rows}>, ${cols}>, };`; } @@ -562,7 +597,7 @@ struct Output { function wgslValuesArray( parameterTypes: Array<Type>, resultType: Type, - cases: CaseList, + cases: Case[], expressionBuilder: ExpressionBuilder ): string { return ` @@ -612,19 +647,28 @@ function basicExpressionShaderBody( expressionBuilder: ExpressionBuilder, parameterTypes: Array<Type>, resultType: Type, - cases: CaseList, + cases: Case[], inputSource: InputSource ): string { assert( + scalarTypeOf(resultType).kind !== 'abstract-int', + `abstractIntShaderBuilder should be used when result type is 'abstract-int'` + ); + assert( scalarTypeOf(resultType).kind !== 'abstract-float', - `abstractFloatShaderBuilder should be used when result type is 'abstract-float` + `abstractFloatShaderBuilder should be used when result type is 'abstract-float'` ); + let nextUniqueIDSuffix = 0; + const convHelpers: TypeConversionHelpers = { + wgsl: '', + uniqueID: () => `cts_symbol_${nextUniqueIDSuffix++}`, + }; if (inputSource === 'const') { ////////////////////////////////////////////////////////////////////////// // Constant eval ////////////////////////////////////////////////////////////////////////// let body = ''; - if (parameterTypes.some(ty => scalarTypeOf(ty).kind === 'abstract-float')) { + if (parameterTypes.some(ty => isAbstractType(elementTypeOf(ty)))) { // Directly assign the expression to the output, to avoid an // intermediate store, which will concretize the value early body = cases @@ -632,7 +676,8 @@ function basicExpressionShaderBody( (c, i) => ` outputs[${i}].value = ${toStorage( resultType, - expressionBuilder(map(c.input, v => v.wgsl())) + expressionBuilder(map(c.input, v => v.wgsl())), + convHelpers )};` ) .join('\n '); @@ -640,47 +685,60 @@ function basicExpressionShaderBody( body = cases .map((_, i) => { const value = `values[${i}]`; - return ` outputs[${i}].value = ${toStorage(resultType, value)};`; + return ` outputs[${i}].value = ${toStorage(resultType, value, convHelpers)};`; }) .join('\n '); } else { body = ` for (var i = 0u; i < ${cases.length}; i++) { - outputs[i].value = ${toStorage(resultType, `values[i]`)}; + outputs[i].value = ${toStorage(resultType, `values[i]`, convHelpers)}; }`; } + // If params are abstract, we will assign them directly to the storage array, so skip the values array. + let valuesArray = ''; + if (!parameterTypes.some(isAbstractType)) { + valuesArray = wgslValuesArray(parameterTypes, resultType, cases, expressionBuilder); + } + return ` ${wgslOutputs(resultType, cases.length)} -${wgslValuesArray(parameterTypes, resultType, cases, expressionBuilder)} +${valuesArray} + +${convHelpers.wgsl} @compute @workgroup_size(1) fn main() { ${body} -}`; +} +`; } else { ////////////////////////////////////////////////////////////////////////// // Runtime eval ////////////////////////////////////////////////////////////////////////// // returns the WGSL expression to load the ith parameter of the given type from the input buffer - const paramExpr = (ty: Type, i: number) => fromStorage(ty, `inputs[i].param${i}`); + const paramExpr = (ty: Type, i: number) => fromStorage(ty, `inputs[i].param${i}`, convHelpers); // resolves to the expression that calls the builtin - const expr = toStorage(resultType, expressionBuilder(parameterTypes.map(paramExpr))); + const expr = toStorage( + resultType, + expressionBuilder(parameterTypes.map(paramExpr)), + convHelpers + ); return ` struct Input { -${parameterTypes - .map((ty, i) => ` @size(${valueStride(ty)}) param${i} : ${storageType(ty)},`) - .join('\n')} -}; +${wgslMembers(parameterTypes.map(storageType), inputSource, i => `param${i}`)} +} ${wgslOutputs(resultType, cases.length)} ${wgslInputVar(inputSource, cases.length)} +${convHelpers.wgsl} + @compute @workgroup_size(1) fn main() { for (var i = 0; i < ${cases.length}; i++) { @@ -699,7 +757,7 @@ export function basicExpressionBuilder(expressionBuilder: ExpressionBuilder): Sh return ( parameterTypes: Array<Type>, resultType: Type, - cases: CaseList, + cases: Case[], inputSource: InputSource ) => { return `\ @@ -722,7 +780,7 @@ export function basicExpressionWithPredeclarationBuilder( return ( parameterTypes: Array<Type>, resultType: Type, - cases: CaseList, + cases: Case[], inputSource: InputSource ) => { return `\ @@ -742,7 +800,7 @@ export function compoundAssignmentBuilder(op: string): ShaderBuilder { return ( parameterTypes: Array<Type>, resultType: Type, - cases: CaseList, + cases: Case[], inputSource: InputSource ) => { ////////////////////////////////////////////////////////////////////////// @@ -807,8 +865,7 @@ ${wgslHeader(parameterTypes, resultType)} ${wgslOutputs(resultType, cases.length)} struct Input { - @size(${valueStride(lhsType)}) lhs : ${storageType(lhsType)}, - @size(${valueStride(rhsType)}) rhs : ${storageType(rhsType)}, +${wgslMembers([lhsType, rhsType].map(storageType), inputSource, i => ['lhs', 'rhs'][i])} } ${wgslInputVar(inputSource, cases.length)} @@ -969,10 +1026,10 @@ export function abstractFloatShaderBuilder(expressionBuilder: ExpressionBuilder) return ( parameterTypes: Array<Type>, resultType: Type, - cases: CaseList, + cases: Case[], inputSource: InputSource ) => { - assert(inputSource === 'const', 'AbstractFloat results are only defined for const-eval'); + assert(inputSource === 'const', `'abstract-float' results are only defined for const-eval`); assert( scalarTypeOf(resultType).kind === 'abstract-float', `Expected resultType of 'abstract-float', received '${scalarTypeOf(resultType).kind}' instead` @@ -998,6 +1055,90 @@ ${body} } /** + * @returns a string that extracts the value of an AbstractInt into an output + * destination + * @param expr expression for an AbstractInt value, if working with vectors, + * this string needs to include indexing into the container. + * @param case_idx index in the case output array to assign the result + * @param accessor string representing how access to the AbstractInt that needs + * to be operated on. + * For scalars this should be left as ''. + * For vectors this will be an indexing operation, + * i.e. '[i]' + */ +function abstractIntSnippet(expr: string, case_idx: number, accessor: string = ''): string { + // AbstractInts are i64s under the hood. WebGPU does not support + // putting i64s in buffers, or any 64-bit simple types, so the result needs to + // be split up into u32 bitfields + // + // Since there is no 64-bit data type that can be used as an element for a + // vector or a matrix in WGSL, the testing framework needs to pass the u32s + // via a struct with two u32s, and deconstruct vectors into arrays. + // + // This is complicated by the fact that user defined functions cannot + // take/return AbstractInts, and AbstractInts cannot be stored in + // variables, so the code cannot just inject a simple utility function + // at the top of the shader, instead this snippet needs to be inlined + // everywhere the test needs to return an AbstractInt. + return ` { + outputs[${case_idx}].value${accessor}.high = bitcast<u32>(i32(${expr}${accessor} >> 32)) & 0xFFFFFFFF; + const low_sign = (${expr}${accessor} & (1 << 31)); + outputs[${case_idx}].value${accessor}.low = bitcast<u32>((${expr}${accessor} & 0x7FFFFFFF)) | low_sign; + }`; +} + +/** @returns a string for a specific case that has a AbstractInt result */ +function abstractIntCaseBody(expr: string, resultType: Type, i: number): string { + if (resultType instanceof ScalarType) { + return abstractIntSnippet(expr, i); + } + + if (resultType instanceof VectorType) { + return [...Array(resultType.width).keys()] + .map(idx => abstractIntSnippet(expr, i, `[${idx}]`)) + .join(' \n'); + } + + unreachable(`Results of type '${resultType}' not yet implemented`); +} + +/** + * @returns a ShaderBuilder that builds a test shader hands AbstractInt results. + * @param expressionBuilder an expression builder that will return AbstractInts + */ +export function abstractIntShaderBuilder(expressionBuilder: ExpressionBuilder): ShaderBuilder { + return ( + parameterTypes: Array<Type>, + resultType: Type, + cases: Case[], + inputSource: InputSource + ) => { + assert(inputSource === 'const', `'abstract-int' results are only defined for const-eval`); + assert( + scalarTypeOf(resultType).kind === 'abstract-int', + `Expected resultType of 'abstract-int', received '${scalarTypeOf(resultType).kind}' instead` + ); + + const body = cases + .map((c, i) => { + const expr = `${expressionBuilder(map(c.input, v => v.wgsl()))}`; + return abstractIntCaseBody(expr, resultType, i); + }) + .join('\n '); + + return ` +${wgslHeader(parameterTypes, resultType)} + +${wgslOutputs(resultType, cases.length)} + +@compute @workgroup_size(1) +fn main() { +${body} +}`; + }; +} + +/** * Constructs and returns a GPUComputePipeline and GPUBindGroup for running a * batch of test cases. If a pre-created pipeline can be found in * `pipelineCache`, then this may be returned instead of creating a new @@ -1016,7 +1157,7 @@ async function buildPipeline( shaderBuilder: ShaderBuilder, parameterTypes: Array<Type>, resultType: Type, - cases: CaseList, + cases: Case[], inputSource: InputSource, outputBuffer: GPUBuffer, pipelineCache: PipelineCache @@ -1060,27 +1201,23 @@ async function buildPipeline( // Input values come from a uniform or storage buffer // size in bytes of the input buffer - const inputSize = cases.length * valueStrides(parameterTypes); + const caseStride = structStride(parameterTypes, inputSource); + const inputSize = align(cases.length * caseStride, 4); // Holds all the parameter values for all cases const inputData = new Uint8Array(inputSize); // Pack all the input parameter values into the inputData buffer - { - const caseStride = valueStrides(parameterTypes); - for (let caseIdx = 0; caseIdx < cases.length; caseIdx++) { - const caseBase = caseIdx * caseStride; - let offset = caseBase; - for (let paramIdx = 0; paramIdx < parameterTypes.length; paramIdx++) { - const params = cases[caseIdx].input; - if (params instanceof Array) { - params[paramIdx].copyTo(inputData, offset); - } else { - params.copyTo(inputData, offset); - } - offset += valueStride(parameterTypes[paramIdx]); + for (let caseIdx = 0; caseIdx < cases.length; caseIdx++) { + const offset = caseIdx * caseStride; + structLayout(parameterTypes, inputSource, m => { + const arg = cases[caseIdx].input; + if (arg instanceof Array) { + arg[m.index].copyTo(inputData, offset + m.offset); + } else { + arg.copyTo(inputData, offset + m.offset); } - } + }); } // build the compute pipeline, if the shader hasn't been compiled already. @@ -1123,12 +1260,12 @@ async function buildPipeline( * If `cases.length` is not a multiple of `vectorWidth`, then the last scalar * test case value is repeated to fill the vector value. */ -function packScalarsToVector( +export function packScalarsToVector( parameterTypes: Array<Type>, resultType: Type, - cases: CaseList, + cases: Case[], vectorWidth: number -): { cases: CaseList; parameterTypes: Array<Type>; resultType: Type } { +): { cases: Case[]; parameterTypes: Array<Type>; resultType: Type } { // Validate that the parameters and return type are all vectorizable for (let i = 0; i < parameterTypes.length; i++) { const ty = parameterTypes[i]; @@ -1145,22 +1282,22 @@ function packScalarsToVector( } const packedCases: Array<Case> = []; - const packedParameterTypes = parameterTypes.map(p => TypeVec(vectorWidth, p as ScalarType)); - const packedResultType = new VectorType(vectorWidth, resultType); + const packedParameterTypes = parameterTypes.map(p => Type.vec(vectorWidth, p as ScalarType)); + const packedResultType = Type.vec(vectorWidth, resultType); const clampCaseIdx = (idx: number) => Math.min(idx, cases.length - 1); let caseIdx = 0; while (caseIdx < cases.length) { // Construct the vectorized inputs from the scalar cases - const packedInputs = new Array<Vector>(parameterTypes.length); + const packedInputs = new Array<VectorValue>(parameterTypes.length); for (let paramIdx = 0; paramIdx < parameterTypes.length; paramIdx++) { - const inputElements = new Array<Scalar>(vectorWidth); + const inputElements = new Array<ScalarValue>(vectorWidth); for (let i = 0; i < vectorWidth; i++) { const input = cases[clampCaseIdx(caseIdx + i)].input; - inputElements[i] = (input instanceof Array ? input[paramIdx] : input) as Scalar; + inputElements[i] = (input instanceof Array ? input[paramIdx] : input) as ScalarValue; } - packedInputs[paramIdx] = new Vector(inputElements); + packedInputs[paramIdx] = new VectorValue(inputElements); } // Gather the comparators for the packed cases @@ -1174,7 +1311,7 @@ function packScalarsToVector( const gElements = new Array<string>(vectorWidth); const eElements = new Array<string>(vectorWidth); for (let i = 0; i < vectorWidth; i++) { - const d = cmp_impls[i]((got as Vector).elements[i]); + const d = cmp_impls[i]((got as VectorValue).elements[i]); matched = matched && d.matched; gElements[i] = d.got; eElements[i] = d.expected; @@ -1199,238 +1336,3 @@ function packScalarsToVector( resultType: packedResultType, }; } - -/** - * Indicates bounds that acceptance intervals need to be within to avoid inputs - * being filtered out. This is used for const-eval tests, since going OOB will - * cause a validation error not an execution error. - */ -export type IntervalFilter = - | 'finite' // Expected to be finite in the interval numeric space - | 'unfiltered'; // No expectations - -/** - * A function that performs a binary operation on x and y, and returns the expected - * result. - */ -export interface BinaryOp { - (x: number, y: number): number | undefined; -} - -/** - * @returns array of Case for the input params with op applied - * @param param0s array of inputs to try for the first param - * @param param1s array of inputs to try for the second param - * @param op callback called on each pair of inputs to produce each case - * @param quantize function to quantize all values - * @param scalarize function to convert numbers to Scalars - */ -function generateScalarBinaryToScalarCases( - param0s: readonly number[], - param1s: readonly number[], - op: BinaryOp, - quantize: QuantizeFunc, - scalarize: ScalarBuilder -): Case[] { - param0s = param0s.map(quantize); - param1s = param1s.map(quantize); - return cartesianProduct(param0s, param1s).reduce((cases, e) => { - const expected = op(e[0], e[1]); - if (expected !== undefined) { - cases.push({ input: [scalarize(e[0]), scalarize(e[1])], expected: scalarize(expected) }); - } - return cases; - }, new Array<Case>()); -} - -/** - * @returns an array of Cases for operations over a range of inputs - * @param param0s array of inputs to try for the first param - * @param param1s array of inputs to try for the second param - * @param op callback called on each pair of inputs to produce each case - */ -export function generateBinaryToI32Cases( - param0s: readonly number[], - param1s: readonly number[], - op: BinaryOp -) { - return generateScalarBinaryToScalarCases(param0s, param1s, op, quantizeToI32, i32); -} - -/** - * @returns an array of Cases for operations over a range of inputs - * @param param0s array of inputs to try for the first param - * @param param1s array of inputs to try for the second param - * @param op callback called on each pair of inputs to produce each case - */ -export function generateBinaryToU32Cases( - param0s: readonly number[], - param1s: readonly number[], - op: BinaryOp -) { - return generateScalarBinaryToScalarCases(param0s, param1s, op, quantizeToU32, u32); -} - -/** - * @returns a Case for the input params with op applied - * @param scalar scalar param - * @param vector vector param (2, 3, or 4 elements) - * @param op the op to apply to scalar and vector - * @param quantize function to quantize all values in vectors and scalars - * @param scalarize function to convert numbers to Scalars - */ -function makeScalarVectorBinaryToVectorCase( - scalar: number, - vector: readonly number[], - op: BinaryOp, - quantize: QuantizeFunc, - scalarize: ScalarBuilder -): Case | undefined { - scalar = quantize(scalar); - vector = vector.map(quantize); - const result = vector.map(v => op(scalar, v)); - if (result.includes(undefined)) { - return undefined; - } - return { - input: [scalarize(scalar), new Vector(vector.map(scalarize))], - expected: new Vector((result as readonly number[]).map(scalarize)), - }; -} - -/** - * @returns array of Case for the input params with op applied - * @param scalars array of scalar params - * @param vectors array of vector params (2, 3, or 4 elements) - * @param op the op to apply to each pair of scalar and vector - * @param quantize function to quantize all values in vectors and scalars - * @param scalarize function to convert numbers to Scalars - */ -function generateScalarVectorBinaryToVectorCases( - scalars: readonly number[], - vectors: ROArrayArray<number>, - op: BinaryOp, - quantize: QuantizeFunc, - scalarize: ScalarBuilder -): Case[] { - const cases = new Array<Case>(); - scalars.forEach(s => { - vectors.forEach(v => { - const c = makeScalarVectorBinaryToVectorCase(s, v, op, quantize, scalarize); - if (c !== undefined) { - cases.push(c); - } - }); - }); - return cases; -} - -/** - * @returns a Case for the input params with op applied - * @param vector vector param (2, 3, or 4 elements) - * @param scalar scalar param - * @param op the op to apply to vector and scalar - * @param quantize function to quantize all values in vectors and scalars - * @param scalarize function to convert numbers to Scalars - */ -function makeVectorScalarBinaryToVectorCase( - vector: readonly number[], - scalar: number, - op: BinaryOp, - quantize: QuantizeFunc, - scalarize: ScalarBuilder -): Case | undefined { - vector = vector.map(quantize); - scalar = quantize(scalar); - const result = vector.map(v => op(v, scalar)); - if (result.includes(undefined)) { - return undefined; - } - return { - input: [new Vector(vector.map(scalarize)), scalarize(scalar)], - expected: new Vector((result as readonly number[]).map(scalarize)), - }; -} - -/** - * @returns array of Case for the input params with op applied - * @param vectors array of vector params (2, 3, or 4 elements) - * @param scalars array of scalar params - * @param op the op to apply to each pair of vector and scalar - * @param quantize function to quantize all values in vectors and scalars - * @param scalarize function to convert numbers to Scalars - */ -function generateVectorScalarBinaryToVectorCases( - vectors: ROArrayArray<number>, - scalars: readonly number[], - op: BinaryOp, - quantize: QuantizeFunc, - scalarize: ScalarBuilder -): Case[] { - const cases = new Array<Case>(); - scalars.forEach(s => { - vectors.forEach(v => { - const c = makeVectorScalarBinaryToVectorCase(v, s, op, quantize, scalarize); - if (c !== undefined) { - cases.push(c); - } - }); - }); - return cases; -} - -/** - * @returns array of Case for the input params with op applied - * @param scalars array of scalar params - * @param vectors array of vector params (2, 3, or 4 elements) - * @param op he op to apply to each pair of scalar and vector - */ -export function generateU32VectorBinaryToVectorCases( - scalars: readonly number[], - vectors: ROArrayArray<number>, - op: BinaryOp -): Case[] { - return generateScalarVectorBinaryToVectorCases(scalars, vectors, op, quantizeToU32, u32); -} - -/** - * @returns array of Case for the input params with op applied - * @param vectors array of vector params (2, 3, or 4 elements) - * @param scalars array of scalar params - * @param op he op to apply to each pair of vector and scalar - */ -export function generateVectorU32BinaryToVectorCases( - vectors: ROArrayArray<number>, - scalars: readonly number[], - op: BinaryOp -): Case[] { - return generateVectorScalarBinaryToVectorCases(vectors, scalars, op, quantizeToU32, u32); -} - -/** - * @returns array of Case for the input params with op applied - * @param scalars array of scalar params - * @param vectors array of vector params (2, 3, or 4 elements) - * @param op he op to apply to each pair of scalar and vector - */ -export function generateI32VectorBinaryToVectorCases( - scalars: readonly number[], - vectors: ROArrayArray<number>, - op: BinaryOp -): Case[] { - return generateScalarVectorBinaryToVectorCases(scalars, vectors, op, quantizeToI32, i32); -} - -/** - * @returns array of Case for the input params with op applied - * @param vectors array of vector params (2, 3, or 4 elements) - * @param scalars array of scalar params - * @param op he op to apply to each pair of vector and scalar - */ -export function generateVectorI32BinaryToVectorCases( - vectors: ROArrayArray<number>, - scalars: readonly number[], - op: BinaryOp -): Case[] { - return generateVectorScalarBinaryToVectorCases(vectors, scalars, op, quantizeToI32, i32); -} |