summaryrefslogtreecommitdiffstats
path: root/third_party/rust/neqo-crypto/src/ext.rs
blob: 02ee6340c12d50ca2c414fca7c86f7ae900af5c1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

use std::{
    cell::RefCell,
    os::raw::{c_uint, c_void},
    pin::Pin,
    rc::Rc,
};

use crate::{
    agentio::as_c_void,
    constants::{Extension, HandshakeMessage, TLS_HS_CLIENT_HELLO, TLS_HS_ENCRYPTED_EXTENSIONS},
    err::Res,
    null_safe_slice,
    ssl::{
        PRBool, PRFileDesc, SECFailure, SECStatus, SECSuccess, SSLAlertDescription,
        SSLExtensionHandler, SSLExtensionWriter, SSLHandshakeType,
    },
};

experimental_api!(SSL_InstallExtensionHooks(
    fd: *mut PRFileDesc,
    extension: u16,
    writer: SSLExtensionWriter,
    writer_arg: *mut c_void,
    handler: SSLExtensionHandler,
    handler_arg: *mut c_void,
));

pub enum ExtensionWriterResult {
    Write(usize),
    Skip,
}

pub enum ExtensionHandlerResult {
    Ok,
    Alert(crate::constants::Alert),
}

pub trait ExtensionHandler {
    fn write(&mut self, msg: HandshakeMessage, _d: &mut [u8]) -> ExtensionWriterResult {
        match msg {
            TLS_HS_CLIENT_HELLO | TLS_HS_ENCRYPTED_EXTENSIONS => ExtensionWriterResult::Write(0),
            _ => ExtensionWriterResult::Skip,
        }
    }

    fn handle(&mut self, msg: HandshakeMessage, _d: &[u8]) -> ExtensionHandlerResult {
        match msg {
            TLS_HS_CLIENT_HELLO | TLS_HS_ENCRYPTED_EXTENSIONS => ExtensionHandlerResult::Ok,
            _ => ExtensionHandlerResult::Alert(110), // unsupported_extension
        }
    }
}

type BoxedExtensionHandler = Box<Rc<RefCell<dyn ExtensionHandler>>>;

pub struct ExtensionTracker {
    extension: Extension,
    handler: Pin<Box<BoxedExtensionHandler>>,
}

impl ExtensionTracker {
    // Technically the as_mut() call here is the only unsafe bit,
    // but don't call this function lightly.
    unsafe fn wrap_handler_call<F, T>(arg: *mut c_void, f: F) -> T
    where
        F: FnOnce(&mut dyn ExtensionHandler) -> T,
    {
        let rc = arg.cast::<BoxedExtensionHandler>().as_mut().unwrap();
        f(&mut *rc.borrow_mut())
    }

    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
    unsafe extern "C" fn extension_writer(
        _fd: *mut PRFileDesc,
        message: SSLHandshakeType::Type,
        data: *mut u8,
        len: *mut c_uint,
        max_len: c_uint,
        arg: *mut c_void,
    ) -> PRBool {
        let d = std::slice::from_raw_parts_mut(data, max_len as usize);
        Self::wrap_handler_call(arg, |handler| {
            // Cast is safe here because the message type is always part of the enum
            match handler.write(message as HandshakeMessage, d) {
                ExtensionWriterResult::Write(sz) => {
                    *len = c_uint::try_from(sz).expect("integer overflow from extension writer");
                    1
                }
                ExtensionWriterResult::Skip => 0,
            }
        })
    }

    unsafe extern "C" fn extension_handler(
        _fd: *mut PRFileDesc,
        message: SSLHandshakeType::Type,
        data: *const u8,
        len: c_uint,
        alert: *mut SSLAlertDescription,
        arg: *mut c_void,
    ) -> SECStatus {
        let d = null_safe_slice(data, len);
        #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
        Self::wrap_handler_call(arg, |handler| {
            // Cast is safe here because the message type is always part of the enum
            match handler.handle(message as HandshakeMessage, d) {
                ExtensionHandlerResult::Ok => SECSuccess,
                ExtensionHandlerResult::Alert(a) => {
                    *alert = a;
                    SECFailure
                }
            }
        })
    }

    /// Use the provided handler to manage an extension.  This is quite unsafe.
    ///
    /// # Safety
    ///
    /// The holder of this `ExtensionTracker` needs to ensure that it lives at
    /// least as long as the file descriptor, as NSS provides no way to remove
    /// an extension handler once it is configured.
    ///
    /// # Errors
    ///
    /// If the underlying NSS API fails to register a handler.
    pub unsafe fn new(
        fd: *mut PRFileDesc,
        extension: Extension,
        handler: Rc<RefCell<dyn ExtensionHandler>>,
    ) -> Res<Self> {
        // The ergonomics here aren't great for users of this API, but it's
        // horrific here. The pinned outer box gives us a stable pointer to the inner
        // box.  This is the pointer that is passed to NSS.
        //
        // The inner box points to the reference-counted object.  This inner box is
        // what we end up with a reference to in callbacks.  That extra wrapper around
        // the Rc avoid any touching of reference counts in callbacks, which would
        // inevitably lead to leaks as we don't control how many times the callback
        // is invoked.
        //
        // This way, only this "outer" code deals with the reference count.
        let mut tracker = Self {
            extension,
            handler: Box::pin(Box::new(handler)),
        };
        SSL_InstallExtensionHooks(
            fd,
            extension,
            Some(Self::extension_writer),
            as_c_void(&mut tracker.handler),
            Some(Self::extension_handler),
            as_c_void(&mut tracker.handler),
        )?;
        Ok(tracker)
    }
}

impl std::fmt::Debug for ExtensionTracker {
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
        write!(f, "ExtensionTracker: {:?}", self.extension)
    }
}