summaryrefslogtreecommitdiffstats
path: root/third_party/rust/audio_thread_priority/src/rt_win.rs
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--third_party/rust/audio_thread_priority/src/rt_win.rs233
1 files changed, 195 insertions, 38 deletions
diff --git a/third_party/rust/audio_thread_priority/src/rt_win.rs b/third_party/rust/audio_thread_priority/src/rt_win.rs
index 3a2bbf29d9..480c5bd75c 100644
--- a/third_party/rust/audio_thread_priority/src/rt_win.rs
+++ b/third_party/rust/audio_thread_priority/src/rt_win.rs
@@ -1,14 +1,11 @@
-use windows_sys::s;
-use windows_sys::Win32::Foundation::GetLastError;
-use windows_sys::Win32::Foundation::FALSE;
-use windows_sys::Win32::Foundation::HANDLE;
-use windows_sys::Win32::System::Threading::{
- AvRevertMmThreadCharacteristics, AvSetMmThreadCharacteristicsA,
-};
-
+use self::avrt_lib::AvRtLibrary;
use crate::AudioThreadPriorityError;
-
use log::info;
+use std::sync::OnceLock;
+use windows_sys::{
+ w,
+ Win32::Foundation::{HANDLE, WIN32_ERROR},
+};
#[derive(Debug)]
pub struct RtPriorityHandleInternal {
@@ -17,7 +14,7 @@ pub struct RtPriorityHandleInternal {
}
impl RtPriorityHandleInternal {
- pub fn new(mmcss_task_index: u32, task_handle: HANDLE) -> RtPriorityHandleInternal {
+ fn new(mmcss_task_index: u32, task_handle: HANDLE) -> RtPriorityHandleInternal {
RtPriorityHandleInternal {
mmcss_task_index,
task_handle,
@@ -25,45 +22,205 @@ impl RtPriorityHandleInternal {
}
}
+fn avrt() -> Result<&'static AvRtLibrary, AudioThreadPriorityError> {
+ static AV_RT_LIBRARY: OnceLock<Result<AvRtLibrary, WIN32_ERROR>> = OnceLock::new();
+ AV_RT_LIBRARY
+ .get_or_init(AvRtLibrary::try_new)
+ .as_ref()
+ .map_err(|win32_error| {
+ AudioThreadPriorityError::new(&format!("Unable to load avrt.dll ({win32_error})"))
+ })
+}
+
+pub fn promote_current_thread_to_real_time_internal(
+ _audio_buffer_frames: u32,
+ _audio_samplerate_hz: u32,
+) -> Result<RtPriorityHandleInternal, AudioThreadPriorityError> {
+ avrt()?
+ .set_mm_thread_characteristics(w!("Audio"))
+ .map(|(mmcss_task_index, task_handle)| {
+ info!("task {mmcss_task_index} bumped to real time priority.");
+ RtPriorityHandleInternal::new(mmcss_task_index, task_handle)
+ })
+ .map_err(|win32_error| {
+ AudioThreadPriorityError::new(&format!(
+ "Unable to bump the thread priority ({win32_error})"
+ ))
+ })
+}
+
pub fn demote_current_thread_from_real_time_internal(
rt_priority_handle: RtPriorityHandleInternal,
) -> Result<(), AudioThreadPriorityError> {
- let rv = unsafe { AvRevertMmThreadCharacteristics(rt_priority_handle.task_handle) };
- if rv == FALSE {
- return Err(AudioThreadPriorityError::new(&format!(
- "Unable to restore the thread priority ({:?})",
- unsafe { GetLastError() }
- )));
+ let RtPriorityHandleInternal {
+ mmcss_task_index,
+ task_handle,
+ } = rt_priority_handle;
+ avrt()?
+ .revert_mm_thread_characteristics(task_handle)
+ .map(|_| {
+ info!("task {mmcss_task_index} priority restored.");
+ })
+ .map_err(|win32_error| {
+ AudioThreadPriorityError::new(&format!(
+ "Unable to restore the thread priority for task {mmcss_task_index} ({win32_error})"
+ ))
+ })
+}
+
+mod avrt_lib {
+ use super::win32_utils::{win32_error_if, OwnedLibrary};
+ use std::sync::Once;
+ use windows_sys::{
+ core::PCWSTR,
+ s, w,
+ Win32::Foundation::{BOOL, FALSE, HANDLE, WIN32_ERROR},
+ };
+
+ type AvSetMmThreadCharacteristicsWFn = unsafe extern "system" fn(PCWSTR, *mut u32) -> HANDLE;
+ type AvRevertMmThreadCharacteristicsFn = unsafe extern "system" fn(HANDLE) -> BOOL;
+
+ #[derive(Debug)]
+ pub(super) struct AvRtLibrary {
+ // This field is never read because only used for its Drop behavior
+ #[allow(dead_code)]
+ module: OwnedLibrary,
+
+ av_set_mm_thread_characteristics_w: AvSetMmThreadCharacteristicsWFn,
+ av_revert_mm_thread_characteristics: AvRevertMmThreadCharacteristicsFn,
}
- info!(
- "task {} priority restored.",
- rt_priority_handle.mmcss_task_index
- );
+ impl AvRtLibrary {
+ pub(super) fn try_new() -> Result<Self, WIN32_ERROR> {
+ let module = OwnedLibrary::try_new(w!("avrt.dll"))?;
+ let av_set_mm_thread_characteristics_w = unsafe {
+ std::mem::transmute::<_, AvSetMmThreadCharacteristicsWFn>(
+ module.get_proc(s!("AvSetMmThreadCharacteristicsW"))?,
+ )
+ };
+ let av_revert_mm_thread_characteristics = unsafe {
+ std::mem::transmute::<_, AvRevertMmThreadCharacteristicsFn>(
+ module.get_proc(s!("AvRevertMmThreadCharacteristics"))?,
+ )
+ };
+ Ok(Self {
+ module,
+ av_set_mm_thread_characteristics_w,
+ av_revert_mm_thread_characteristics,
+ })
+ }
+
+ pub(super) fn set_mm_thread_characteristics(
+ &self,
+ task_name: PCWSTR,
+ ) -> Result<(u32, HANDLE), WIN32_ERROR> {
+ // Ensure that the first call never runs in parallel with other calls. This
+ // seems necessary to guarantee the success of these other calls. We saw them
+ // fail with an error code of ERROR_PATH_NOT_FOUND in tests, presumably on a
+ // machine where the MMCSS service was initially inactive.
+ static FIRST_CALL: Once = Once::new();
+ let mut first_call_result = None;
+ FIRST_CALL.call_once(|| {
+ first_call_result = Some(self.set_mm_thread_characteristics_internal(task_name))
+ });
+ first_call_result
+ .unwrap_or_else(|| self.set_mm_thread_characteristics_internal(task_name))
+ }
+
+ fn set_mm_thread_characteristics_internal(
+ &self,
+ task_name: PCWSTR,
+ ) -> Result<(u32, HANDLE), WIN32_ERROR> {
+ let mut mmcss_task_index = 0u32;
+ let task_handle = unsafe {
+ (self.av_set_mm_thread_characteristics_w)(task_name, &mut mmcss_task_index)
+ };
+ win32_error_if(task_handle == 0)?;
+ Ok((mmcss_task_index, task_handle))
+ }
- Ok(())
+ pub(super) fn revert_mm_thread_characteristics(
+ &self,
+ handle: HANDLE,
+ ) -> Result<(), WIN32_ERROR> {
+ let rv = unsafe { (self.av_revert_mm_thread_characteristics)(handle) };
+ win32_error_if(rv == FALSE)
+ }
+ }
}
-pub fn promote_current_thread_to_real_time_internal(
- _audio_buffer_frames: u32,
- _audio_samplerate_hz: u32,
-) -> Result<RtPriorityHandleInternal, AudioThreadPriorityError> {
- let mut task_index = 0u32;
+mod win32_utils {
+ use windows_sys::{
+ core::{PCSTR, PCWSTR},
+ Win32::{
+ Foundation::{FreeLibrary, GetLastError, HMODULE, WIN32_ERROR},
+ System::LibraryLoader::{GetProcAddress, LoadLibraryW},
+ },
+ };
+
+ pub(super) fn win32_error_if(condition: bool) -> Result<(), WIN32_ERROR> {
+ if condition {
+ Err(unsafe { GetLastError() })
+ } else {
+ Ok(())
+ }
+ }
+
+ #[derive(Debug)]
+ pub(super) struct OwnedLibrary(HMODULE);
+
+ impl OwnedLibrary {
+ pub(super) fn try_new(lib_file_name: PCWSTR) -> Result<Self, WIN32_ERROR> {
+ let module = unsafe { LoadLibraryW(lib_file_name) };
+ win32_error_if(module == 0)?;
+ Ok(Self(module))
+ }
- let handle = unsafe { AvSetMmThreadCharacteristicsA(s!("Audio"), &mut task_index) };
- let handle = RtPriorityHandleInternal::new(task_index, handle);
+ fn raw(&self) -> HMODULE {
+ self.0
+ }
- if handle.task_handle == 0 {
- return Err(AudioThreadPriorityError::new(&format!(
- "Unable to restore the thread priority ({:?})",
- unsafe { GetLastError() }
- )));
+ /// SAFETY: The caller must transmute the value wrapped in a Ok(_) to the correct
+ /// function type, with the correct extern specifier.
+ pub(super) unsafe fn get_proc(
+ &self,
+ proc_name: PCSTR,
+ ) -> Result<unsafe extern "system" fn() -> isize, WIN32_ERROR> {
+ let proc = unsafe { GetProcAddress(self.raw(), proc_name) };
+ win32_error_if(proc.is_none())?;
+ Ok(proc.unwrap())
+ }
}
- info!(
- "task {} bumped to real time priority.",
- handle.mmcss_task_index
- );
+ impl Drop for OwnedLibrary {
+ fn drop(&mut self) {
+ unsafe {
+ FreeLibrary(self.raw());
+ }
+ }
+ }
+}
- Ok(handle)
+#[cfg(test)]
+mod tests {
+ use super::{
+ avrt, demote_current_thread_from_real_time_internal,
+ promote_current_thread_to_real_time_internal,
+ };
+
+ #[test]
+ fn test_successful_avrt_library_load() {
+ assert!(avrt().is_ok())
+ }
+
+ #[test]
+ fn test_successful_api_use() {
+ let handle = promote_current_thread_to_real_time_internal(512, 44100);
+ println!("handle: {handle:?}");
+ assert!(handle.is_ok());
+
+ let result = demote_current_thread_from_real_time_internal(handle.unwrap());
+ println!("result: {result:?}");
+ assert!(result.is_ok());
+ }
}