summaryrefslogtreecommitdiffstats
path: root/vendor/curl/src/easy/windows.rs
blob: cbf2817cf615dd51cfd34eabef5bafacc5dd193c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#![allow(non_camel_case_types, non_snake_case)]

use libc::c_void;

#[cfg(target_env = "msvc")]
mod win {
    use schannel::cert_context::ValidUses;
    use schannel::cert_store::CertStore;
    use std::ffi::CString;
    use std::mem;
    use std::ptr;
    use winapi::ctypes::*;
    use winapi::um::libloaderapi::*;
    use winapi::um::wincrypt::*;

    fn lookup(module: &str, symbol: &str) -> Option<*const c_void> {
        unsafe {
            let symbol = CString::new(symbol).unwrap();
            let mut mod_buf: Vec<u16> = module.encode_utf16().collect();
            mod_buf.push(0);
            let handle = GetModuleHandleW(mod_buf.as_mut_ptr());
            let n = GetProcAddress(handle, symbol.as_ptr());
            if n == ptr::null_mut() {
                None
            } else {
                Some(n as *const c_void)
            }
        }
    }

    pub enum X509_STORE {}
    pub enum X509 {}
    pub enum SSL_CTX {}

    type d2i_X509_fn = unsafe extern "C" fn(
        a: *mut *mut X509,
        pp: *mut *const c_uchar,
        length: c_long,
    ) -> *mut X509;
    type X509_free_fn = unsafe extern "C" fn(x: *mut X509);
    type X509_STORE_add_cert_fn =
        unsafe extern "C" fn(store: *mut X509_STORE, x: *mut X509) -> c_int;
    type SSL_CTX_get_cert_store_fn = unsafe extern "C" fn(ctx: *const SSL_CTX) -> *mut X509_STORE;

    struct OpenSSL {
        d2i_X509: d2i_X509_fn,
        X509_free: X509_free_fn,
        X509_STORE_add_cert: X509_STORE_add_cert_fn,
        SSL_CTX_get_cert_store: SSL_CTX_get_cert_store_fn,
    }

    unsafe fn lookup_functions(crypto_module: &str, ssl_module: &str) -> Option<OpenSSL> {
        macro_rules! get {
            ($(let $sym:ident in $module:expr;)*) => ($(
                let $sym = match lookup($module, stringify!($sym)) {
                    Some(p) => p,
                    None => return None,
                };
            )*)
        }
        get! {
            let d2i_X509 in crypto_module;
            let X509_free in crypto_module;
            let X509_STORE_add_cert in crypto_module;
            let SSL_CTX_get_cert_store in ssl_module;
        }
        Some(OpenSSL {
            d2i_X509: mem::transmute(d2i_X509),
            X509_free: mem::transmute(X509_free),
            X509_STORE_add_cert: mem::transmute(X509_STORE_add_cert),
            SSL_CTX_get_cert_store: mem::transmute(SSL_CTX_get_cert_store),
        })
    }

    pub unsafe fn add_certs_to_context(ssl_ctx: *mut c_void) {
        // check the runtime version of OpenSSL
        let openssl = match crate::version::Version::get().ssl_version() {
            Some(ssl_ver) if ssl_ver.starts_with("OpenSSL/1.1.0") => {
                lookup_functions("libcrypto", "libssl")
            }
            Some(ssl_ver) if ssl_ver.starts_with("OpenSSL/1.0.2") => {
                lookup_functions("libeay32", "ssleay32")
            }
            _ => return,
        };
        let openssl = match openssl {
            Some(s) => s,
            None => return,
        };

        let openssl_store = (openssl.SSL_CTX_get_cert_store)(ssl_ctx as *const SSL_CTX);
        let store = match CertStore::open_current_user("ROOT") {
            Ok(s) => s,
            Err(_) => return,
        };

        for cert in store.certs() {
            let valid_uses = match cert.valid_uses() {
                Ok(v) => v,
                Err(_) => continue,
            };

            // check the extended key usage for the "Server Authentication" OID
            match valid_uses {
                ValidUses::All => {}
                ValidUses::Oids(ref oids) => {
                    let oid = szOID_PKIX_KP_SERVER_AUTH.to_owned();
                    if !oids.contains(&oid) {
                        continue;
                    }
                }
            }

            let der = cert.to_der();
            let x509 = (openssl.d2i_X509)(ptr::null_mut(), &mut der.as_ptr(), der.len() as c_long);
            if !x509.is_null() {
                (openssl.X509_STORE_add_cert)(openssl_store, x509);
                (openssl.X509_free)(x509);
            }
        }
    }
}

#[cfg(target_env = "msvc")]
pub fn add_certs_to_context(ssl_ctx: *mut c_void) {
    unsafe {
        win::add_certs_to_context(ssl_ctx as *mut _);
    }
}

#[cfg(not(target_env = "msvc"))]
pub fn add_certs_to_context(_: *mut c_void) {}