summaryrefslogtreecommitdiffstats
path: root/third_party/rust/neqo-crypto/src/secrets.rs
blob: 7e2739cfa5dadbe8024e880b9074d47578a55722 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

use crate::agentio::as_c_void;
use crate::constants::Epoch;
use crate::err::Res;
use crate::p11::{PK11SymKey, PK11_ReferenceSymKey, SymKey};
use crate::ssl::{PRFileDesc, SSLSecretCallback, SSLSecretDirection};

use neqo_common::qdebug;
use std::os::raw::c_void;
use std::pin::Pin;

experimental_api!(SSL_SecretCallback(
    fd: *mut PRFileDesc,
    cb: SSLSecretCallback,
    arg: *mut c_void,
));

#[derive(Clone, Copy, Debug)]
pub enum SecretDirection {
    Read,
    Write,
}

impl From<SSLSecretDirection::Type> for SecretDirection {
    #[must_use]
    fn from(dir: SSLSecretDirection::Type) -> Self {
        match dir {
            SSLSecretDirection::ssl_secret_read => Self::Read,
            SSLSecretDirection::ssl_secret_write => Self::Write,
            _ => unreachable!(),
        }
    }
}

#[derive(Debug, Default)]
#[allow(clippy::module_name_repetitions)]
pub struct DirectionalSecrets {
    // We only need to maintain 3 secrets for the epochs used during the handshake.
    secrets: [Option<SymKey>; 3],
}

impl DirectionalSecrets {
    fn put(&mut self, epoch: Epoch, key: SymKey) {
        assert!(epoch > 0);
        let i = (epoch - 1) as usize;
        assert!(i < self.secrets.len());
        // assert!(self.secrets[i].is_none());
        self.secrets[i] = Some(key);
    }

    pub fn take(&mut self, epoch: Epoch) -> Option<SymKey> {
        assert!(epoch > 0);
        let i = (epoch - 1) as usize;
        assert!(i < self.secrets.len());
        self.secrets[i].take()
    }
}

#[derive(Debug, Default)]
pub struct Secrets {
    r: DirectionalSecrets,
    w: DirectionalSecrets,
}

impl Secrets {
    #[allow(clippy::unused_self)]
    unsafe extern "C" fn secret_available(
        _fd: *mut PRFileDesc,
        epoch: u16,
        dir: SSLSecretDirection::Type,
        secret: *mut PK11SymKey,
        arg: *mut c_void,
    ) {
        let secrets = arg.cast::<Self>().as_mut().unwrap();
        secrets.put_raw(epoch, dir, secret);
    }

    fn put_raw(&mut self, epoch: Epoch, dir: SSLSecretDirection::Type, key_ptr: *mut PK11SymKey) {
        let key_ptr = unsafe { PK11_ReferenceSymKey(key_ptr) };
        let key = SymKey::from_ptr(key_ptr).expect("NSS shouldn't be passing out NULL secrets");
        self.put(SecretDirection::from(dir), epoch, key);
    }

    fn put(&mut self, dir: SecretDirection, epoch: Epoch, key: SymKey) {
        qdebug!("{:?} secret available for {:?}: {:?}", dir, epoch, key);
        let keys = match dir {
            SecretDirection::Read => &mut self.r,
            SecretDirection::Write => &mut self.w,
        };
        keys.put(epoch, key);
    }
}

#[derive(Debug)]
pub struct SecretHolder {
    secrets: Pin<Box<Secrets>>,
}

impl SecretHolder {
    /// This registers with NSS.  The lifetime of this object needs to match the lifetime
    /// of the connection, or bad things might happen.
    pub fn register(&mut self, fd: *mut PRFileDesc) -> Res<()> {
        let p = as_c_void(&mut self.secrets);
        unsafe { SSL_SecretCallback(fd, Some(Secrets::secret_available), p) }
    }

    pub fn take_read(&mut self, epoch: Epoch) -> Option<SymKey> {
        self.secrets.r.take(epoch)
    }

    pub fn take_write(&mut self, epoch: Epoch) -> Option<SymKey> {
        self.secrets.w.take(epoch)
    }
}

impl Default for SecretHolder {
    fn default() -> Self {
        Self {
            secrets: Box::pin(Secrets::default()),
        }
    }
}