summaryrefslogtreecommitdiffstats
path: root/third_party/rust/wgpu-core/src/command/compute_command.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/wgpu-core/src/command/compute_command.rs')
-rw-r--r--third_party/rust/wgpu-core/src/command/compute_command.rs322
1 files changed, 322 insertions, 0 deletions
diff --git a/third_party/rust/wgpu-core/src/command/compute_command.rs b/third_party/rust/wgpu-core/src/command/compute_command.rs
new file mode 100644
index 0000000000..49fdbbec24
--- /dev/null
+++ b/third_party/rust/wgpu-core/src/command/compute_command.rs
@@ -0,0 +1,322 @@
+use std::sync::Arc;
+
+use crate::{
+ binding_model::BindGroup,
+ hal_api::HalApi,
+ id,
+ pipeline::ComputePipeline,
+ resource::{Buffer, QuerySet},
+};
+
+use super::{ComputePassError, ComputePassErrorInner, PassErrorScope};
+
+#[derive(Clone, Copy, Debug)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub enum ComputeCommand {
+ SetBindGroup {
+ index: u32,
+ num_dynamic_offsets: usize,
+ bind_group_id: id::BindGroupId,
+ },
+
+ SetPipeline(id::ComputePipelineId),
+
+ /// Set a range of push constants to values stored in `push_constant_data`.
+ SetPushConstant {
+ /// The byte offset within the push constant storage to write to. This
+ /// must be a multiple of four.
+ offset: u32,
+
+ /// The number of bytes to write. This must be a multiple of four.
+ size_bytes: u32,
+
+ /// Index in `push_constant_data` of the start of the data
+ /// to be written.
+ ///
+ /// Note: this is not a byte offset like `offset`. Rather, it is the
+ /// index of the first `u32` element in `push_constant_data` to read.
+ values_offset: u32,
+ },
+
+ Dispatch([u32; 3]),
+
+ DispatchIndirect {
+ buffer_id: id::BufferId,
+ offset: wgt::BufferAddress,
+ },
+
+ PushDebugGroup {
+ color: u32,
+ len: usize,
+ },
+
+ PopDebugGroup,
+
+ InsertDebugMarker {
+ color: u32,
+ len: usize,
+ },
+
+ WriteTimestamp {
+ query_set_id: id::QuerySetId,
+ query_index: u32,
+ },
+
+ BeginPipelineStatisticsQuery {
+ query_set_id: id::QuerySetId,
+ query_index: u32,
+ },
+
+ EndPipelineStatisticsQuery,
+}
+
+impl ComputeCommand {
+ /// Resolves all ids in a list of commands into the corresponding resource Arc.
+ ///
+ // TODO: Once resolving is done on-the-fly during recording, this function should be only needed with the replay feature:
+ // #[cfg(feature = "replay")]
+ pub fn resolve_compute_command_ids<A: HalApi>(
+ hub: &crate::hub::Hub<A>,
+ commands: &[ComputeCommand],
+ ) -> Result<Vec<ArcComputeCommand<A>>, ComputePassError> {
+ let buffers_guard = hub.buffers.read();
+ let bind_group_guard = hub.bind_groups.read();
+ let query_set_guard = hub.query_sets.read();
+ let pipelines_guard = hub.compute_pipelines.read();
+
+ let resolved_commands: Vec<ArcComputeCommand<A>> = commands
+ .iter()
+ .map(|c| -> Result<ArcComputeCommand<A>, ComputePassError> {
+ Ok(match *c {
+ ComputeCommand::SetBindGroup {
+ index,
+ num_dynamic_offsets,
+ bind_group_id,
+ } => ArcComputeCommand::SetBindGroup {
+ index,
+ num_dynamic_offsets,
+ bind_group: bind_group_guard.get_owned(bind_group_id).map_err(|_| {
+ ComputePassError {
+ scope: PassErrorScope::SetBindGroup(bind_group_id),
+ inner: ComputePassErrorInner::InvalidBindGroup(index),
+ }
+ })?,
+ },
+
+ ComputeCommand::SetPipeline(pipeline_id) => ArcComputeCommand::SetPipeline(
+ pipelines_guard
+ .get_owned(pipeline_id)
+ .map_err(|_| ComputePassError {
+ scope: PassErrorScope::SetPipelineCompute(pipeline_id),
+ inner: ComputePassErrorInner::InvalidPipeline(pipeline_id),
+ })?,
+ ),
+
+ ComputeCommand::SetPushConstant {
+ offset,
+ size_bytes,
+ values_offset,
+ } => ArcComputeCommand::SetPushConstant {
+ offset,
+ size_bytes,
+ values_offset,
+ },
+
+ ComputeCommand::Dispatch(dim) => ArcComputeCommand::Dispatch(dim),
+
+ ComputeCommand::DispatchIndirect { buffer_id, offset } => {
+ ArcComputeCommand::DispatchIndirect {
+ buffer: buffers_guard.get_owned(buffer_id).map_err(|_| {
+ ComputePassError {
+ scope: PassErrorScope::Dispatch {
+ indirect: true,
+ pipeline: None, // TODO: not used right now, but once we do the resolve during recording we can use this again.
+ },
+ inner: ComputePassErrorInner::InvalidBuffer(buffer_id),
+ }
+ })?,
+ offset,
+ }
+ }
+
+ ComputeCommand::PushDebugGroup { color, len } => {
+ ArcComputeCommand::PushDebugGroup { color, len }
+ }
+
+ ComputeCommand::PopDebugGroup => ArcComputeCommand::PopDebugGroup,
+
+ ComputeCommand::InsertDebugMarker { color, len } => {
+ ArcComputeCommand::InsertDebugMarker { color, len }
+ }
+
+ ComputeCommand::WriteTimestamp {
+ query_set_id,
+ query_index,
+ } => ArcComputeCommand::WriteTimestamp {
+ query_set: query_set_guard.get_owned(query_set_id).map_err(|_| {
+ ComputePassError {
+ scope: PassErrorScope::WriteTimestamp,
+ inner: ComputePassErrorInner::InvalidQuerySet(query_set_id),
+ }
+ })?,
+ query_index,
+ },
+
+ ComputeCommand::BeginPipelineStatisticsQuery {
+ query_set_id,
+ query_index,
+ } => ArcComputeCommand::BeginPipelineStatisticsQuery {
+ query_set: query_set_guard.get_owned(query_set_id).map_err(|_| {
+ ComputePassError {
+ scope: PassErrorScope::BeginPipelineStatisticsQuery,
+ inner: ComputePassErrorInner::InvalidQuerySet(query_set_id),
+ }
+ })?,
+ query_index,
+ },
+
+ ComputeCommand::EndPipelineStatisticsQuery => {
+ ArcComputeCommand::EndPipelineStatisticsQuery
+ }
+ })
+ })
+ .collect::<Result<Vec<_>, ComputePassError>>()?;
+ Ok(resolved_commands)
+ }
+}
+
+/// Equivalent to `ComputeCommand` but the Ids resolved into resource Arcs.
+#[derive(Clone, Debug)]
+pub enum ArcComputeCommand<A: HalApi> {
+ SetBindGroup {
+ index: u32,
+ num_dynamic_offsets: usize,
+ bind_group: Arc<BindGroup<A>>,
+ },
+
+ SetPipeline(Arc<ComputePipeline<A>>),
+
+ /// Set a range of push constants to values stored in `push_constant_data`.
+ SetPushConstant {
+ /// The byte offset within the push constant storage to write to. This
+ /// must be a multiple of four.
+ offset: u32,
+
+ /// The number of bytes to write. This must be a multiple of four.
+ size_bytes: u32,
+
+ /// Index in `push_constant_data` of the start of the data
+ /// to be written.
+ ///
+ /// Note: this is not a byte offset like `offset`. Rather, it is the
+ /// index of the first `u32` element in `push_constant_data` to read.
+ values_offset: u32,
+ },
+
+ Dispatch([u32; 3]),
+
+ DispatchIndirect {
+ buffer: Arc<Buffer<A>>,
+ offset: wgt::BufferAddress,
+ },
+
+ PushDebugGroup {
+ color: u32,
+ len: usize,
+ },
+
+ PopDebugGroup,
+
+ InsertDebugMarker {
+ color: u32,
+ len: usize,
+ },
+
+ WriteTimestamp {
+ query_set: Arc<QuerySet<A>>,
+ query_index: u32,
+ },
+
+ BeginPipelineStatisticsQuery {
+ query_set: Arc<QuerySet<A>>,
+ query_index: u32,
+ },
+
+ EndPipelineStatisticsQuery,
+}
+
+#[cfg(feature = "trace")]
+impl<A: HalApi> From<&ArcComputeCommand<A>> for ComputeCommand {
+ fn from(value: &ArcComputeCommand<A>) -> Self {
+ use crate::resource::Resource as _;
+
+ match value {
+ ArcComputeCommand::SetBindGroup {
+ index,
+ num_dynamic_offsets,
+ bind_group,
+ } => ComputeCommand::SetBindGroup {
+ index: *index,
+ num_dynamic_offsets: *num_dynamic_offsets,
+ bind_group_id: bind_group.as_info().id(),
+ },
+
+ ArcComputeCommand::SetPipeline(pipeline) => {
+ ComputeCommand::SetPipeline(pipeline.as_info().id())
+ }
+
+ ArcComputeCommand::SetPushConstant {
+ offset,
+ size_bytes,
+ values_offset,
+ } => ComputeCommand::SetPushConstant {
+ offset: *offset,
+ size_bytes: *size_bytes,
+ values_offset: *values_offset,
+ },
+
+ ArcComputeCommand::Dispatch(dim) => ComputeCommand::Dispatch(*dim),
+
+ ArcComputeCommand::DispatchIndirect { buffer, offset } => {
+ ComputeCommand::DispatchIndirect {
+ buffer_id: buffer.as_info().id(),
+ offset: *offset,
+ }
+ }
+
+ ArcComputeCommand::PushDebugGroup { color, len } => ComputeCommand::PushDebugGroup {
+ color: *color,
+ len: *len,
+ },
+
+ ArcComputeCommand::PopDebugGroup => ComputeCommand::PopDebugGroup,
+
+ ArcComputeCommand::InsertDebugMarker { color, len } => {
+ ComputeCommand::InsertDebugMarker {
+ color: *color,
+ len: *len,
+ }
+ }
+
+ ArcComputeCommand::WriteTimestamp {
+ query_set,
+ query_index,
+ } => ComputeCommand::WriteTimestamp {
+ query_set_id: query_set.as_info().id(),
+ query_index: *query_index,
+ },
+
+ ArcComputeCommand::BeginPipelineStatisticsQuery {
+ query_set,
+ query_index,
+ } => ComputeCommand::BeginPipelineStatisticsQuery {
+ query_set_id: query_set.as_info().id(),
+ query_index: *query_index,
+ },
+
+ ArcComputeCommand::EndPipelineStatisticsQuery => {
+ ComputeCommand::EndPipelineStatisticsQuery
+ }
+ }
+ }
+}