diff options
Diffstat (limited to 'third_party/rust/wgpu-hal/src/metal/command.rs')
-rw-r--r-- | third_party/rust/wgpu-hal/src/metal/command.rs | 972 |
1 files changed, 972 insertions, 0 deletions
diff --git a/third_party/rust/wgpu-hal/src/metal/command.rs b/third_party/rust/wgpu-hal/src/metal/command.rs new file mode 100644 index 0000000000..3b94c71ed0 --- /dev/null +++ b/third_party/rust/wgpu-hal/src/metal/command.rs @@ -0,0 +1,972 @@ +use super::{conv, AsNative}; +use std::{mem, ops::Range}; + +// has to match `Temp::binding_sizes` +const WORD_SIZE: usize = 4; + +impl Default for super::CommandState { + fn default() -> Self { + Self { + blit: None, + render: None, + compute: None, + raw_primitive_type: mtl::MTLPrimitiveType::Point, + index: None, + raw_wg_size: mtl::MTLSize::new(0, 0, 0), + stage_infos: Default::default(), + storage_buffer_length_map: Default::default(), + work_group_memory_sizes: Vec::new(), + push_constants: Vec::new(), + } + } +} + +impl super::CommandEncoder { + fn enter_blit(&mut self) -> &mtl::BlitCommandEncoderRef { + if self.state.blit.is_none() { + debug_assert!(self.state.render.is_none() && self.state.compute.is_none()); + objc::rc::autoreleasepool(|| { + let cmd_buf = self.raw_cmd_buf.as_ref().unwrap(); + self.state.blit = Some(cmd_buf.new_blit_command_encoder().to_owned()); + }); + } + self.state.blit.as_ref().unwrap() + } + + pub(super) fn leave_blit(&mut self) { + if let Some(encoder) = self.state.blit.take() { + encoder.end_encoding(); + } + } + + fn enter_any(&mut self) -> Option<&mtl::CommandEncoderRef> { + if let Some(ref encoder) = self.state.render { + Some(encoder) + } else if let Some(ref encoder) = self.state.compute { + Some(encoder) + } else if let Some(ref encoder) = self.state.blit { + Some(encoder) + } else { + None + } + } + + fn begin_pass(&mut self) { + self.state.reset(); + self.leave_blit(); + } +} + +impl super::CommandState { + fn reset(&mut self) { + self.storage_buffer_length_map.clear(); + self.stage_infos.vs.clear(); + self.stage_infos.fs.clear(); + self.stage_infos.cs.clear(); + self.work_group_memory_sizes.clear(); + self.push_constants.clear(); + } + + fn make_sizes_buffer_update<'a>( + &self, + stage: naga::ShaderStage, + result_sizes: &'a mut Vec<u32>, + ) -> Option<(u32, &'a [u32])> { + let stage_info = &self.stage_infos[stage]; + let slot = stage_info.sizes_slot?; + + result_sizes.clear(); + result_sizes.extend(stage_info.sized_bindings.iter().map(|br| { + self.storage_buffer_length_map + .get(br) + .map(|size| u32::try_from(size.get()).unwrap_or(u32::MAX)) + .unwrap_or_default() + })); + + if !result_sizes.is_empty() { + Some((slot as _, result_sizes)) + } else { + None + } + } +} + +impl crate::CommandEncoder<super::Api> for super::CommandEncoder { + unsafe fn begin_encoding(&mut self, label: crate::Label) -> Result<(), crate::DeviceError> { + let queue = &self.raw_queue.lock(); + let retain_references = self.shared.settings.retain_command_buffer_references; + let raw = objc::rc::autoreleasepool(move || { + let cmd_buf_ref = if retain_references { + queue.new_command_buffer() + } else { + queue.new_command_buffer_with_unretained_references() + }; + if let Some(label) = label { + cmd_buf_ref.set_label(label); + } + cmd_buf_ref.to_owned() + }); + + self.raw_cmd_buf = Some(raw); + + Ok(()) + } + + unsafe fn discard_encoding(&mut self) { + self.leave_blit(); + // when discarding, we don't have a guarantee that + // everything is in a good state, so check carefully + if let Some(encoder) = self.state.render.take() { + encoder.end_encoding(); + } + if let Some(encoder) = self.state.compute.take() { + encoder.end_encoding(); + } + self.raw_cmd_buf = None; + } + + unsafe fn end_encoding(&mut self) -> Result<super::CommandBuffer, crate::DeviceError> { + self.leave_blit(); + assert!(self.state.render.is_none()); + assert!(self.state.compute.is_none()); + Ok(super::CommandBuffer { + raw: self.raw_cmd_buf.take().unwrap(), + }) + } + + unsafe fn reset_all<I>(&mut self, _cmd_bufs: I) + where + I: Iterator<Item = super::CommandBuffer>, + { + //do nothing + } + + unsafe fn transition_buffers<'a, T>(&mut self, _barriers: T) + where + T: Iterator<Item = crate::BufferBarrier<'a, super::Api>>, + { + } + + unsafe fn transition_textures<'a, T>(&mut self, _barriers: T) + where + T: Iterator<Item = crate::TextureBarrier<'a, super::Api>>, + { + } + + unsafe fn clear_buffer(&mut self, buffer: &super::Buffer, range: crate::MemoryRange) { + let encoder = self.enter_blit(); + encoder.fill_buffer(&buffer.raw, conv::map_range(&range), 0); + } + + unsafe fn copy_buffer_to_buffer<T>( + &mut self, + src: &super::Buffer, + dst: &super::Buffer, + regions: T, + ) where + T: Iterator<Item = crate::BufferCopy>, + { + let encoder = self.enter_blit(); + for copy in regions { + encoder.copy_from_buffer( + &src.raw, + copy.src_offset, + &dst.raw, + copy.dst_offset, + copy.size.get(), + ); + } + } + + unsafe fn copy_texture_to_texture<T>( + &mut self, + src: &super::Texture, + _src_usage: crate::TextureUses, + dst: &super::Texture, + regions: T, + ) where + T: Iterator<Item = crate::TextureCopy>, + { + let encoder = self.enter_blit(); + for copy in regions { + let src_origin = conv::map_origin(©.src_base.origin); + let dst_origin = conv::map_origin(©.dst_base.origin); + // no clamping is done: Metal expects physical sizes here + let extent = conv::map_copy_extent(©.size); + encoder.copy_from_texture( + &src.raw, + copy.src_base.array_layer as u64, + copy.src_base.mip_level as u64, + src_origin, + extent, + &dst.raw, + copy.dst_base.array_layer as u64, + copy.dst_base.mip_level as u64, + dst_origin, + ); + } + } + + unsafe fn copy_buffer_to_texture<T>( + &mut self, + src: &super::Buffer, + dst: &super::Texture, + regions: T, + ) where + T: Iterator<Item = crate::BufferTextureCopy>, + { + let encoder = self.enter_blit(); + for copy in regions { + let dst_origin = conv::map_origin(©.texture_base.origin); + // Metal expects buffer-texture copies in virtual sizes + let extent = copy + .texture_base + .max_copy_size(&dst.copy_size) + .min(©.size); + let bytes_per_row = copy + .buffer_layout + .bytes_per_row + .map_or(0, |v| v.get() as u64); + let image_byte_stride = if extent.depth > 1 { + copy.buffer_layout + .rows_per_image + .map_or(0, |v| v.get() as u64 * bytes_per_row) + } else { + // Don't pass a stride when updating a single layer, otherwise metal validation + // fails when updating a subset of the image due to the stride being larger than + // the amount of data to copy. + 0 + }; + encoder.copy_from_buffer_to_texture( + &src.raw, + copy.buffer_layout.offset, + bytes_per_row, + image_byte_stride, + conv::map_copy_extent(&extent), + &dst.raw, + copy.texture_base.array_layer as u64, + copy.texture_base.mip_level as u64, + dst_origin, + mtl::MTLBlitOption::empty(), + ); + } + } + + unsafe fn copy_texture_to_buffer<T>( + &mut self, + src: &super::Texture, + _src_usage: crate::TextureUses, + dst: &super::Buffer, + regions: T, + ) where + T: Iterator<Item = crate::BufferTextureCopy>, + { + let encoder = self.enter_blit(); + for copy in regions { + let src_origin = conv::map_origin(©.texture_base.origin); + // Metal expects texture-buffer copies in virtual sizes + let extent = copy + .texture_base + .max_copy_size(&src.copy_size) + .min(©.size); + let bytes_per_row = copy + .buffer_layout + .bytes_per_row + .map_or(0, |v| v.get() as u64); + let bytes_per_image = copy + .buffer_layout + .rows_per_image + .map_or(0, |v| v.get() as u64 * bytes_per_row); + encoder.copy_from_texture_to_buffer( + &src.raw, + copy.texture_base.array_layer as u64, + copy.texture_base.mip_level as u64, + src_origin, + conv::map_copy_extent(&extent), + &dst.raw, + copy.buffer_layout.offset, + bytes_per_row, + bytes_per_image, + mtl::MTLBlitOption::empty(), + ); + } + } + + unsafe fn begin_query(&mut self, set: &super::QuerySet, index: u32) { + match set.ty { + wgt::QueryType::Occlusion => { + self.state + .render + .as_ref() + .unwrap() + .set_visibility_result_mode( + mtl::MTLVisibilityResultMode::Boolean, + index as u64 * crate::QUERY_SIZE, + ); + } + _ => {} + } + } + unsafe fn end_query(&mut self, set: &super::QuerySet, _index: u32) { + match set.ty { + wgt::QueryType::Occlusion => { + self.state + .render + .as_ref() + .unwrap() + .set_visibility_result_mode(mtl::MTLVisibilityResultMode::Disabled, 0); + } + _ => {} + } + } + unsafe fn write_timestamp(&mut self, _set: &super::QuerySet, _index: u32) {} + unsafe fn reset_queries(&mut self, set: &super::QuerySet, range: Range<u32>) { + let encoder = self.enter_blit(); + let raw_range = mtl::NSRange { + location: range.start as u64 * crate::QUERY_SIZE, + length: (range.end - range.start) as u64 * crate::QUERY_SIZE, + }; + encoder.fill_buffer(&set.raw_buffer, raw_range, 0); + } + unsafe fn copy_query_results( + &mut self, + set: &super::QuerySet, + range: Range<u32>, + buffer: &super::Buffer, + offset: wgt::BufferAddress, + _: wgt::BufferSize, // Metal doesn't support queries that are bigger than a single element are not supported + ) { + let encoder = self.enter_blit(); + let size = (range.end - range.start) as u64 * crate::QUERY_SIZE; + encoder.copy_from_buffer( + &set.raw_buffer, + range.start as u64 * crate::QUERY_SIZE, + &buffer.raw, + offset, + size, + ); + } + + // render + + unsafe fn begin_render_pass(&mut self, desc: &crate::RenderPassDescriptor<super::Api>) { + self.begin_pass(); + self.state.index = None; + + objc::rc::autoreleasepool(|| { + let descriptor = mtl::RenderPassDescriptor::new(); + //TODO: set visibility results buffer + + for (i, at) in desc.color_attachments.iter().enumerate() { + if let Some(at) = at.as_ref() { + let at_descriptor = descriptor.color_attachments().object_at(i as u64).unwrap(); + at_descriptor.set_texture(Some(&at.target.view.raw)); + if let Some(ref resolve) = at.resolve_target { + //Note: the selection of levels and slices is already handled by `TextureView` + at_descriptor.set_resolve_texture(Some(&resolve.view.raw)); + } + let load_action = if at.ops.contains(crate::AttachmentOps::LOAD) { + mtl::MTLLoadAction::Load + } else { + at_descriptor.set_clear_color(conv::map_clear_color(&at.clear_value)); + mtl::MTLLoadAction::Clear + }; + let store_action = conv::map_store_action( + at.ops.contains(crate::AttachmentOps::STORE), + at.resolve_target.is_some(), + ); + at_descriptor.set_load_action(load_action); + at_descriptor.set_store_action(store_action); + } + } + + if let Some(ref at) = desc.depth_stencil_attachment { + if at.target.view.aspects.contains(crate::FormatAspects::DEPTH) { + let at_descriptor = descriptor.depth_attachment().unwrap(); + at_descriptor.set_texture(Some(&at.target.view.raw)); + + let load_action = if at.depth_ops.contains(crate::AttachmentOps::LOAD) { + mtl::MTLLoadAction::Load + } else { + at_descriptor.set_clear_depth(at.clear_value.0 as f64); + mtl::MTLLoadAction::Clear + }; + let store_action = if at.depth_ops.contains(crate::AttachmentOps::STORE) { + mtl::MTLStoreAction::Store + } else { + mtl::MTLStoreAction::DontCare + }; + at_descriptor.set_load_action(load_action); + at_descriptor.set_store_action(store_action); + } + if at + .target + .view + .aspects + .contains(crate::FormatAspects::STENCIL) + { + let at_descriptor = descriptor.stencil_attachment().unwrap(); + at_descriptor.set_texture(Some(&at.target.view.raw)); + + let load_action = if at.stencil_ops.contains(crate::AttachmentOps::LOAD) { + mtl::MTLLoadAction::Load + } else { + at_descriptor.set_clear_stencil(at.clear_value.1); + mtl::MTLLoadAction::Clear + }; + let store_action = if at.stencil_ops.contains(crate::AttachmentOps::STORE) { + mtl::MTLStoreAction::Store + } else { + mtl::MTLStoreAction::DontCare + }; + at_descriptor.set_load_action(load_action); + at_descriptor.set_store_action(store_action); + } + } + + let raw = self.raw_cmd_buf.as_ref().unwrap(); + let encoder = raw.new_render_command_encoder(descriptor); + if let Some(label) = desc.label { + encoder.set_label(label); + } + self.state.render = Some(encoder.to_owned()); + }); + } + + unsafe fn end_render_pass(&mut self) { + self.state.render.take().unwrap().end_encoding(); + } + + unsafe fn set_bind_group( + &mut self, + layout: &super::PipelineLayout, + group_index: u32, + group: &super::BindGroup, + dynamic_offsets: &[wgt::DynamicOffset], + ) { + let bg_info = &layout.bind_group_infos[group_index as usize]; + + if let Some(ref encoder) = self.state.render { + let mut changes_sizes_buffer = false; + for index in 0..group.counters.vs.buffers { + let buf = &group.buffers[index as usize]; + let mut offset = buf.offset; + if let Some(dyn_index) = buf.dynamic_index { + offset += dynamic_offsets[dyn_index as usize] as wgt::BufferAddress; + } + encoder.set_vertex_buffer( + (bg_info.base_resource_indices.vs.buffers + index) as u64, + Some(buf.ptr.as_native()), + offset, + ); + if let Some(size) = buf.binding_size { + let br = naga::ResourceBinding { + group: group_index, + binding: buf.binding_location, + }; + self.state.storage_buffer_length_map.insert(br, size); + changes_sizes_buffer = true; + } + } + if changes_sizes_buffer { + if let Some((index, sizes)) = self.state.make_sizes_buffer_update( + naga::ShaderStage::Vertex, + &mut self.temp.binding_sizes, + ) { + encoder.set_vertex_bytes( + index as _, + (sizes.len() * WORD_SIZE) as u64, + sizes.as_ptr() as _, + ); + } + } + + changes_sizes_buffer = false; + for index in 0..group.counters.fs.buffers { + let buf = &group.buffers[(group.counters.vs.buffers + index) as usize]; + let mut offset = buf.offset; + if let Some(dyn_index) = buf.dynamic_index { + offset += dynamic_offsets[dyn_index as usize] as wgt::BufferAddress; + } + encoder.set_fragment_buffer( + (bg_info.base_resource_indices.fs.buffers + index) as u64, + Some(buf.ptr.as_native()), + offset, + ); + if let Some(size) = buf.binding_size { + let br = naga::ResourceBinding { + group: group_index, + binding: buf.binding_location, + }; + self.state.storage_buffer_length_map.insert(br, size); + changes_sizes_buffer = true; + } + } + if changes_sizes_buffer { + if let Some((index, sizes)) = self.state.make_sizes_buffer_update( + naga::ShaderStage::Fragment, + &mut self.temp.binding_sizes, + ) { + encoder.set_fragment_bytes( + index as _, + (sizes.len() * WORD_SIZE) as u64, + sizes.as_ptr() as _, + ); + } + } + + for index in 0..group.counters.vs.samplers { + let res = group.samplers[index as usize]; + encoder.set_vertex_sampler_state( + (bg_info.base_resource_indices.vs.samplers + index) as u64, + Some(res.as_native()), + ); + } + for index in 0..group.counters.fs.samplers { + let res = group.samplers[(group.counters.vs.samplers + index) as usize]; + encoder.set_fragment_sampler_state( + (bg_info.base_resource_indices.fs.samplers + index) as u64, + Some(res.as_native()), + ); + } + + for index in 0..group.counters.vs.textures { + let res = group.textures[index as usize]; + encoder.set_vertex_texture( + (bg_info.base_resource_indices.vs.textures + index) as u64, + Some(res.as_native()), + ); + } + for index in 0..group.counters.fs.textures { + let res = group.textures[(group.counters.vs.textures + index) as usize]; + encoder.set_fragment_texture( + (bg_info.base_resource_indices.fs.textures + index) as u64, + Some(res.as_native()), + ); + } + } + + if let Some(ref encoder) = self.state.compute { + let index_base = super::ResourceData { + buffers: group.counters.vs.buffers + group.counters.fs.buffers, + samplers: group.counters.vs.samplers + group.counters.fs.samplers, + textures: group.counters.vs.textures + group.counters.fs.textures, + }; + + let mut changes_sizes_buffer = false; + for index in 0..group.counters.cs.buffers { + let buf = &group.buffers[(index_base.buffers + index) as usize]; + let mut offset = buf.offset; + if let Some(dyn_index) = buf.dynamic_index { + offset += dynamic_offsets[dyn_index as usize] as wgt::BufferAddress; + } + encoder.set_buffer( + (bg_info.base_resource_indices.cs.buffers + index) as u64, + Some(buf.ptr.as_native()), + offset, + ); + if let Some(size) = buf.binding_size { + let br = naga::ResourceBinding { + group: group_index, + binding: buf.binding_location, + }; + self.state.storage_buffer_length_map.insert(br, size); + changes_sizes_buffer = true; + } + } + if changes_sizes_buffer { + if let Some((index, sizes)) = self.state.make_sizes_buffer_update( + naga::ShaderStage::Compute, + &mut self.temp.binding_sizes, + ) { + encoder.set_bytes( + index as _, + (sizes.len() * WORD_SIZE) as u64, + sizes.as_ptr() as _, + ); + } + } + + for index in 0..group.counters.cs.samplers { + let res = group.samplers[(index_base.samplers + index) as usize]; + encoder.set_sampler_state( + (bg_info.base_resource_indices.cs.samplers + index) as u64, + Some(res.as_native()), + ); + } + for index in 0..group.counters.cs.textures { + let res = group.textures[(index_base.textures + index) as usize]; + encoder.set_texture( + (bg_info.base_resource_indices.cs.textures + index) as u64, + Some(res.as_native()), + ); + } + } + } + + unsafe fn set_push_constants( + &mut self, + layout: &super::PipelineLayout, + stages: wgt::ShaderStages, + offset: u32, + data: &[u32], + ) { + let state_pc = &mut self.state.push_constants; + if state_pc.len() < layout.total_push_constants as usize { + state_pc.resize(layout.total_push_constants as usize, 0); + } + assert_eq!(offset as usize % WORD_SIZE, 0); + + let offset = offset as usize / WORD_SIZE; + state_pc[offset..offset + data.len()].copy_from_slice(data); + + if stages.contains(wgt::ShaderStages::COMPUTE) { + self.state.compute.as_ref().unwrap().set_bytes( + layout.push_constants_infos.cs.unwrap().buffer_index as _, + (layout.total_push_constants as usize * WORD_SIZE) as _, + state_pc.as_ptr() as _, + ) + } + if stages.contains(wgt::ShaderStages::VERTEX) { + self.state.render.as_ref().unwrap().set_vertex_bytes( + layout.push_constants_infos.vs.unwrap().buffer_index as _, + (layout.total_push_constants as usize * WORD_SIZE) as _, + state_pc.as_ptr() as _, + ) + } + if stages.contains(wgt::ShaderStages::FRAGMENT) { + self.state.render.as_ref().unwrap().set_fragment_bytes( + layout.push_constants_infos.fs.unwrap().buffer_index as _, + (layout.total_push_constants as usize * WORD_SIZE) as _, + state_pc.as_ptr() as _, + ) + } + } + + unsafe fn insert_debug_marker(&mut self, label: &str) { + if let Some(encoder) = self.enter_any() { + encoder.insert_debug_signpost(label); + } + } + unsafe fn begin_debug_marker(&mut self, group_label: &str) { + if let Some(encoder) = self.enter_any() { + encoder.push_debug_group(group_label); + } else if let Some(ref buf) = self.raw_cmd_buf { + buf.push_debug_group(group_label); + } + } + unsafe fn end_debug_marker(&mut self) { + if let Some(encoder) = self.enter_any() { + encoder.pop_debug_group(); + } else if let Some(ref buf) = self.raw_cmd_buf { + buf.pop_debug_group(); + } + } + + unsafe fn set_render_pipeline(&mut self, pipeline: &super::RenderPipeline) { + self.state.raw_primitive_type = pipeline.raw_primitive_type; + self.state.stage_infos.vs.assign_from(&pipeline.vs_info); + self.state.stage_infos.fs.assign_from(&pipeline.fs_info); + + let encoder = self.state.render.as_ref().unwrap(); + encoder.set_render_pipeline_state(&pipeline.raw); + encoder.set_front_facing_winding(pipeline.raw_front_winding); + encoder.set_cull_mode(pipeline.raw_cull_mode); + encoder.set_triangle_fill_mode(pipeline.raw_triangle_fill_mode); + if let Some(depth_clip) = pipeline.raw_depth_clip_mode { + encoder.set_depth_clip_mode(depth_clip); + } + if let Some((ref state, bias)) = pipeline.depth_stencil { + encoder.set_depth_stencil_state(state); + encoder.set_depth_bias(bias.constant as f32, bias.slope_scale, bias.clamp); + } + + { + if let Some((index, sizes)) = self + .state + .make_sizes_buffer_update(naga::ShaderStage::Vertex, &mut self.temp.binding_sizes) + { + encoder.set_vertex_bytes( + index as _, + (sizes.len() * WORD_SIZE) as u64, + sizes.as_ptr() as _, + ); + } + } + if pipeline.fs_lib.is_some() { + if let Some((index, sizes)) = self + .state + .make_sizes_buffer_update(naga::ShaderStage::Fragment, &mut self.temp.binding_sizes) + { + encoder.set_fragment_bytes( + index as _, + (sizes.len() * WORD_SIZE) as u64, + sizes.as_ptr() as _, + ); + } + } + } + + unsafe fn set_index_buffer<'a>( + &mut self, + binding: crate::BufferBinding<'a, super::Api>, + format: wgt::IndexFormat, + ) { + let (stride, raw_type) = match format { + wgt::IndexFormat::Uint16 => (2, mtl::MTLIndexType::UInt16), + wgt::IndexFormat::Uint32 => (4, mtl::MTLIndexType::UInt32), + }; + self.state.index = Some(super::IndexState { + buffer_ptr: AsNative::from(binding.buffer.raw.as_ref()), + offset: binding.offset, + stride, + raw_type, + }); + } + + unsafe fn set_vertex_buffer<'a>( + &mut self, + index: u32, + binding: crate::BufferBinding<'a, super::Api>, + ) { + let buffer_index = self.shared.private_caps.max_vertex_buffers as u64 - 1 - index as u64; + let encoder = self.state.render.as_ref().unwrap(); + encoder.set_vertex_buffer(buffer_index, Some(&binding.buffer.raw), binding.offset); + } + + unsafe fn set_viewport(&mut self, rect: &crate::Rect<f32>, depth_range: Range<f32>) { + let zfar = if self.shared.disabilities.broken_viewport_near_depth { + depth_range.end - depth_range.start + } else { + depth_range.end + }; + let encoder = self.state.render.as_ref().unwrap(); + encoder.set_viewport(mtl::MTLViewport { + originX: rect.x as _, + originY: rect.y as _, + width: rect.w as _, + height: rect.h as _, + znear: depth_range.start as _, + zfar: zfar as _, + }); + } + unsafe fn set_scissor_rect(&mut self, rect: &crate::Rect<u32>) { + //TODO: support empty scissors by modifying the viewport + let scissor = mtl::MTLScissorRect { + x: rect.x as _, + y: rect.y as _, + width: rect.w as _, + height: rect.h as _, + }; + let encoder = self.state.render.as_ref().unwrap(); + encoder.set_scissor_rect(scissor); + } + unsafe fn set_stencil_reference(&mut self, value: u32) { + let encoder = self.state.render.as_ref().unwrap(); + encoder.set_stencil_front_back_reference_value(value, value); + } + unsafe fn set_blend_constants(&mut self, color: &[f32; 4]) { + let encoder = self.state.render.as_ref().unwrap(); + encoder.set_blend_color(color[0], color[1], color[2], color[3]); + } + + unsafe fn draw( + &mut self, + start_vertex: u32, + vertex_count: u32, + start_instance: u32, + instance_count: u32, + ) { + let encoder = self.state.render.as_ref().unwrap(); + if start_instance != 0 { + encoder.draw_primitives_instanced_base_instance( + self.state.raw_primitive_type, + start_vertex as _, + vertex_count as _, + instance_count as _, + start_instance as _, + ); + } else if instance_count != 1 { + encoder.draw_primitives_instanced( + self.state.raw_primitive_type, + start_vertex as _, + vertex_count as _, + instance_count as _, + ); + } else { + encoder.draw_primitives( + self.state.raw_primitive_type, + start_vertex as _, + vertex_count as _, + ); + } + } + + unsafe fn draw_indexed( + &mut self, + start_index: u32, + index_count: u32, + base_vertex: i32, + start_instance: u32, + instance_count: u32, + ) { + let encoder = self.state.render.as_ref().unwrap(); + let index = self.state.index.as_ref().unwrap(); + let offset = index.offset + index.stride * start_index as wgt::BufferAddress; + if base_vertex != 0 || start_instance != 0 { + encoder.draw_indexed_primitives_instanced_base_instance( + self.state.raw_primitive_type, + index_count as _, + index.raw_type, + index.buffer_ptr.as_native(), + offset, + instance_count as _, + base_vertex as _, + start_instance as _, + ); + } else if instance_count != 1 { + encoder.draw_indexed_primitives_instanced( + self.state.raw_primitive_type, + index_count as _, + index.raw_type, + index.buffer_ptr.as_native(), + offset, + instance_count as _, + ); + } else { + encoder.draw_indexed_primitives( + self.state.raw_primitive_type, + index_count as _, + index.raw_type, + index.buffer_ptr.as_native(), + offset, + ); + } + } + + unsafe fn draw_indirect( + &mut self, + buffer: &super::Buffer, + mut offset: wgt::BufferAddress, + draw_count: u32, + ) { + let encoder = self.state.render.as_ref().unwrap(); + for _ in 0..draw_count { + encoder.draw_primitives_indirect(self.state.raw_primitive_type, &buffer.raw, offset); + offset += mem::size_of::<wgt::DrawIndirectArgs>() as wgt::BufferAddress; + } + } + + unsafe fn draw_indexed_indirect( + &mut self, + buffer: &super::Buffer, + mut offset: wgt::BufferAddress, + draw_count: u32, + ) { + let encoder = self.state.render.as_ref().unwrap(); + let index = self.state.index.as_ref().unwrap(); + for _ in 0..draw_count { + encoder.draw_indexed_primitives_indirect( + self.state.raw_primitive_type, + index.raw_type, + index.buffer_ptr.as_native(), + index.offset, + &buffer.raw, + offset, + ); + offset += mem::size_of::<wgt::DrawIndexedIndirectArgs>() as wgt::BufferAddress; + } + } + + unsafe fn draw_indirect_count( + &mut self, + _buffer: &super::Buffer, + _offset: wgt::BufferAddress, + _count_buffer: &super::Buffer, + _count_offset: wgt::BufferAddress, + _max_count: u32, + ) { + //TODO + } + unsafe fn draw_indexed_indirect_count( + &mut self, + _buffer: &super::Buffer, + _offset: wgt::BufferAddress, + _count_buffer: &super::Buffer, + _count_offset: wgt::BufferAddress, + _max_count: u32, + ) { + //TODO + } + + // compute + + unsafe fn begin_compute_pass(&mut self, desc: &crate::ComputePassDescriptor) { + self.begin_pass(); + + let raw = self.raw_cmd_buf.as_ref().unwrap(); + objc::rc::autoreleasepool(|| { + let encoder = raw.new_compute_command_encoder(); + if let Some(label) = desc.label { + encoder.set_label(label); + } + self.state.compute = Some(encoder.to_owned()); + }); + } + unsafe fn end_compute_pass(&mut self) { + self.state.compute.take().unwrap().end_encoding(); + } + + unsafe fn set_compute_pipeline(&mut self, pipeline: &super::ComputePipeline) { + self.state.raw_wg_size = pipeline.work_group_size; + self.state.stage_infos.cs.assign_from(&pipeline.cs_info); + + let encoder = self.state.compute.as_ref().unwrap(); + encoder.set_compute_pipeline_state(&pipeline.raw); + + if let Some((index, sizes)) = self + .state + .make_sizes_buffer_update(naga::ShaderStage::Compute, &mut self.temp.binding_sizes) + { + encoder.set_bytes( + index as _, + (sizes.len() * WORD_SIZE) as u64, + sizes.as_ptr() as _, + ); + } + + // update the threadgroup memory sizes + while self.state.work_group_memory_sizes.len() < pipeline.work_group_memory_sizes.len() { + self.state.work_group_memory_sizes.push(0); + } + for (index, (cur_size, pipeline_size)) in self + .state + .work_group_memory_sizes + .iter_mut() + .zip(pipeline.work_group_memory_sizes.iter()) + .enumerate() + { + const ALIGN_MASK: u32 = 0xF; // must be a multiple of 16 bytes + let size = ((*pipeline_size - 1) | ALIGN_MASK) + 1; + if *cur_size != size { + *cur_size = size; + encoder.set_threadgroup_memory_length(index as _, size as _); + } + } + } + + unsafe fn dispatch(&mut self, count: [u32; 3]) { + let encoder = self.state.compute.as_ref().unwrap(); + let raw_count = mtl::MTLSize { + width: count[0] as u64, + height: count[1] as u64, + depth: count[2] as u64, + }; + encoder.dispatch_thread_groups(raw_count, self.state.raw_wg_size); + } + + unsafe fn dispatch_indirect(&mut self, buffer: &super::Buffer, offset: wgt::BufferAddress) { + let encoder = self.state.compute.as_ref().unwrap(); + encoder.dispatch_thread_groups_indirect(&buffer.raw, offset, self.state.raw_wg_size); + } +} |