From 8dd16259287f58f9273002717ec4d27e97127719 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Wed, 12 Jun 2024 07:43:14 +0200 Subject: Merging upstream version 127.0. Signed-off-by: Daniel Baumann --- third_party/rust/wgpu-core/src/command/compute.rs | 318 ++++++++-------------- 1 file changed, 117 insertions(+), 201 deletions(-) (limited to 'third_party/rust/wgpu-core/src/command/compute.rs') diff --git a/third_party/rust/wgpu-core/src/command/compute.rs b/third_party/rust/wgpu-core/src/command/compute.rs index b38324984c..4ee48f0086 100644 --- a/third_party/rust/wgpu-core/src/command/compute.rs +++ b/third_party/rust/wgpu-core/src/command/compute.rs @@ -1,3 +1,4 @@ +use crate::command::compute_command::{ArcComputeCommand, ComputeCommand}; use crate::device::DeviceError; use crate::resource::Resource; use crate::snatch::SnatchGuard; @@ -20,7 +21,6 @@ use crate::{ hal_label, id, id::DeviceId, init_tracker::MemoryInitKind, - pipeline, resource::{self}, storage::Storage, track::{Tracker, UsageConflict, UsageScope}, @@ -36,61 +36,9 @@ use serde::Serialize; use thiserror::Error; +use std::sync::Arc; use std::{fmt, mem, str}; -#[doc(hidden)] -#[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub enum ComputeCommand { - SetBindGroup { - index: u32, - num_dynamic_offsets: usize, - bind_group_id: id::BindGroupId, - }, - SetPipeline(id::ComputePipelineId), - - /// Set a range of push constants to values stored in [`BasePass::push_constant_data`]. - SetPushConstant { - /// The byte offset within the push constant storage to write to. This - /// must be a multiple of four. - offset: u32, - - /// The number of bytes to write. This must be a multiple of four. - size_bytes: u32, - - /// Index in [`BasePass::push_constant_data`] of the start of the data - /// to be written. - /// - /// Note: this is not a byte offset like `offset`. Rather, it is the - /// index of the first `u32` element in `push_constant_data` to read. - values_offset: u32, - }, - - Dispatch([u32; 3]), - DispatchIndirect { - buffer_id: id::BufferId, - offset: wgt::BufferAddress, - }, - PushDebugGroup { - color: u32, - len: usize, - }, - PopDebugGroup, - InsertDebugMarker { - color: u32, - len: usize, - }, - WriteTimestamp { - query_set_id: id::QuerySetId, - query_index: u32, - }, - BeginPipelineStatisticsQuery { - query_set_id: id::QuerySetId, - query_index: u32, - }, - EndPipelineStatisticsQuery, -} - #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct ComputePass { base: BasePass, @@ -184,7 +132,7 @@ pub enum ComputePassErrorInner { #[error(transparent)] Encoder(#[from] CommandEncoderError), #[error("Bind group at index {0:?} is invalid")] - InvalidBindGroup(usize), + InvalidBindGroup(u32), #[error("Device {0:?} is invalid")] InvalidDevice(DeviceId), #[error("Bind group index {index} is greater than the device's requested `max_bind_group` limit {max}")] @@ -249,7 +197,7 @@ impl PrettyError for ComputePassErrorInner { pub struct ComputePassError { pub scope: PassErrorScope, #[source] - inner: ComputePassErrorInner, + pub(super) inner: ComputePassErrorInner, } impl PrettyError for ComputePassError { fn fmt_pretty(&self, fmt: &mut ErrorFormatter) { @@ -346,7 +294,8 @@ impl Global { encoder_id: id::CommandEncoderId, pass: &ComputePass, ) -> Result<(), ComputePassError> { - self.command_encoder_run_compute_pass_impl::( + // TODO: This should go directly to `command_encoder_run_compute_pass_impl` by means of storing `ArcComputeCommand` internally. + self.command_encoder_run_compute_pass_with_unresolved_commands::( encoder_id, pass.base.as_ref(), pass.timestamp_writes.as_ref(), @@ -354,18 +303,41 @@ impl Global { } #[doc(hidden)] - pub fn command_encoder_run_compute_pass_impl( + pub fn command_encoder_run_compute_pass_with_unresolved_commands( &self, encoder_id: id::CommandEncoderId, base: BasePassRef, timestamp_writes: Option<&ComputePassTimestampWrites>, + ) -> Result<(), ComputePassError> { + let resolved_commands = + ComputeCommand::resolve_compute_command_ids(A::hub(self), base.commands)?; + + self.command_encoder_run_compute_pass_impl::( + encoder_id, + BasePassRef { + label: base.label, + commands: &resolved_commands, + dynamic_offsets: base.dynamic_offsets, + string_data: base.string_data, + push_constant_data: base.push_constant_data, + }, + timestamp_writes, + ) + } + + fn command_encoder_run_compute_pass_impl( + &self, + encoder_id: id::CommandEncoderId, + base: BasePassRef>, + timestamp_writes: Option<&ComputePassTimestampWrites>, ) -> Result<(), ComputePassError> { profiling::scope!("CommandEncoder::run_compute_pass"); let pass_scope = PassErrorScope::Pass(encoder_id); let hub = A::hub(self); - let cmd_buf = CommandBuffer::get_encoder(hub, encoder_id).map_pass_err(pass_scope)?; + let cmd_buf: Arc> = + CommandBuffer::get_encoder(hub, encoder_id).map_pass_err(pass_scope)?; let device = &cmd_buf.device; if !device.is_valid() { return Err(ComputePassErrorInner::InvalidDevice( @@ -380,7 +352,13 @@ impl Global { #[cfg(feature = "trace")] if let Some(ref mut list) = cmd_buf_data.commands { list.push(crate::device::trace::Command::RunComputePass { - base: BasePass::from_ref(base), + base: BasePass { + label: base.label.map(str::to_string), + commands: base.commands.iter().map(Into::into).collect(), + dynamic_offsets: base.dynamic_offsets.to_vec(), + string_data: base.string_data.to_vec(), + push_constant_data: base.push_constant_data.to_vec(), + }, timestamp_writes: timestamp_writes.cloned(), }); } @@ -400,9 +378,7 @@ impl Global { let raw = encoder.open().map_pass_err(pass_scope)?; let bind_group_guard = hub.bind_groups.read(); - let pipeline_guard = hub.compute_pipelines.read(); let query_set_guard = hub.query_sets.read(); - let buffer_guard = hub.buffers.read(); let mut state = State { binder: Binder::new(), @@ -480,19 +456,21 @@ impl Global { // be inserted before texture reads. let mut pending_discard_init_fixups = SurfacesInDiscardState::new(); + // TODO: We should be draining the commands here, avoiding extra copies in the process. + // (A command encoder can't be executed twice!) for command in base.commands { - match *command { - ComputeCommand::SetBindGroup { + match command { + ArcComputeCommand::SetBindGroup { index, num_dynamic_offsets, - bind_group_id, + bind_group, } => { - let scope = PassErrorScope::SetBindGroup(bind_group_id); + let scope = PassErrorScope::SetBindGroup(bind_group.as_info().id()); let max_bind_groups = cmd_buf.limits.max_bind_groups; - if index >= max_bind_groups { + if index >= &max_bind_groups { return Err(ComputePassErrorInner::BindGroupIndexOutOfRange { - index, + index: *index, max: max_bind_groups, }) .map_pass_err(scope); @@ -505,13 +483,9 @@ impl Global { ); dynamic_offset_count += num_dynamic_offsets; - let bind_group = tracker - .bind_groups - .add_single(&*bind_group_guard, bind_group_id) - .ok_or(ComputePassErrorInner::InvalidBindGroup(index as usize)) - .map_pass_err(scope)?; + let bind_group = tracker.bind_groups.insert_single(bind_group.clone()); bind_group - .validate_dynamic_bindings(index, &temp_offsets, &cmd_buf.limits) + .validate_dynamic_bindings(*index, &temp_offsets, &cmd_buf.limits) .map_pass_err(scope)?; buffer_memory_init_actions.extend( @@ -533,14 +507,14 @@ impl Global { let entries = state .binder - .assign_group(index as usize, bind_group, &temp_offsets); + .assign_group(*index as usize, bind_group, &temp_offsets); if !entries.is_empty() && pipeline_layout.is_some() { let pipeline_layout = pipeline_layout.as_ref().unwrap().raw(); for (i, e) in entries.iter().enumerate() { if let Some(group) = e.group.as_ref() { let raw_bg = group .raw(&snatch_guard) - .ok_or(ComputePassErrorInner::InvalidBindGroup(i)) + .ok_or(ComputePassErrorInner::InvalidBindGroup(i as u32)) .map_pass_err(scope)?; unsafe { raw.set_bind_group( @@ -554,16 +528,13 @@ impl Global { } } } - ComputeCommand::SetPipeline(pipeline_id) => { + ArcComputeCommand::SetPipeline(pipeline) => { + let pipeline_id = pipeline.as_info().id(); let scope = PassErrorScope::SetPipelineCompute(pipeline_id); state.pipeline = Some(pipeline_id); - let pipeline: &pipeline::ComputePipeline = tracker - .compute_pipelines - .add_single(&*pipeline_guard, pipeline_id) - .ok_or(ComputePassErrorInner::InvalidPipeline(pipeline_id)) - .map_pass_err(scope)?; + tracker.compute_pipelines.insert_single(pipeline.clone()); unsafe { raw.set_compute_pipeline(pipeline.raw()); @@ -587,7 +558,7 @@ impl Global { if let Some(group) = e.group.as_ref() { let raw_bg = group .raw(&snatch_guard) - .ok_or(ComputePassErrorInner::InvalidBindGroup(i)) + .ok_or(ComputePassErrorInner::InvalidBindGroup(i as u32)) .map_pass_err(scope)?; unsafe { raw.set_bind_group( @@ -623,7 +594,7 @@ impl Global { } } } - ComputeCommand::SetPushConstant { + ArcComputeCommand::SetPushConstant { offset, size_bytes, values_offset, @@ -634,7 +605,7 @@ impl Global { let values_end_offset = (values_offset + size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize; let data_slice = - &base.push_constant_data[(values_offset as usize)..values_end_offset]; + &base.push_constant_data[(*values_offset as usize)..values_end_offset]; let pipeline_layout = state .binder @@ -649,7 +620,7 @@ impl Global { pipeline_layout .validate_push_constant_ranges( wgt::ShaderStages::COMPUTE, - offset, + *offset, end_offset_bytes, ) .map_pass_err(scope)?; @@ -658,12 +629,12 @@ impl Global { raw.set_push_constants( pipeline_layout.raw(), wgt::ShaderStages::COMPUTE, - offset, + *offset, data_slice, ); } } - ComputeCommand::Dispatch(groups) => { + ArcComputeCommand::Dispatch(groups) => { let scope = PassErrorScope::Dispatch { indirect: false, pipeline: state.pipeline, @@ -688,7 +659,7 @@ impl Global { { return Err(ComputePassErrorInner::Dispatch( DispatchError::InvalidGroupSize { - current: groups, + current: *groups, limit: groups_size_limit, }, )) @@ -696,10 +667,11 @@ impl Global { } unsafe { - raw.dispatch(groups); + raw.dispatch(*groups); } } - ComputeCommand::DispatchIndirect { buffer_id, offset } => { + ArcComputeCommand::DispatchIndirect { buffer, offset } => { + let buffer_id = buffer.as_info().id(); let scope = PassErrorScope::Dispatch { indirect: true, pipeline: state.pipeline, @@ -711,29 +683,25 @@ impl Global { .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION) .map_pass_err(scope)?; - let indirect_buffer = state + state .scope .buffers - .merge_single(&*buffer_guard, buffer_id, hal::BufferUses::INDIRECT) + .insert_merge_single(buffer.clone(), hal::BufferUses::INDIRECT) + .map_pass_err(scope)?; + check_buffer_usage(buffer_id, buffer.usage, wgt::BufferUsages::INDIRECT) .map_pass_err(scope)?; - check_buffer_usage( - buffer_id, - indirect_buffer.usage, - wgt::BufferUsages::INDIRECT, - ) - .map_pass_err(scope)?; let end_offset = offset + mem::size_of::() as u64; - if end_offset > indirect_buffer.size { + if end_offset > buffer.size { return Err(ComputePassErrorInner::IndirectBufferOverrun { - offset, + offset: *offset, end_offset, - buffer_size: indirect_buffer.size, + buffer_size: buffer.size, }) .map_pass_err(scope); } - let buf_raw = indirect_buffer + let buf_raw = buffer .raw .get(&snatch_guard) .ok_or(ComputePassErrorInner::InvalidIndirectBuffer(buffer_id)) @@ -742,9 +710,9 @@ impl Global { let stride = 3 * 4; // 3 integers, x/y/z group size buffer_memory_init_actions.extend( - indirect_buffer.initialization_status.read().create_action( - indirect_buffer, - offset..(offset + stride), + buffer.initialization_status.read().create_action( + buffer, + *offset..(*offset + stride), MemoryInitKind::NeedsInitializedMemory, ), ); @@ -754,15 +722,15 @@ impl Global { raw, &mut intermediate_trackers, &*bind_group_guard, - Some(indirect_buffer.as_info().tracker_index()), + Some(buffer.as_info().tracker_index()), &snatch_guard, ) .map_pass_err(scope)?; unsafe { - raw.dispatch_indirect(buf_raw, offset); + raw.dispatch_indirect(buf_raw, *offset); } } - ComputeCommand::PushDebugGroup { color: _, len } => { + ArcComputeCommand::PushDebugGroup { color: _, len } => { state.debug_scope_depth += 1; if !discard_hal_labels { let label = @@ -774,7 +742,7 @@ impl Global { } string_offset += len; } - ComputeCommand::PopDebugGroup => { + ArcComputeCommand::PopDebugGroup => { let scope = PassErrorScope::PopDebugGroup; if state.debug_scope_depth == 0 { @@ -788,7 +756,7 @@ impl Global { } } } - ComputeCommand::InsertDebugMarker { color: _, len } => { + ArcComputeCommand::InsertDebugMarker { color: _, len } => { if !discard_hal_labels { let label = str::from_utf8(&base.string_data[string_offset..string_offset + len]) @@ -797,49 +765,43 @@ impl Global { } string_offset += len; } - ComputeCommand::WriteTimestamp { - query_set_id, + ArcComputeCommand::WriteTimestamp { + query_set, query_index, } => { + let query_set_id = query_set.as_info().id(); let scope = PassErrorScope::WriteTimestamp; device .require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES) .map_pass_err(scope)?; - let query_set: &resource::QuerySet = tracker - .query_sets - .add_single(&*query_set_guard, query_set_id) - .ok_or(ComputePassErrorInner::InvalidQuerySet(query_set_id)) - .map_pass_err(scope)?; + let query_set = tracker.query_sets.insert_single(query_set.clone()); query_set - .validate_and_write_timestamp(raw, query_set_id, query_index, None) + .validate_and_write_timestamp(raw, query_set_id, *query_index, None) .map_pass_err(scope)?; } - ComputeCommand::BeginPipelineStatisticsQuery { - query_set_id, + ArcComputeCommand::BeginPipelineStatisticsQuery { + query_set, query_index, } => { + let query_set_id = query_set.as_info().id(); let scope = PassErrorScope::BeginPipelineStatisticsQuery; - let query_set: &resource::QuerySet = tracker - .query_sets - .add_single(&*query_set_guard, query_set_id) - .ok_or(ComputePassErrorInner::InvalidQuerySet(query_set_id)) - .map_pass_err(scope)?; + let query_set = tracker.query_sets.insert_single(query_set.clone()); query_set .validate_and_begin_pipeline_statistics_query( raw, query_set_id, - query_index, + *query_index, None, &mut active_query, ) .map_pass_err(scope)?; } - ComputeCommand::EndPipelineStatisticsQuery => { + ArcComputeCommand::EndPipelineStatisticsQuery => { let scope = PassErrorScope::EndPipelineStatisticsQuery; end_pipeline_statistics_query(raw, &*query_set_guard, &mut active_query) @@ -883,33 +845,24 @@ impl Global { } } -pub mod compute_ffi { +pub mod compute_commands { use super::{ComputeCommand, ComputePass}; - use crate::{id, RawString}; - use std::{convert::TryInto, ffi, slice}; + use crate::id; + use std::convert::TryInto; use wgt::{BufferAddress, DynamicOffset}; - /// # Safety - /// - /// This function is unsafe as there is no guarantee that the given pointer is - /// valid for `offset_length` elements. - #[no_mangle] - pub unsafe extern "C" fn wgpu_compute_pass_set_bind_group( + pub fn wgpu_compute_pass_set_bind_group( pass: &mut ComputePass, index: u32, bind_group_id: id::BindGroupId, - offsets: *const DynamicOffset, - offset_length: usize, + offsets: &[DynamicOffset], ) { - let redundant = unsafe { - pass.current_bind_groups.set_and_check_redundant( - bind_group_id, - index, - &mut pass.base.dynamic_offsets, - offsets, - offset_length, - ) - }; + let redundant = pass.current_bind_groups.set_and_check_redundant( + bind_group_id, + index, + &mut pass.base.dynamic_offsets, + offsets, + ); if redundant { return; @@ -917,13 +870,12 @@ pub mod compute_ffi { pass.base.commands.push(ComputeCommand::SetBindGroup { index, - num_dynamic_offsets: offset_length, + num_dynamic_offsets: offsets.len(), bind_group_id, }); } - #[no_mangle] - pub extern "C" fn wgpu_compute_pass_set_pipeline( + pub fn wgpu_compute_pass_set_pipeline( pass: &mut ComputePass, pipeline_id: id::ComputePipelineId, ) { @@ -936,47 +888,34 @@ pub mod compute_ffi { .push(ComputeCommand::SetPipeline(pipeline_id)); } - /// # Safety - /// - /// This function is unsafe as there is no guarantee that the given pointer is - /// valid for `size_bytes` bytes. - #[no_mangle] - pub unsafe extern "C" fn wgpu_compute_pass_set_push_constant( - pass: &mut ComputePass, - offset: u32, - size_bytes: u32, - data: *const u8, - ) { + pub fn wgpu_compute_pass_set_push_constant(pass: &mut ComputePass, offset: u32, data: &[u8]) { assert_eq!( offset & (wgt::PUSH_CONSTANT_ALIGNMENT - 1), 0, "Push constant offset must be aligned to 4 bytes." ); assert_eq!( - size_bytes & (wgt::PUSH_CONSTANT_ALIGNMENT - 1), + data.len() as u32 & (wgt::PUSH_CONSTANT_ALIGNMENT - 1), 0, "Push constant size must be aligned to 4 bytes." ); - let data_slice = unsafe { slice::from_raw_parts(data, size_bytes as usize) }; let value_offset = pass.base.push_constant_data.len().try_into().expect( "Ran out of push constant space. Don't set 4gb of push constants per ComputePass.", ); pass.base.push_constant_data.extend( - data_slice - .chunks_exact(wgt::PUSH_CONSTANT_ALIGNMENT as usize) + data.chunks_exact(wgt::PUSH_CONSTANT_ALIGNMENT as usize) .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])), ); pass.base.commands.push(ComputeCommand::SetPushConstant { offset, - size_bytes, + size_bytes: data.len() as u32, values_offset: value_offset, }); } - #[no_mangle] - pub extern "C" fn wgpu_compute_pass_dispatch_workgroups( + pub fn wgpu_compute_pass_dispatch_workgroups( pass: &mut ComputePass, groups_x: u32, groups_y: u32, @@ -987,8 +926,7 @@ pub mod compute_ffi { .push(ComputeCommand::Dispatch([groups_x, groups_y, groups_z])); } - #[no_mangle] - pub extern "C" fn wgpu_compute_pass_dispatch_workgroups_indirect( + pub fn wgpu_compute_pass_dispatch_workgroups_indirect( pass: &mut ComputePass, buffer_id: id::BufferId, offset: BufferAddress, @@ -998,17 +936,8 @@ pub mod compute_ffi { .push(ComputeCommand::DispatchIndirect { buffer_id, offset }); } - /// # Safety - /// - /// This function is unsafe as there is no guarantee that the given `label` - /// is a valid null-terminated string. - #[no_mangle] - pub unsafe extern "C" fn wgpu_compute_pass_push_debug_group( - pass: &mut ComputePass, - label: RawString, - color: u32, - ) { - let bytes = unsafe { ffi::CStr::from_ptr(label) }.to_bytes(); + pub fn wgpu_compute_pass_push_debug_group(pass: &mut ComputePass, label: &str, color: u32) { + let bytes = label.as_bytes(); pass.base.string_data.extend_from_slice(bytes); pass.base.commands.push(ComputeCommand::PushDebugGroup { @@ -1017,22 +946,12 @@ pub mod compute_ffi { }); } - #[no_mangle] - pub extern "C" fn wgpu_compute_pass_pop_debug_group(pass: &mut ComputePass) { + pub fn wgpu_compute_pass_pop_debug_group(pass: &mut ComputePass) { pass.base.commands.push(ComputeCommand::PopDebugGroup); } - /// # Safety - /// - /// This function is unsafe as there is no guarantee that the given `label` - /// is a valid null-terminated string. - #[no_mangle] - pub unsafe extern "C" fn wgpu_compute_pass_insert_debug_marker( - pass: &mut ComputePass, - label: RawString, - color: u32, - ) { - let bytes = unsafe { ffi::CStr::from_ptr(label) }.to_bytes(); + pub fn wgpu_compute_pass_insert_debug_marker(pass: &mut ComputePass, label: &str, color: u32) { + let bytes = label.as_bytes(); pass.base.string_data.extend_from_slice(bytes); pass.base.commands.push(ComputeCommand::InsertDebugMarker { @@ -1041,8 +960,7 @@ pub mod compute_ffi { }); } - #[no_mangle] - pub extern "C" fn wgpu_compute_pass_write_timestamp( + pub fn wgpu_compute_pass_write_timestamp( pass: &mut ComputePass, query_set_id: id::QuerySetId, query_index: u32, @@ -1053,8 +971,7 @@ pub mod compute_ffi { }); } - #[no_mangle] - pub extern "C" fn wgpu_compute_pass_begin_pipeline_statistics_query( + pub fn wgpu_compute_pass_begin_pipeline_statistics_query( pass: &mut ComputePass, query_set_id: id::QuerySetId, query_index: u32, @@ -1067,8 +984,7 @@ pub mod compute_ffi { }); } - #[no_mangle] - pub extern "C" fn wgpu_compute_pass_end_pipeline_statistics_query(pass: &mut ComputePass) { + pub fn wgpu_compute_pass_end_pipeline_statistics_query(pass: &mut ComputePass) { pass.base .commands .push(ComputeCommand::EndPipelineStatisticsQuery); -- cgit v1.2.3