summaryrefslogtreecommitdiffstats
path: root/third_party/rust/neqo-crypto/src/aead_fuzzing.rs
blob: 4e5a6de07f238032955bc0ef287518438ec98d4d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

use std::fmt;

use crate::{
    constants::{Cipher, Version},
    err::{sec::SEC_ERROR_BAD_DATA, Error, Res},
    p11::SymKey,
    RealAead,
};

pub const FIXED_TAG_FUZZING: &[u8] = &[0x0a; 16];

pub struct FuzzingAead {
    real: Option<RealAead>,
}

impl FuzzingAead {
    pub fn new(
        fuzzing: bool,
        version: Version,
        cipher: Cipher,
        secret: &SymKey,
        prefix: &str,
    ) -> Res<Self> {
        let real = if fuzzing {
            None
        } else {
            Some(RealAead::new(false, version, cipher, secret, prefix)?)
        };
        Ok(Self { real })
    }

    #[must_use]
    pub fn expansion(&self) -> usize {
        if let Some(aead) = &self.real {
            aead.expansion()
        } else {
            FIXED_TAG_FUZZING.len()
        }
    }

    pub fn encrypt<'a>(
        &self,
        count: u64,
        aad: &[u8],
        input: &[u8],
        output: &'a mut [u8],
    ) -> Res<&'a [u8]> {
        if let Some(aead) = &self.real {
            return aead.encrypt(count, aad, input, output);
        }

        let l = input.len();
        output[..l].copy_from_slice(input);
        output[l..l + 16].copy_from_slice(FIXED_TAG_FUZZING);
        Ok(&output[..l + 16])
    }

    pub fn decrypt<'a>(
        &self,
        count: u64,
        aad: &[u8],
        input: &[u8],
        output: &'a mut [u8],
    ) -> Res<&'a [u8]> {
        if let Some(aead) = &self.real {
            return aead.decrypt(count, aad, input, output);
        }

        if input.len() < FIXED_TAG_FUZZING.len() {
            return Err(Error::from(SEC_ERROR_BAD_DATA));
        }

        let len_encrypted = input.len() - FIXED_TAG_FUZZING.len();
        // Check that:
        // 1) expansion is all zeros and
        // 2) if the encrypted data is also supplied that at least some values are no zero
        //    (otherwise padding will be interpreted as a valid packet)
        if &input[len_encrypted..] == FIXED_TAG_FUZZING
            && (len_encrypted == 0 || input[..len_encrypted].iter().any(|x| *x != 0x0))
        {
            output[..len_encrypted].copy_from_slice(&input[..len_encrypted]);
            Ok(&output[..len_encrypted])
        } else {
            Err(Error::from(SEC_ERROR_BAD_DATA))
        }
    }
}

impl fmt::Debug for FuzzingAead {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        if let Some(a) = &self.real {
            a.fmt(f)
        } else {
            write!(f, "[FUZZING AEAD]")
        }
    }
}