diff options
Diffstat (limited to 'vendor/sysinfo/src/windows/sid.rs')
-rw-r--r-- | vendor/sysinfo/src/windows/sid.rs | 157 |
1 files changed, 157 insertions, 0 deletions
diff --git a/vendor/sysinfo/src/windows/sid.rs b/vendor/sysinfo/src/windows/sid.rs new file mode 100644 index 000000000..44d94f2f5 --- /dev/null +++ b/vendor/sysinfo/src/windows/sid.rs @@ -0,0 +1,157 @@ +// Take a look at the license at the top of the repository in the LICENSE file.
+
+use std::{fmt::Display, str::FromStr};
+
+use winapi::{
+ shared::{
+ sddl::{ConvertSidToStringSidW, ConvertStringSidToSidW},
+ winerror::ERROR_INSUFFICIENT_BUFFER,
+ },
+ um::{
+ errhandlingapi::GetLastError,
+ securitybaseapi::{CopySid, GetLengthSid, IsValidSid},
+ winbase::{LocalFree, LookupAccountSidW},
+ winnt::{SidTypeUnknown, LPWSTR, PSID},
+ },
+};
+
+use crate::sys::utils::to_str;
+
+#[doc = include_str!("../../md_doc/sid.md")]
+#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
+pub struct Sid {
+ sid: Vec<u8>,
+}
+
+impl Sid {
+ /// Creates an `Sid` by making a copy of the given raw SID.
+ pub(crate) unsafe fn from_psid(psid: PSID) -> Option<Self> {
+ if psid.is_null() {
+ return None;
+ }
+
+ if IsValidSid(psid) == 0 {
+ return None;
+ }
+
+ let length = GetLengthSid(psid);
+
+ let mut sid = vec![0; length as usize];
+
+ if CopySid(length, sid.as_mut_ptr() as *mut _, psid) == 0 {
+ sysinfo_debug!("CopySid failed: {:?}", GetLastError());
+ return None;
+ }
+
+ // We are making assumptions about the SID internal structure,
+ // and these only hold if the revision is 1
+ // https://learn.microsoft.com/en-us/windows/win32/api/winnt/ns-winnt-sid
+ // Namely:
+ // 1. SIDs can be compared directly (memcmp).
+ // 2. Following from this, to hash a SID we can just hash its bytes.
+ // These are the basis for deriving PartialEq, Eq, and Hash.
+ // And since we also need PartialOrd and Ord, we might as well derive them
+ // too. The default implementation will be consistent with Eq,
+ // and we don't care about the actual order, just that there is one.
+ // So it should all work out.
+ // Why bother with this? Because it makes the implementation that
+ // much simpler :)
+ assert_eq!(sid[0], 1, "Expected SID revision to be 1");
+
+ Some(Self { sid })
+ }
+
+ /// Retrieves the account name of this SID.
+ pub(crate) fn account_name(&self) -> Option<String> {
+ unsafe {
+ let mut name_len = 0;
+ let mut domain_len = 0;
+ let mut name_use = SidTypeUnknown;
+
+ if LookupAccountSidW(
+ std::ptr::null_mut(),
+ self.sid.as_ptr() as *mut _,
+ std::ptr::null_mut(),
+ &mut name_len,
+ std::ptr::null_mut(),
+ &mut domain_len,
+ &mut name_use,
+ ) == 0
+ {
+ let error = GetLastError();
+ if error != ERROR_INSUFFICIENT_BUFFER {
+ sysinfo_debug!("LookupAccountSidW failed: {:?}", error);
+ return None;
+ }
+ }
+
+ let mut name = vec![0; name_len as usize];
+
+ // Reset length to 0 since we're still passing a NULL pointer
+ // for the domain.
+ domain_len = 0;
+
+ if LookupAccountSidW(
+ std::ptr::null_mut(),
+ self.sid.as_ptr() as *mut _,
+ name.as_mut_ptr(),
+ &mut name_len,
+ std::ptr::null_mut(),
+ &mut domain_len,
+ &mut name_use,
+ ) == 0
+ {
+ sysinfo_debug!("LookupAccountSidW failed: {:?}", GetLastError());
+ return None;
+ }
+
+ Some(to_str(name.as_mut_ptr()))
+ }
+ }
+}
+
+impl Display for Sid {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ unsafe fn convert_sid_to_string_sid(sid: PSID) -> Option<String> {
+ let mut string_sid: LPWSTR = std::ptr::null_mut();
+ if ConvertSidToStringSidW(sid, &mut string_sid) == 0 {
+ sysinfo_debug!("ConvertSidToStringSidW failed: {:?}", GetLastError());
+ return None;
+ }
+ let result = to_str(string_sid);
+ LocalFree(string_sid as *mut _);
+ Some(result)
+ }
+
+ let string_sid = unsafe { convert_sid_to_string_sid(self.sid.as_ptr() as *mut _) };
+ let string_sid = string_sid.ok_or(std::fmt::Error)?;
+
+ write!(f, "{string_sid}")
+ }
+}
+
+impl FromStr for Sid {
+ type Err = String;
+
+ fn from_str(s: &str) -> Result<Self, Self::Err> {
+ unsafe {
+ let mut string_sid: Vec<u16> = s.encode_utf16().collect();
+ string_sid.push(0);
+
+ let mut psid: PSID = std::ptr::null_mut();
+ if ConvertStringSidToSidW(string_sid.as_ptr(), &mut psid) == 0 {
+ return Err(format!(
+ "ConvertStringSidToSidW failed: {:?}",
+ GetLastError()
+ ));
+ }
+ let sid = Self::from_psid(psid);
+ LocalFree(psid as *mut _);
+
+ // Unwrapping because ConvertStringSidToSidW should've performed
+ // all the necessary validations. If it returned an invalid SID,
+ // we better fail fast.
+ Ok(sid.unwrap())
+ }
+ }
+}
|