diff options
Diffstat (limited to 'third_party/rust/wgpu-core/src/device/resource.rs')
-rw-r--r-- | third_party/rust/wgpu-core/src/device/resource.rs | 544 |
1 files changed, 306 insertions, 238 deletions
diff --git a/third_party/rust/wgpu-core/src/device/resource.rs b/third_party/rust/wgpu-core/src/device/resource.rs index 4892aecb75..2541af7c70 100644 --- a/third_party/rust/wgpu-core/src/device/resource.rs +++ b/third_party/rust/wgpu-core/src/device/resource.rs @@ -7,18 +7,20 @@ use crate::{ bgl, life::{LifetimeTracker, WaitIdleError}, queue::PendingWrites, - AttachmentData, CommandAllocator, DeviceLostInvocation, MissingDownlevelFlags, - MissingFeatures, RenderPassContext, CLEANUP_WAIT_MS, + AttachmentData, DeviceLostInvocation, MissingDownlevelFlags, MissingFeatures, + RenderPassContext, CLEANUP_WAIT_MS, }, hal_api::HalApi, hal_label, hub::Hub, + id, init_tracker::{ BufferInitTracker, BufferInitTrackerAction, MemoryInitKind, TextureInitRange, TextureInitTracker, TextureInitTrackerAction, }, instance::Adapter, - pipeline, + lock::{rank, Mutex, MutexGuard, RwLock}, + pipeline::{self}, pool::ResourcePool, registry::Registry, resource::{ @@ -41,7 +43,6 @@ use crate::{ use arrayvec::ArrayVec; use hal::{CommandEncoder as _, Device as _}; use once_cell::sync::OnceCell; -use parking_lot::{Mutex, MutexGuard, RwLock}; use smallvec::SmallVec; use thiserror::Error; @@ -97,7 +98,7 @@ pub struct Device<A: HalApi> { pub(crate) zero_buffer: Option<A::Buffer>, pub(crate) info: ResourceInfo<Device<A>>, - pub(crate) command_allocator: Mutex<Option<CommandAllocator<A>>>, + pub(crate) command_allocator: command::CommandAllocator<A>, //Note: The submission index here corresponds to the last submission that is done. pub(crate) active_submission_index: AtomicU64, //SubmissionIndex, // NOTE: if both are needed, the `snatchable_lock` must be consistently acquired before the @@ -126,9 +127,6 @@ pub struct Device<A: HalApi> { pub(crate) tracker_indices: TrackerIndexAllocators, // Life tracker should be locked right after the device and before anything else. life_tracker: Mutex<LifetimeTracker<A>>, - /// Temporary storage for resource management functions. Cleared at the end - /// of every call (unless an error occurs). - pub(crate) temp_suspected: Mutex<Option<ResourceMaps<A>>>, /// Pool of bind group layouts, allowing deduplication. pub(crate) bgl_pool: ResourcePool<bgl::EntryMap, BindGroupLayout<A>>, pub(crate) alignments: hal::Alignments, @@ -141,6 +139,10 @@ pub struct Device<A: HalApi> { #[cfg(feature = "trace")] pub(crate) trace: Mutex<Option<trace::Trace>>, pub(crate) usage_scopes: UsageScopePool<A>, + + /// Temporary storage, cleared at the start of every call, + /// retained only to save allocations. + temp_suspected: Mutex<Option<ResourceMaps<A>>>, } pub(crate) enum DeferredDestroy<A: HalApi> { @@ -165,7 +167,7 @@ impl<A: HalApi> Drop for Device<A> { let raw = self.raw.take().unwrap(); let pending_writes = self.pending_writes.lock().take().unwrap(); pending_writes.dispose(&raw); - self.command_allocator.lock().take().unwrap().dispose(&raw); + self.command_allocator.dispose(&raw); unsafe { raw.destroy_buffer(self.zero_buffer.take().unwrap()); raw.destroy_fence(self.fence.write().take().unwrap()); @@ -223,10 +225,8 @@ impl<A: HalApi> Device<A> { let fence = unsafe { raw_device.create_fence() }.map_err(|_| CreateDeviceError::OutOfMemory)?; - let mut com_alloc = CommandAllocator { - free_encoders: Vec::new(), - }; - let pending_encoder = com_alloc + let command_allocator = command::CommandAllocator::new(); + let pending_encoder = command_allocator .acquire_encoder(&raw_device, raw_queue) .map_err(|_| CreateDeviceError::OutOfMemory)?; let mut pending_writes = queue::PendingWrites::<A>::new(pending_encoder); @@ -271,38 +271,44 @@ impl<A: HalApi> Device<A> { queue_to_drop: OnceCell::new(), zero_buffer: Some(zero_buffer), info: ResourceInfo::new("<device>", None), - command_allocator: Mutex::new(Some(com_alloc)), + command_allocator, active_submission_index: AtomicU64::new(0), - fence: RwLock::new(Some(fence)), - snatchable_lock: unsafe { SnatchLock::new() }, + fence: RwLock::new(rank::DEVICE_FENCE, Some(fence)), + snatchable_lock: unsafe { SnatchLock::new(rank::DEVICE_SNATCHABLE_LOCK) }, valid: AtomicBool::new(true), - trackers: Mutex::new(Tracker::new()), + trackers: Mutex::new(rank::DEVICE_TRACKERS, Tracker::new()), tracker_indices: TrackerIndexAllocators::new(), - life_tracker: Mutex::new(life::LifetimeTracker::new()), - temp_suspected: Mutex::new(Some(life::ResourceMaps::new())), + life_tracker: Mutex::new(rank::DEVICE_LIFE_TRACKER, life::LifetimeTracker::new()), + temp_suspected: Mutex::new( + rank::DEVICE_TEMP_SUSPECTED, + Some(life::ResourceMaps::new()), + ), bgl_pool: ResourcePool::new(), #[cfg(feature = "trace")] - trace: Mutex::new(trace_path.and_then(|path| match trace::Trace::new(path) { - Ok(mut trace) => { - trace.add(trace::Action::Init { - desc: desc.clone(), - backend: A::VARIANT, - }); - Some(trace) - } - Err(e) => { - log::error!("Unable to start a trace in '{path:?}': {e}"); - None - } - })), + trace: Mutex::new( + rank::DEVICE_TRACE, + trace_path.and_then(|path| match trace::Trace::new(path) { + Ok(mut trace) => { + trace.add(trace::Action::Init { + desc: desc.clone(), + backend: A::VARIANT, + }); + Some(trace) + } + Err(e) => { + log::error!("Unable to start a trace in '{path:?}': {e}"); + None + } + }), + ), alignments, limits: desc.required_limits.clone(), features: desc.required_features, downlevel, instance_flags, - pending_writes: Mutex::new(Some(pending_writes)), - deferred_destroy: Mutex::new(Vec::new()), - usage_scopes: Default::default(), + pending_writes: Mutex::new(rank::DEVICE_PENDING_WRITES, Some(pending_writes)), + deferred_destroy: Mutex::new(rank::DEVICE_DEFERRED_DESTROY, Vec::new()), + usage_scopes: Mutex::new(rank::DEVICE_USAGE_SCOPES, Default::default()), }) } @@ -379,7 +385,7 @@ impl<A: HalApi> Device<A> { /// Check this device for completed commands. /// - /// The `maintain` argument tells how the maintence function should behave, either + /// The `maintain` argument tells how the maintenance function should behave, either /// blocking or just polling the current state of the gpu. /// /// Return a pair `(closures, queue_empty)`, where: @@ -392,11 +398,12 @@ impl<A: HalApi> Device<A> { /// return it to our callers.) pub(crate) fn maintain<'this>( &'this self, - fence: &A::Fence, + fence_guard: crate::lock::RwLockReadGuard<Option<A::Fence>>, maintain: wgt::Maintain<queue::WrappedSubmissionIndex>, snatch_guard: SnatchGuard, ) -> Result<(UserClosures, bool), WaitIdleError> { profiling::scope!("Device::maintain"); + let fence = fence_guard.as_ref().unwrap(); let last_done_index = if maintain.is_wait() { let index_to_wait_for = match maintain { wgt::Maintain::WaitForSubmissionIndex(submission_index) => { @@ -425,28 +432,12 @@ impl<A: HalApi> Device<A> { }; let mut life_tracker = self.lock_life(); - let submission_closures = life_tracker.triage_submissions( - last_done_index, - self.command_allocator.lock().as_mut().unwrap(), - ); - - { - // Normally, `temp_suspected` exists only to save heap - // allocations: it's cleared at the start of the function - // call, and cleared by the end. But `Global::queue_submit` is - // fallible; if it exits early, it may leave some resources in - // `temp_suspected`. - let temp_suspected = self - .temp_suspected - .lock() - .replace(ResourceMaps::new()) - .unwrap(); + let submission_closures = + life_tracker.triage_submissions(last_done_index, &self.command_allocator); - life_tracker.suspected_resources.extend(temp_suspected); + life_tracker.triage_suspected(&self.trackers); - life_tracker.triage_suspected(&self.trackers); - life_tracker.triage_mapped(); - } + life_tracker.triage_mapped(); let mapping_closures = life_tracker.handle_mapping(self.raw(), &self.trackers, &snatch_guard); @@ -478,6 +469,7 @@ impl<A: HalApi> Device<A> { // Don't hold the locks while calling release_gpu_resources. drop(life_tracker); + drop(fence_guard); drop(snatch_guard); if should_release_gpu_resource { @@ -493,12 +485,14 @@ impl<A: HalApi> Device<A> { } pub(crate) fn untrack(&self, trackers: &Tracker<A>) { + // If we have a previously allocated `ResourceMap`, just use that. let mut temp_suspected = self .temp_suspected .lock() - .replace(ResourceMaps::new()) - .unwrap(); + .take() + .unwrap_or_else(|| ResourceMaps::new()); temp_suspected.clear(); + // As the tracker is cleared/dropped, we need to consider all the resources // that it references for destruction in the next GC pass. { @@ -559,7 +553,11 @@ impl<A: HalApi> Device<A> { } } } - self.lock_life().suspected_resources.extend(temp_suspected); + self.lock_life() + .suspected_resources + .extend(&mut temp_suspected); + // Save this resource map for later reuse. + *self.temp_suspected.lock() = Some(temp_suspected); } pub(crate) fn create_buffer( @@ -653,14 +651,17 @@ impl<A: HalApi> Device<A> { device: self.clone(), usage: desc.usage, size: desc.size, - initialization_status: RwLock::new(BufferInitTracker::new(aligned_size)), - sync_mapped_writes: Mutex::new(None), - map_state: Mutex::new(resource::BufferMapState::Idle), + initialization_status: RwLock::new( + rank::BUFFER_INITIALIZATION_STATUS, + BufferInitTracker::new(aligned_size), + ), + sync_mapped_writes: Mutex::new(rank::BUFFER_SYNC_MAPPED_WRITES, None), + map_state: Mutex::new(rank::BUFFER_MAP_STATE, resource::BufferMapState::Idle), info: ResourceInfo::new( desc.label.borrow_or_default(), Some(self.tracker_indices.buffers.clone()), ), - bind_groups: Mutex::new(Vec::new()), + bind_groups: Mutex::new(rank::BUFFER_BIND_GROUPS, Vec::new()), }) } @@ -680,10 +681,10 @@ impl<A: HalApi> Device<A> { desc: desc.map_label(|_| ()), hal_usage, format_features, - initialization_status: RwLock::new(TextureInitTracker::new( - desc.mip_level_count, - desc.array_layer_count(), - )), + initialization_status: RwLock::new( + rank::TEXTURE_INITIALIZATION_STATUS, + TextureInitTracker::new(desc.mip_level_count, desc.array_layer_count()), + ), full_range: TextureSelector { mips: 0..desc.mip_level_count, layers: 0..desc.array_layer_count(), @@ -692,9 +693,9 @@ impl<A: HalApi> Device<A> { desc.label.borrow_or_default(), Some(self.tracker_indices.textures.clone()), ), - clear_mode: RwLock::new(clear_mode), - views: Mutex::new(Vec::new()), - bind_groups: Mutex::new(Vec::new()), + clear_mode: RwLock::new(rank::TEXTURE_CLEAR_MODE, clear_mode), + views: Mutex::new(rank::TEXTURE_VIEWS, Vec::new()), + bind_groups: Mutex::new(rank::TEXTURE_BIND_GROUPS, Vec::new()), } } @@ -710,14 +711,17 @@ impl<A: HalApi> Device<A> { device: self.clone(), usage: desc.usage, size: desc.size, - initialization_status: RwLock::new(BufferInitTracker::new(0)), - sync_mapped_writes: Mutex::new(None), - map_state: Mutex::new(resource::BufferMapState::Idle), + initialization_status: RwLock::new( + rank::BUFFER_INITIALIZATION_STATUS, + BufferInitTracker::new(0), + ), + sync_mapped_writes: Mutex::new(rank::BUFFER_SYNC_MAPPED_WRITES, None), + map_state: Mutex::new(rank::BUFFER_MAP_STATE, resource::BufferMapState::Idle), info: ResourceInfo::new( desc.label.borrow_or_default(), Some(self.tracker_indices.buffers.clone()), ), - bind_groups: Mutex::new(Vec::new()), + bind_groups: Mutex::new(rank::BUFFER_BIND_GROUPS, Vec::new()), } } @@ -1421,7 +1425,7 @@ impl<A: HalApi> Device<A> { pipeline::ShaderModuleSource::Wgsl(code) => { profiling::scope!("naga::front::wgsl::parse_str"); let module = naga::front::wgsl::parse_str(&code).map_err(|inner| { - pipeline::CreateShaderModuleError::Parsing(pipeline::ShaderError { + pipeline::CreateShaderModuleError::Parsing(naga::error::ShaderError { source: code.to_string(), label: desc.label.as_ref().map(|l| l.to_string()), inner: Box::new(inner), @@ -1434,7 +1438,7 @@ impl<A: HalApi> Device<A> { let parser = naga::front::spv::Frontend::new(spv.iter().cloned(), &options); profiling::scope!("naga::front::spv::Frontend"); let module = parser.parse().map_err(|inner| { - pipeline::CreateShaderModuleError::ParsingSpirV(pipeline::ShaderError { + pipeline::CreateShaderModuleError::ParsingSpirV(naga::error::ShaderError { source: String::new(), label: desc.label.as_ref().map(|l| l.to_string()), inner: Box::new(inner), @@ -1447,7 +1451,7 @@ impl<A: HalApi> Device<A> { let mut parser = naga::front::glsl::Frontend::default(); profiling::scope!("naga::front::glsl::Frontend.parse"); let module = parser.parse(&options, &code).map_err(|inner| { - pipeline::CreateShaderModuleError::ParsingGlsl(pipeline::ShaderError { + pipeline::CreateShaderModuleError::ParsingGlsl(naga::error::ShaderError { source: code.to_string(), label: desc.label.as_ref().map(|l| l.to_string()), inner: Box::new(inner), @@ -1471,9 +1475,78 @@ impl<A: HalApi> Device<A> { }; } - use naga::valid::Capabilities as Caps; profiling::scope!("naga::validate"); + let debug_source = + if self.instance_flags.contains(wgt::InstanceFlags::DEBUG) && !source.is_empty() { + Some(hal::DebugSource { + file_name: Cow::Owned( + desc.label + .as_ref() + .map_or("shader".to_string(), |l| l.to_string()), + ), + source_code: Cow::Owned(source.clone()), + }) + } else { + None + }; + + let info = self + .create_validator(naga::valid::ValidationFlags::all()) + .validate(&module) + .map_err(|inner| { + pipeline::CreateShaderModuleError::Validation(naga::error::ShaderError { + source, + label: desc.label.as_ref().map(|l| l.to_string()), + inner: Box::new(inner), + }) + })?; + + let interface = + validation::Interface::new(&module, &info, self.limits.clone(), self.features); + let hal_shader = hal::ShaderInput::Naga(hal::NagaShader { + module, + info, + debug_source, + }); + let hal_desc = hal::ShaderModuleDescriptor { + label: desc.label.to_hal(self.instance_flags), + runtime_checks: desc.shader_bound_checks.runtime_checks(), + }; + let raw = match unsafe { + self.raw + .as_ref() + .unwrap() + .create_shader_module(&hal_desc, hal_shader) + } { + Ok(raw) => raw, + Err(error) => { + return Err(match error { + hal::ShaderError::Device(error) => { + pipeline::CreateShaderModuleError::Device(error.into()) + } + hal::ShaderError::Compilation(ref msg) => { + log::error!("Shader error: {}", msg); + pipeline::CreateShaderModuleError::Generation + } + }) + } + }; + + Ok(pipeline::ShaderModule { + raw: Some(raw), + device: self.clone(), + interface: Some(interface), + info: ResourceInfo::new(desc.label.borrow_or_default(), None), + label: desc.label.borrow_or_default().to_string(), + }) + } + /// Create a validator with the given validation flags. + pub fn create_validator( + self: &Arc<Self>, + flags: naga::valid::ValidationFlags, + ) -> naga::valid::Validator { + use naga::valid::Capabilities as Caps; let mut caps = Caps::empty(); caps.set( Caps::PUSH_CONSTANT, @@ -1541,69 +1614,36 @@ impl<A: HalApi> Device<A> { .flags .contains(wgt::DownlevelFlags::CUBE_ARRAY_TEXTURES), ); + caps.set( + Caps::SUBGROUP, + self.features + .intersects(wgt::Features::SUBGROUP | wgt::Features::SUBGROUP_VERTEX), + ); + caps.set( + Caps::SUBGROUP_BARRIER, + self.features.intersects(wgt::Features::SUBGROUP_BARRIER), + ); - let debug_source = - if self.instance_flags.contains(wgt::InstanceFlags::DEBUG) && !source.is_empty() { - Some(hal::DebugSource { - file_name: Cow::Owned( - desc.label - .as_ref() - .map_or("shader".to_string(), |l| l.to_string()), - ), - source_code: Cow::Owned(source.clone()), - }) - } else { - None - }; - - let info = naga::valid::Validator::new(naga::valid::ValidationFlags::all(), caps) - .validate(&module) - .map_err(|inner| { - pipeline::CreateShaderModuleError::Validation(pipeline::ShaderError { - source, - label: desc.label.as_ref().map(|l| l.to_string()), - inner: Box::new(inner), - }) - })?; + let mut subgroup_stages = naga::valid::ShaderStages::empty(); + subgroup_stages.set( + naga::valid::ShaderStages::COMPUTE | naga::valid::ShaderStages::FRAGMENT, + self.features.contains(wgt::Features::SUBGROUP), + ); + subgroup_stages.set( + naga::valid::ShaderStages::VERTEX, + self.features.contains(wgt::Features::SUBGROUP_VERTEX), + ); - let interface = - validation::Interface::new(&module, &info, self.limits.clone(), self.features); - let hal_shader = hal::ShaderInput::Naga(hal::NagaShader { - module, - info, - debug_source, - }); - let hal_desc = hal::ShaderModuleDescriptor { - label: desc.label.to_hal(self.instance_flags), - runtime_checks: desc.shader_bound_checks.runtime_checks(), - }; - let raw = match unsafe { - self.raw - .as_ref() - .unwrap() - .create_shader_module(&hal_desc, hal_shader) - } { - Ok(raw) => raw, - Err(error) => { - return Err(match error { - hal::ShaderError::Device(error) => { - pipeline::CreateShaderModuleError::Device(error.into()) - } - hal::ShaderError::Compilation(ref msg) => { - log::error!("Shader error: {}", msg); - pipeline::CreateShaderModuleError::Generation - } - }) - } + let subgroup_operations = if caps.contains(Caps::SUBGROUP) { + use naga::valid::SubgroupOperationSet as S; + S::BASIC | S::VOTE | S::ARITHMETIC | S::BALLOT | S::SHUFFLE | S::SHUFFLE_RELATIVE + } else { + naga::valid::SubgroupOperationSet::empty() }; - - Ok(pipeline::ShaderModule { - raw: Some(raw), - device: self.clone(), - interface: Some(interface), - info: ResourceInfo::new(desc.label.borrow_or_default(), None), - label: desc.label.borrow_or_default().to_string(), - }) + let mut validator = naga::valid::Validator::new(flags, caps); + validator.subgroup_stages(subgroup_stages); + validator.subgroup_operations(subgroup_operations); + validator } #[allow(unused_unsafe)] @@ -1913,6 +1953,7 @@ impl<A: HalApi> Device<A> { used: &mut BindGroupStates<A>, storage: &'a Storage<Buffer<A>>, limits: &wgt::Limits, + device_id: id::Id<id::markers::Device>, snatch_guard: &'a SnatchGuard<'a>, ) -> Result<hal::BufferBinding<'a, A>, binding_model::CreateBindGroupError> { use crate::binding_model::CreateBindGroupError as Error; @@ -1931,6 +1972,7 @@ impl<A: HalApi> Device<A> { }) } }; + let (pub_usage, internal_use, range_limit) = match binding_ty { wgt::BufferBindingType::Uniform => ( wgt::BufferUsages::UNIFORM, @@ -1963,6 +2005,10 @@ impl<A: HalApi> Device<A> { .add_single(storage, bb.buffer_id, internal_use) .ok_or(Error::InvalidBuffer(bb.buffer_id))?; + if buffer.device.as_info().id() != device_id { + return Err(DeviceError::WrongDevice.into()); + } + check_buffer_usage(bb.buffer_id, buffer.usage, pub_usage)?; let raw_buffer = buffer .raw @@ -2041,13 +2087,53 @@ impl<A: HalApi> Device<A> { }) } - pub(crate) fn create_texture_binding( - view: &TextureView<A>, - internal_use: hal::TextureUses, - pub_usage: wgt::TextureUsages, + fn create_sampler_binding<'a>( + used: &BindGroupStates<A>, + storage: &'a Storage<Sampler<A>>, + id: id::Id<id::markers::Sampler>, + device_id: id::Id<id::markers::Device>, + ) -> Result<&'a Sampler<A>, binding_model::CreateBindGroupError> { + use crate::binding_model::CreateBindGroupError as Error; + + let sampler = used + .samplers + .add_single(storage, id) + .ok_or(Error::InvalidSampler(id))?; + + if sampler.device.as_info().id() != device_id { + return Err(DeviceError::WrongDevice.into()); + } + + Ok(sampler) + } + + pub(crate) fn create_texture_binding<'a>( + self: &Arc<Self>, + binding: u32, + decl: &wgt::BindGroupLayoutEntry, + storage: &'a Storage<TextureView<A>>, + id: id::Id<id::markers::TextureView>, used: &mut BindGroupStates<A>, used_texture_ranges: &mut Vec<TextureInitTrackerAction<A>>, - ) -> Result<(), binding_model::CreateBindGroupError> { + snatch_guard: &'a SnatchGuard<'a>, + ) -> Result<hal::TextureBinding<'a, A>, binding_model::CreateBindGroupError> { + use crate::binding_model::CreateBindGroupError as Error; + + let view = used + .views + .add_single(storage, id) + .ok_or(Error::InvalidTextureView(id))?; + + if view.device.as_info().id() != self.as_info().id() { + return Err(DeviceError::WrongDevice.into()); + } + + let (pub_usage, internal_use) = self.texture_use_parameters( + binding, + decl, + view, + "SampledTexture, ReadonlyStorageTexture or WriteonlyStorageTexture", + )?; let texture = &view.parent; let texture_id = texture.as_info().id(); // Careful here: the texture may no longer have its own ref count, @@ -2077,7 +2163,12 @@ impl<A: HalApi> Device<A> { kind: MemoryInitKind::NeedsInitializedMemory, }); - Ok(()) + Ok(hal::TextureBinding { + view: view + .raw(snatch_guard) + .ok_or(Error::InvalidTextureView(id))?, + usage: internal_use, + }) } // This function expects the provided bind group layout to be resolved @@ -2139,6 +2230,7 @@ impl<A: HalApi> Device<A> { &mut used, &*buffer_guard, &self.limits, + self.as_info().id(), &snatch_guard, )?; @@ -2162,105 +2254,86 @@ impl<A: HalApi> Device<A> { &mut used, &*buffer_guard, &self.limits, + self.as_info().id(), &snatch_guard, )?; hal_buffers.push(bb); } (res_index, num_bindings) } - Br::Sampler(id) => { - match decl.ty { - wgt::BindingType::Sampler(ty) => { - let sampler = used - .samplers - .add_single(&*sampler_guard, id) - .ok_or(Error::InvalidSampler(id))?; - - if sampler.device.as_info().id() != self.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } - - // Allowed sampler values for filtering and comparison - let (allowed_filtering, allowed_comparison) = match ty { - wgt::SamplerBindingType::Filtering => (None, false), - wgt::SamplerBindingType::NonFiltering => (Some(false), false), - wgt::SamplerBindingType::Comparison => (None, true), - }; - - if let Some(allowed_filtering) = allowed_filtering { - if allowed_filtering != sampler.filtering { - return Err(Error::WrongSamplerFiltering { - binding, - layout_flt: allowed_filtering, - sampler_flt: sampler.filtering, - }); - } - } + Br::Sampler(id) => match decl.ty { + wgt::BindingType::Sampler(ty) => { + let sampler = Self::create_sampler_binding( + &used, + &sampler_guard, + id, + self.as_info().id(), + )?; - if allowed_comparison != sampler.comparison { - return Err(Error::WrongSamplerComparison { + let (allowed_filtering, allowed_comparison) = match ty { + wgt::SamplerBindingType::Filtering => (None, false), + wgt::SamplerBindingType::NonFiltering => (Some(false), false), + wgt::SamplerBindingType::Comparison => (None, true), + }; + if let Some(allowed_filtering) = allowed_filtering { + if allowed_filtering != sampler.filtering { + return Err(Error::WrongSamplerFiltering { binding, - layout_cmp: allowed_comparison, - sampler_cmp: sampler.comparison, + layout_flt: allowed_filtering, + sampler_flt: sampler.filtering, }); } - - let res_index = hal_samplers.len(); - hal_samplers.push(sampler.raw()); - (res_index, 1) } - _ => { - return Err(Error::WrongBindingType { + if allowed_comparison != sampler.comparison { + return Err(Error::WrongSamplerComparison { binding, - actual: decl.ty, - expected: "Sampler", - }) + layout_cmp: allowed_comparison, + sampler_cmp: sampler.comparison, + }); } + + let res_index = hal_samplers.len(); + hal_samplers.push(sampler.raw()); + (res_index, 1) } - } + _ => { + return Err(Error::WrongBindingType { + binding, + actual: decl.ty, + expected: "Sampler", + }) + } + }, Br::SamplerArray(ref bindings_array) => { let num_bindings = bindings_array.len(); Self::check_array_binding(self.features, decl.count, num_bindings)?; let res_index = hal_samplers.len(); for &id in bindings_array.iter() { - let sampler = used - .samplers - .add_single(&*sampler_guard, id) - .ok_or(Error::InvalidSampler(id))?; - if sampler.device.as_info().id() != self.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + let sampler = Self::create_sampler_binding( + &used, + &sampler_guard, + id, + self.as_info().id(), + )?; + hal_samplers.push(sampler.raw()); } (res_index, num_bindings) } Br::TextureView(id) => { - let view = used - .views - .add_single(&*texture_view_guard, id) - .ok_or(Error::InvalidTextureView(id))?; - let (pub_usage, internal_use) = self.texture_use_parameters( + let tb = self.create_texture_binding( binding, decl, - view, - "SampledTexture, ReadonlyStorageTexture or WriteonlyStorageTexture", - )?; - Self::create_texture_binding( - view, - internal_use, - pub_usage, + &texture_view_guard, + id, &mut used, &mut used_texture_ranges, + &snatch_guard, )?; let res_index = hal_textures.len(); - hal_textures.push(hal::TextureBinding { - view: view - .raw(&snatch_guard) - .ok_or(Error::InvalidTextureView(id))?, - usage: internal_use, - }); + hal_textures.push(tb); (res_index, 1) } Br::TextureViewArray(ref bindings_array) => { @@ -2269,26 +2342,17 @@ impl<A: HalApi> Device<A> { let res_index = hal_textures.len(); for &id in bindings_array.iter() { - let view = used - .views - .add_single(&*texture_view_guard, id) - .ok_or(Error::InvalidTextureView(id))?; - let (pub_usage, internal_use) = - self.texture_use_parameters(binding, decl, view, - "SampledTextureArray, ReadonlyStorageTextureArray or WriteonlyStorageTextureArray")?; - Self::create_texture_binding( - view, - internal_use, - pub_usage, + let tb = self.create_texture_binding( + binding, + decl, + &texture_view_guard, + id, &mut used, &mut used_texture_ranges, + &snatch_guard, )?; - hal_textures.push(hal::TextureBinding { - view: view - .raw(&snatch_guard) - .ok_or(Error::InvalidTextureView(id))?, - usage: internal_use, - }); + + hal_textures.push(tb); } (res_index, num_bindings) @@ -2762,8 +2826,10 @@ impl<A: HalApi> Device<A> { label: desc.label.to_hal(self.instance_flags), layout: pipeline_layout.raw(), stage: hal::ProgrammableStage { - entry_point: final_entry_point_name.as_ref(), module: shader_module.raw(), + entry_point: final_entry_point_name.as_ref(), + constants: desc.stage.constants.as_ref(), + zero_initialize_workgroup_memory: desc.stage.zero_initialize_workgroup_memory, }, }; @@ -3178,6 +3244,8 @@ impl<A: HalApi> Device<A> { hal::ProgrammableStage { module: vertex_shader_module.raw(), entry_point: &vertex_entry_point_name, + constants: stage_desc.constants.as_ref(), + zero_initialize_workgroup_memory: stage_desc.zero_initialize_workgroup_memory, } }; @@ -3237,6 +3305,10 @@ impl<A: HalApi> Device<A> { Some(hal::ProgrammableStage { module: shader_module.raw(), entry_point: &fragment_entry_point_name, + constants: fragment_state.stage.constants.as_ref(), + zero_initialize_workgroup_memory: fragment_state + .stage + .zero_initialize_workgroup_memory, }) } None => None, @@ -3482,10 +3554,9 @@ impl<A: HalApi> Device<A> { .map_err(DeviceError::from)? }; drop(guard); - let closures = self.lock_life().triage_submissions( - submission_index, - self.command_allocator.lock().as_mut().unwrap(), - ); + let closures = self + .lock_life() + .triage_submissions(submission_index, &self.command_allocator); assert!( closures.is_empty(), "wait_for_submit is not expected to work with closures" @@ -3613,10 +3684,7 @@ impl<A: HalApi> Device<A> { log::error!("failed to wait for the device: {error}"); } let mut life_tracker = self.lock_life(); - let _ = life_tracker.triage_submissions( - current_index, - self.command_allocator.lock().as_mut().unwrap(), - ); + let _ = life_tracker.triage_submissions(current_index, &self.command_allocator); if let Some(device_lost_closure) = life_tracker.device_lost_closure.take() { // It's important to not hold the lock while calling the closure. drop(life_tracker); |