summaryrefslogtreecommitdiffstats
path: root/dom/webgpu/tests/cts/checkout/src/webgpu/shader/execution/shader_io/compute_builtins.spec.ts
diff options
context:
space:
mode:
Diffstat (limited to 'dom/webgpu/tests/cts/checkout/src/webgpu/shader/execution/shader_io/compute_builtins.spec.ts')
-rw-r--r--dom/webgpu/tests/cts/checkout/src/webgpu/shader/execution/shader_io/compute_builtins.spec.ts175
1 files changed, 68 insertions, 107 deletions
diff --git a/dom/webgpu/tests/cts/checkout/src/webgpu/shader/execution/shader_io/compute_builtins.spec.ts b/dom/webgpu/tests/cts/checkout/src/webgpu/shader/execution/shader_io/compute_builtins.spec.ts
index fcf3159c64..a40b426332 100644
--- a/dom/webgpu/tests/cts/checkout/src/webgpu/shader/execution/shader_io/compute_builtins.spec.ts
+++ b/dom/webgpu/tests/cts/checkout/src/webgpu/shader/execution/shader_io/compute_builtins.spec.ts
@@ -1,7 +1,6 @@
export const description = `Test compute shader builtin variables`;
import { makeTestGroup } from '../../../../common/framework/test_group.js';
-import { iterRange } from '../../../../common/util/util.js';
import { GPUTest } from '../../../gpu_test.js';
export const g = makeTestGroup(GPUTest);
@@ -98,17 +97,14 @@ g.test('inputs')
// WGSL shader that stores every builtin value to a buffer, for every invocation in the grid.
const wgsl = `
- struct S {
- data : array<u32>
+ struct Outputs {
+ local_id: vec3u,
+ local_index: u32,
+ global_id: vec3u,
+ group_id: vec3u,
+ num_groups: vec3u,
};
- struct V {
- data : array<vec3<u32>>
- };
- @group(0) @binding(0) var<storage, read_write> local_id_out : V;
- @group(0) @binding(1) var<storage, read_write> local_index_out : S;
- @group(0) @binding(2) var<storage, read_write> global_id_out : V;
- @group(0) @binding(3) var<storage, read_write> group_id_out : V;
- @group(0) @binding(4) var<storage, read_write> num_groups_out : V;
+ @group(0) @binding(0) var<storage, read_write> outputs : array<Outputs>;
${structures}
@@ -122,11 +118,13 @@ g.test('inputs')
) {
let group_index = ((${group_id}.z * ${num_groups}.y) + ${group_id}.y) * ${num_groups}.x + ${group_id}.x;
let global_index = group_index * ${invocationsPerGroup}u + ${local_index};
- local_id_out.data[global_index] = ${local_id};
- local_index_out.data[global_index] = ${local_index};
- global_id_out.data[global_index] = ${global_id};
- group_id_out.data[global_index] = ${group_id};
- num_groups_out.data[global_index] = ${num_groups};
+ var o: Outputs;
+ o.local_id = ${local_id};
+ o.local_index = ${local_index};
+ o.global_id = ${global_id};
+ o.group_id = ${group_id};
+ o.num_groups = ${num_groups};
+ outputs[global_index] = o;
}
`;
@@ -140,35 +138,24 @@ g.test('inputs')
},
});
- // Helper to create a `size`-byte buffer with binding number `binding`.
- function createBuffer(size: number, binding: number) {
- const buffer = t.device.createBuffer({
- size,
- usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
- });
- t.trackForCleanup(buffer);
-
- bindGroupEntries.push({
- binding,
- resource: {
- buffer,
- },
- });
-
- return buffer;
- }
+ // Offsets are in u32 size units
+ const kLocalIdOffset = 0;
+ const kLocalIndexOffset = 3;
+ const kGlobalIdOffset = 4;
+ const kGroupIdOffset = 8;
+ const kNumGroupsOffset = 12;
+ const kOutputElementSize = 16;
// Create the output buffers.
- const bindGroupEntries: GPUBindGroupEntry[] = [];
- const localIdBuffer = createBuffer(totalInvocations * 16, 0);
- const localIndexBuffer = createBuffer(totalInvocations * 4, 1);
- const globalIdBuffer = createBuffer(totalInvocations * 16, 2);
- const groupIdBuffer = createBuffer(totalInvocations * 16, 3);
- const numGroupsBuffer = createBuffer(totalInvocations * 16, 4);
+ const outputBuffer = t.device.createBuffer({
+ size: totalInvocations * kOutputElementSize * 4,
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
+ });
+ t.trackForCleanup(outputBuffer);
const bindGroup = t.device.createBindGroup({
layout: pipeline.getBindGroupLayout(0),
- entries: bindGroupEntries,
+ entries: [{ binding: 0, resource: { buffer: outputBuffer } }],
});
// Run the shader.
@@ -204,11 +191,7 @@ g.test('inputs')
// Helper to check that the vec3<u32> value at each index of the provided `output` buffer
// matches the expected value for that invocation, as generated by the `getBuiltinValue`
// function. The `name` parameter is the builtin name, used for error messages.
- const checkEachIndex = (
- output: Uint32Array,
- name: string,
- getBuiltinValue: (groupId: vec3, localId: vec3) => vec3
- ) => {
+ const checkEachIndex = (output: Uint32Array) => {
// Loop over workgroups.
for (let gz = 0; gz < t.params.numGroups.z; gz++) {
for (let gy = 0; gy < t.params.numGroups.y; gy++) {
@@ -220,30 +203,44 @@ g.test('inputs')
const groupIndex = (gz * t.params.numGroups.y + gy) * t.params.numGroups.x + gx;
const localIndex = (lz * t.params.groupSize.y + ly) * t.params.groupSize.x + lx;
const globalIndex = groupIndex * invocationsPerGroup + localIndex;
- const expected = getBuiltinValue(
- { x: gx, y: gy, z: gz },
- { x: lx, y: ly, z: lz }
- );
- if (output[globalIndex * 4 + 0] !== expected.x) {
- return new Error(
- `${name}.x failed at group(${gx},${gy},${gz}) local(${lx},${ly},${lz}))\n` +
- ` expected: ${expected.x}\n` +
- ` got: ${output[globalIndex * 4 + 0]}`
- );
- }
- if (output[globalIndex * 4 + 1] !== expected.y) {
- return new Error(
- `${name}.y failed at group(${gx},${gy},${gz}) local(${lx},${ly},${lz}))\n` +
- ` expected: ${expected.y}\n` +
- ` got: ${output[globalIndex * 4 + 1]}`
+ const globalOffset = globalIndex * kOutputElementSize;
+
+ const expectEqual = (name: string, expected: number, actual: number) => {
+ if (actual !== expected) {
+ return new Error(
+ `${name} failed at group(${gx},${gy},${gz}) local(${lx},${ly},${lz}))\n` +
+ ` expected: ${expected}\n` +
+ ` got: ${actual}`
+ );
+ }
+ return undefined;
+ };
+
+ const checkVec3Value = (name: string, fieldOffset: number, expected: vec3) => {
+ const offset = globalOffset + fieldOffset;
+ return (
+ expectEqual(`${name}.x`, expected.x, output[offset + 0]) ||
+ expectEqual(`${name}.y`, expected.y, output[offset + 1]) ||
+ expectEqual(`${name}.z`, expected.z, output[offset + 2])
);
- }
- if (output[globalIndex * 4 + 2] !== expected.z) {
- return new Error(
- `${name}.z failed at group(${gx},${gy},${gz}) local(${lx},${ly},${lz}))\n` +
- ` expected: ${expected.z}\n` +
- ` got: ${output[globalIndex * 4 + 2]}`
+ };
+
+ const error =
+ checkVec3Value('local_id', kLocalIdOffset, { x: lx, y: ly, z: lz }) ||
+ checkVec3Value('global_id', kGlobalIdOffset, {
+ x: gx * t.params.groupSize.x + lx,
+ y: gy * t.params.groupSize.y + ly,
+ z: gz * t.params.groupSize.z + lz,
+ }) ||
+ checkVec3Value('group_id', kGroupIdOffset, { x: gx, y: gy, z: gz }) ||
+ checkVec3Value('num_groups', kNumGroupsOffset, t.params.numGroups) ||
+ expectEqual(
+ 'local_index',
+ localIndex,
+ output[globalOffset + kLocalIndexOffset]
);
+ if (error) {
+ return error;
}
}
}
@@ -254,44 +251,8 @@ g.test('inputs')
return undefined;
};
- // Check @builtin(local_invocation_index) values.
- t.expectGPUBufferValuesEqual(
- localIndexBuffer,
- new Uint32Array([...iterRange(totalInvocations, x => x % invocationsPerGroup)])
- );
-
- // Check @builtin(local_invocation_id) values.
- t.expectGPUBufferValuesPassCheck(
- localIdBuffer,
- outputData => checkEachIndex(outputData, 'local_invocation_id', (_, localId) => localId),
- { type: Uint32Array, typedLength: totalInvocations * 4 }
- );
-
- // Check @builtin(global_invocation_id) values.
- const getGlobalId = (groupId: vec3, localId: vec3) => {
- return {
- x: groupId.x * t.params.groupSize.x + localId.x,
- y: groupId.y * t.params.groupSize.y + localId.y,
- z: groupId.z * t.params.groupSize.z + localId.z,
- };
- };
- t.expectGPUBufferValuesPassCheck(
- globalIdBuffer,
- outputData => checkEachIndex(outputData, 'global_invocation_id', getGlobalId),
- { type: Uint32Array, typedLength: totalInvocations * 4 }
- );
-
- // Check @builtin(workgroup_id) values.
- t.expectGPUBufferValuesPassCheck(
- groupIdBuffer,
- outputData => checkEachIndex(outputData, 'workgroup_id', (groupId, _) => groupId),
- { type: Uint32Array, typedLength: totalInvocations * 4 }
- );
-
- // Check @builtin(num_workgroups) values.
- t.expectGPUBufferValuesPassCheck(
- numGroupsBuffer,
- outputData => checkEachIndex(outputData, 'num_workgroups', () => t.params.numGroups),
- { type: Uint32Array, typedLength: totalInvocations * 4 }
- );
+ t.expectGPUBufferValuesPassCheck(outputBuffer, outputData => checkEachIndex(outputData), {
+ type: Uint32Array,
+ typedLength: outputBuffer.size / 4,
+ });
});