summaryrefslogtreecommitdiffstats
path: root/vendor/sysinfo/src/windows/sid.rs
blob: 44d94f2f588b246cd8b2bbe13af0699c70e7aa5c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
// Take a look at the license at the top of the repository in the LICENSE file.

use std::{fmt::Display, str::FromStr};

use winapi::{
    shared::{
        sddl::{ConvertSidToStringSidW, ConvertStringSidToSidW},
        winerror::ERROR_INSUFFICIENT_BUFFER,
    },
    um::{
        errhandlingapi::GetLastError,
        securitybaseapi::{CopySid, GetLengthSid, IsValidSid},
        winbase::{LocalFree, LookupAccountSidW},
        winnt::{SidTypeUnknown, LPWSTR, PSID},
    },
};

use crate::sys::utils::to_str;

#[doc = include_str!("../../md_doc/sid.md")]
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Sid {
    sid: Vec<u8>,
}

impl Sid {
    /// Creates an `Sid` by making a copy of the given raw SID.
    pub(crate) unsafe fn from_psid(psid: PSID) -> Option<Self> {
        if psid.is_null() {
            return None;
        }

        if IsValidSid(psid) == 0 {
            return None;
        }

        let length = GetLengthSid(psid);

        let mut sid = vec![0; length as usize];

        if CopySid(length, sid.as_mut_ptr() as *mut _, psid) == 0 {
            sysinfo_debug!("CopySid failed: {:?}", GetLastError());
            return None;
        }

        // We are making assumptions about the SID internal structure,
        // and these only hold if the revision is 1
        // https://learn.microsoft.com/en-us/windows/win32/api/winnt/ns-winnt-sid
        // Namely:
        // 1. SIDs can be compared directly (memcmp).
        // 2. Following from this, to hash a SID we can just hash its bytes.
        // These are the basis for deriving PartialEq, Eq, and Hash.
        // And since we also need PartialOrd and Ord, we might as well derive them
        // too. The default implementation will be consistent with Eq,
        // and we don't care about the actual order, just that there is one.
        // So it should all work out.
        // Why bother with this? Because it makes the implementation that
        // much simpler :)
        assert_eq!(sid[0], 1, "Expected SID revision to be 1");

        Some(Self { sid })
    }

    /// Retrieves the account name of this SID.
    pub(crate) fn account_name(&self) -> Option<String> {
        unsafe {
            let mut name_len = 0;
            let mut domain_len = 0;
            let mut name_use = SidTypeUnknown;

            if LookupAccountSidW(
                std::ptr::null_mut(),
                self.sid.as_ptr() as *mut _,
                std::ptr::null_mut(),
                &mut name_len,
                std::ptr::null_mut(),
                &mut domain_len,
                &mut name_use,
            ) == 0
            {
                let error = GetLastError();
                if error != ERROR_INSUFFICIENT_BUFFER {
                    sysinfo_debug!("LookupAccountSidW failed: {:?}", error);
                    return None;
                }
            }

            let mut name = vec![0; name_len as usize];

            // Reset length to 0 since we're still passing a NULL pointer
            // for the domain.
            domain_len = 0;

            if LookupAccountSidW(
                std::ptr::null_mut(),
                self.sid.as_ptr() as *mut _,
                name.as_mut_ptr(),
                &mut name_len,
                std::ptr::null_mut(),
                &mut domain_len,
                &mut name_use,
            ) == 0
            {
                sysinfo_debug!("LookupAccountSidW failed: {:?}", GetLastError());
                return None;
            }

            Some(to_str(name.as_mut_ptr()))
        }
    }
}

impl Display for Sid {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        unsafe fn convert_sid_to_string_sid(sid: PSID) -> Option<String> {
            let mut string_sid: LPWSTR = std::ptr::null_mut();
            if ConvertSidToStringSidW(sid, &mut string_sid) == 0 {
                sysinfo_debug!("ConvertSidToStringSidW failed: {:?}", GetLastError());
                return None;
            }
            let result = to_str(string_sid);
            LocalFree(string_sid as *mut _);
            Some(result)
        }

        let string_sid = unsafe { convert_sid_to_string_sid(self.sid.as_ptr() as *mut _) };
        let string_sid = string_sid.ok_or(std::fmt::Error)?;

        write!(f, "{string_sid}")
    }
}

impl FromStr for Sid {
    type Err = String;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        unsafe {
            let mut string_sid: Vec<u16> = s.encode_utf16().collect();
            string_sid.push(0);

            let mut psid: PSID = std::ptr::null_mut();
            if ConvertStringSidToSidW(string_sid.as_ptr(), &mut psid) == 0 {
                return Err(format!(
                    "ConvertStringSidToSidW failed: {:?}",
                    GetLastError()
                ));
            }
            let sid = Self::from_psid(psid);
            LocalFree(psid as *mut _);

            // Unwrapping because ConvertStringSidToSidW should've performed
            // all the necessary validations. If it returned an invalid SID,
            // we better fail fast.
            Ok(sid.unwrap())
        }
    }
}