// Take a look at the license at the top of the repository in the LICENSE file. use std::{fmt::Display, str::FromStr}; use winapi::{ shared::{ sddl::{ConvertSidToStringSidW, ConvertStringSidToSidW}, winerror::ERROR_INSUFFICIENT_BUFFER, }, um::{ errhandlingapi::GetLastError, securitybaseapi::{CopySid, GetLengthSid, IsValidSid}, winbase::{LocalFree, LookupAccountSidW}, winnt::{SidTypeUnknown, LPWSTR, PSID}, }, }; use crate::sys::utils::to_str; #[doc = include_str!("../../md_doc/sid.md")] #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Sid { sid: Vec, } impl Sid { /// Creates an `Sid` by making a copy of the given raw SID. pub(crate) unsafe fn from_psid(psid: PSID) -> Option { if psid.is_null() { return None; } if IsValidSid(psid) == 0 { return None; } let length = GetLengthSid(psid); let mut sid = vec![0; length as usize]; if CopySid(length, sid.as_mut_ptr() as *mut _, psid) == 0 { sysinfo_debug!("CopySid failed: {:?}", GetLastError()); return None; } // We are making assumptions about the SID internal structure, // and these only hold if the revision is 1 // https://learn.microsoft.com/en-us/windows/win32/api/winnt/ns-winnt-sid // Namely: // 1. SIDs can be compared directly (memcmp). // 2. Following from this, to hash a SID we can just hash its bytes. // These are the basis for deriving PartialEq, Eq, and Hash. // And since we also need PartialOrd and Ord, we might as well derive them // too. The default implementation will be consistent with Eq, // and we don't care about the actual order, just that there is one. // So it should all work out. // Why bother with this? Because it makes the implementation that // much simpler :) assert_eq!(sid[0], 1, "Expected SID revision to be 1"); Some(Self { sid }) } /// Retrieves the account name of this SID. pub(crate) fn account_name(&self) -> Option { unsafe { let mut name_len = 0; let mut domain_len = 0; let mut name_use = SidTypeUnknown; if LookupAccountSidW( std::ptr::null_mut(), self.sid.as_ptr() as *mut _, std::ptr::null_mut(), &mut name_len, std::ptr::null_mut(), &mut domain_len, &mut name_use, ) == 0 { let error = GetLastError(); if error != ERROR_INSUFFICIENT_BUFFER { sysinfo_debug!("LookupAccountSidW failed: {:?}", error); return None; } } let mut name = vec![0; name_len as usize]; // Reset length to 0 since we're still passing a NULL pointer // for the domain. domain_len = 0; if LookupAccountSidW( std::ptr::null_mut(), self.sid.as_ptr() as *mut _, name.as_mut_ptr(), &mut name_len, std::ptr::null_mut(), &mut domain_len, &mut name_use, ) == 0 { sysinfo_debug!("LookupAccountSidW failed: {:?}", GetLastError()); return None; } Some(to_str(name.as_mut_ptr())) } } } impl Display for Sid { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { unsafe fn convert_sid_to_string_sid(sid: PSID) -> Option { let mut string_sid: LPWSTR = std::ptr::null_mut(); if ConvertSidToStringSidW(sid, &mut string_sid) == 0 { sysinfo_debug!("ConvertSidToStringSidW failed: {:?}", GetLastError()); return None; } let result = to_str(string_sid); LocalFree(string_sid as *mut _); Some(result) } let string_sid = unsafe { convert_sid_to_string_sid(self.sid.as_ptr() as *mut _) }; let string_sid = string_sid.ok_or(std::fmt::Error)?; write!(f, "{string_sid}") } } impl FromStr for Sid { type Err = String; fn from_str(s: &str) -> Result { unsafe { let mut string_sid: Vec = s.encode_utf16().collect(); string_sid.push(0); let mut psid: PSID = std::ptr::null_mut(); if ConvertStringSidToSidW(string_sid.as_ptr(), &mut psid) == 0 { return Err(format!( "ConvertStringSidToSidW failed: {:?}", GetLastError() )); } let sid = Self::from_psid(psid); LocalFree(psid as *mut _); // Unwrapping because ConvertStringSidToSidW should've performed // all the necessary validations. If it returned an invalid SID, // we better fail fast. Ok(sid.unwrap()) } } }