summaryrefslogtreecommitdiffstats
path: root/crates/credential/cargo-credential-wincred/src/main.rs
blob: 8ae48f348998638957519cea70ad4f1729a8d3c7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
//! Cargo registry windows credential process.

use cargo_credential::{Credential, Error};
use std::ffi::OsStr;
use std::os::windows::ffi::OsStrExt;

use windows_sys::core::PWSTR;
use windows_sys::Win32::Foundation::ERROR_NOT_FOUND;
use windows_sys::Win32::Foundation::FILETIME;
use windows_sys::Win32::Foundation::TRUE;
use windows_sys::Win32::Security::Credentials::CredDeleteW;
use windows_sys::Win32::Security::Credentials::CredReadW;
use windows_sys::Win32::Security::Credentials::CredWriteW;
use windows_sys::Win32::Security::Credentials::CREDENTIALW;
use windows_sys::Win32::Security::Credentials::CRED_PERSIST_LOCAL_MACHINE;
use windows_sys::Win32::Security::Credentials::CRED_TYPE_GENERIC;

struct WindowsCredential;

/// Converts a string to a nul-terminated wide UTF-16 byte sequence.
fn wstr(s: &str) -> Vec<u16> {
    let mut wide: Vec<u16> = OsStr::new(s).encode_wide().collect();
    if wide.iter().any(|b| *b == 0) {
        panic!("nul byte in wide string");
    }
    wide.push(0);
    wide
}

fn target_name(registry_name: &str) -> Vec<u16> {
    wstr(&format!("cargo-registry:{}", registry_name))
}

impl Credential for WindowsCredential {
    fn name(&self) -> &'static str {
        env!("CARGO_PKG_NAME")
    }

    fn get(&self, index_url: &str) -> Result<String, Error> {
        let target_name = target_name(index_url);
        let p_credential: *mut CREDENTIALW = std::ptr::null_mut() as *mut _;
        unsafe {
            if CredReadW(
                target_name.as_ptr(),
                CRED_TYPE_GENERIC,
                0,
                p_credential as *mut _ as *mut _,
            ) != TRUE
            {
                return Err(
                    format!("failed to fetch token: {}", std::io::Error::last_os_error()).into(),
                );
            }
            let bytes = std::slice::from_raw_parts(
                (*p_credential).CredentialBlob,
                (*p_credential).CredentialBlobSize as usize,
            );
            String::from_utf8(bytes.to_vec()).map_err(|_| "failed to convert token to UTF8".into())
        }
    }

    fn store(&self, index_url: &str, token: &str, name: Option<&str>) -> Result<(), Error> {
        let token = token.as_bytes();
        let target_name = target_name(index_url);
        let comment = match name {
            Some(name) => wstr(&format!("Cargo registry token for {}", name)),
            None => wstr("Cargo registry token"),
        };
        let mut credential = CREDENTIALW {
            Flags: 0,
            Type: CRED_TYPE_GENERIC,
            TargetName: target_name.as_ptr() as PWSTR,
            Comment: comment.as_ptr() as PWSTR,
            LastWritten: FILETIME {
                dwLowDateTime: 0,
                dwHighDateTime: 0,
            },
            CredentialBlobSize: token.len() as u32,
            CredentialBlob: token.as_ptr() as *mut u8,
            Persist: CRED_PERSIST_LOCAL_MACHINE,
            AttributeCount: 0,
            Attributes: std::ptr::null_mut(),
            TargetAlias: std::ptr::null_mut(),
            UserName: std::ptr::null_mut(),
        };
        let result = unsafe { CredWriteW(&mut credential, 0) };
        if result != TRUE {
            let err = std::io::Error::last_os_error();
            return Err(format!("failed to store token: {}", err).into());
        }
        Ok(())
    }

    fn erase(&self, index_url: &str) -> Result<(), Error> {
        let target_name = target_name(index_url);
        let result = unsafe { CredDeleteW(target_name.as_ptr(), CRED_TYPE_GENERIC, 0) };
        if result != TRUE {
            let err = std::io::Error::last_os_error();
            if err.raw_os_error() == Some(ERROR_NOT_FOUND as i32) {
                eprintln!("not currently logged in to `{}`", index_url);
                return Ok(());
            }
            return Err(format!("failed to remove token: {}", err).into());
        }
        Ok(())
    }
}

fn main() {
    cargo_credential::main(WindowsCredential);
}