summaryrefslogtreecommitdiffstats
path: root/dom/webgpu/tests/cts/checkout/src/webgpu/api/operation/command_buffer/programmable/programmable_state_test.ts
diff options
context:
space:
mode:
Diffstat (limited to 'dom/webgpu/tests/cts/checkout/src/webgpu/api/operation/command_buffer/programmable/programmable_state_test.ts')
-rw-r--r--dom/webgpu/tests/cts/checkout/src/webgpu/api/operation/command_buffer/programmable/programmable_state_test.ts157
1 files changed, 157 insertions, 0 deletions
diff --git a/dom/webgpu/tests/cts/checkout/src/webgpu/api/operation/command_buffer/programmable/programmable_state_test.ts b/dom/webgpu/tests/cts/checkout/src/webgpu/api/operation/command_buffer/programmable/programmable_state_test.ts
new file mode 100644
index 0000000000..19cf91419c
--- /dev/null
+++ b/dom/webgpu/tests/cts/checkout/src/webgpu/api/operation/command_buffer/programmable/programmable_state_test.ts
@@ -0,0 +1,157 @@
+import { unreachable } from '../../../../../common/util/util.js';
+import { GPUTest } from '../../../../gpu_test.js';
+import { EncoderType } from '../../../../util/command_buffer_maker.js';
+
+interface BindGroupIndices {
+ a: number;
+ b: number;
+ out: number;
+}
+
+export class ProgrammableStateTest extends GPUTest {
+ private commonBindGroupLayouts: Map<string, GPUBindGroupLayout> = new Map();
+
+ getBindGroupLayout(type: GPUBufferBindingType): GPUBindGroupLayout {
+ if (!this.commonBindGroupLayouts.has(type)) {
+ this.commonBindGroupLayouts.set(
+ type,
+ this.device.createBindGroupLayout({
+ entries: [
+ {
+ binding: 0,
+ visibility: GPUShaderStage.COMPUTE | GPUShaderStage.FRAGMENT,
+ buffer: { type },
+ },
+ ],
+ })
+ );
+ }
+ return this.commonBindGroupLayouts.get(type)!;
+ }
+
+ getBindGroupLayouts(indices: BindGroupIndices): GPUBindGroupLayout[] {
+ const bindGroupLayouts: GPUBindGroupLayout[] = [];
+ bindGroupLayouts[indices.a] = this.getBindGroupLayout('read-only-storage');
+ bindGroupLayouts[indices.b] = this.getBindGroupLayout('read-only-storage');
+ bindGroupLayouts[indices.out] = this.getBindGroupLayout('storage');
+ return bindGroupLayouts;
+ }
+
+ createBindGroup(buffer: GPUBuffer, type: GPUBufferBindingType): GPUBindGroup {
+ return this.device.createBindGroup({
+ layout: this.getBindGroupLayout(type),
+ entries: [{ binding: 0, resource: { buffer } }],
+ });
+ }
+
+ setBindGroup(
+ encoder: GPUBindingCommandsMixin,
+ index: number,
+ factory: (index: number) => GPUBindGroup
+ ) {
+ encoder.setBindGroup(index, factory(index));
+ }
+
+ // Create a compute pipeline that performs an operation on data from two bind groups,
+ // then writes the result to a third bind group.
+ createBindingStatePipeline<T extends EncoderType>(
+ encoderType: T,
+ groups: BindGroupIndices,
+ algorithm: string = 'a.value - b.value'
+ ): GPUComputePipeline | GPURenderPipeline {
+ switch (encoderType) {
+ case 'compute pass': {
+ const wgsl = `struct Data {
+ value : i32
+ };
+
+ @group(${groups.a}) @binding(0) var<storage> a : Data;
+ @group(${groups.b}) @binding(0) var<storage> b : Data;
+ @group(${groups.out}) @binding(0) var<storage, read_write> out : Data;
+
+ @compute @workgroup_size(1) fn main() {
+ out.value = ${algorithm};
+ return;
+ }
+ `;
+
+ return this.device.createComputePipeline({
+ layout: this.device.createPipelineLayout({
+ bindGroupLayouts: this.getBindGroupLayouts(groups),
+ }),
+ compute: {
+ module: this.device.createShaderModule({
+ code: wgsl,
+ }),
+ entryPoint: 'main',
+ },
+ });
+ }
+ case 'render pass':
+ case 'render bundle': {
+ const wgslShaders = {
+ vertex: `
+ @vertex fn vert_main() -> @builtin(position) vec4<f32> {
+ return vec4<f32>(0.5, 0.5, 0.0, 1.0);
+ }
+ `,
+
+ fragment: `
+ struct Data {
+ value : i32
+ };
+
+ @group(${groups.a}) @binding(0) var<storage> a : Data;
+ @group(${groups.b}) @binding(0) var<storage> b : Data;
+ @group(${groups.out}) @binding(0) var<storage, read_write> out : Data;
+
+ @fragment fn frag_main() -> @location(0) vec4<f32> {
+ out.value = ${algorithm};
+ return vec4<f32>(1.0, 0.0, 0.0, 1.0);
+ }
+ `,
+ };
+
+ return this.device.createRenderPipeline({
+ layout: this.device.createPipelineLayout({
+ bindGroupLayouts: this.getBindGroupLayouts(groups),
+ }),
+ vertex: {
+ module: this.device.createShaderModule({
+ code: wgslShaders.vertex,
+ }),
+ entryPoint: 'vert_main',
+ },
+ fragment: {
+ module: this.device.createShaderModule({
+ code: wgslShaders.fragment,
+ }),
+ entryPoint: 'frag_main',
+ targets: [{ format: 'rgba8unorm' }],
+ },
+ primitive: { topology: 'point-list' },
+ });
+ }
+ default:
+ unreachable();
+ }
+ }
+
+ setPipeline(pass: GPUBindingCommandsMixin, pipeline: GPUComputePipeline | GPURenderPipeline) {
+ if (pass instanceof GPUComputePassEncoder) {
+ pass.setPipeline(pipeline as GPUComputePipeline);
+ } else if (pass instanceof GPURenderPassEncoder || pass instanceof GPURenderBundleEncoder) {
+ pass.setPipeline(pipeline as GPURenderPipeline);
+ }
+ }
+
+ dispatchOrDraw(pass: GPUBindingCommandsMixin) {
+ if (pass instanceof GPUComputePassEncoder) {
+ pass.dispatchWorkgroups(1);
+ } else if (pass instanceof GPURenderPassEncoder) {
+ pass.draw(1);
+ } else if (pass instanceof GPURenderBundleEncoder) {
+ pass.draw(1);
+ }
+ }
+}