summaryrefslogtreecommitdiffstats
path: root/third_party/rust/rust_cascade/src
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/rust_cascade/src')
-rw-r--r--third_party/rust/rust_cascade/src/lib.rs1129
1 files changed, 1129 insertions, 0 deletions
diff --git a/third_party/rust/rust_cascade/src/lib.rs b/third_party/rust/rust_cascade/src/lib.rs
new file mode 100644
index 0000000000..eef8e1f97d
--- /dev/null
+++ b/third_party/rust/rust_cascade/src/lib.rs
@@ -0,0 +1,1129 @@
+//! # rust-cascade
+//!
+//! A library for creating and querying the cascading bloom filters described by
+//! Larisch, Choffnes, Levin, Maggs, Mislove, and Wilson in
+//! "CRLite: A Scalable System for Pushing All TLS Revocations to All Browsers"
+//! <https://www.ieee-security.org/TC/SP2017/papers/567.pdf>
+
+extern crate byteorder;
+extern crate murmurhash3;
+extern crate rand;
+extern crate sha2;
+
+use byteorder::{ByteOrder, LittleEndian, ReadBytesExt};
+use murmurhash3::murmurhash3_x86_32;
+#[cfg(feature = "builder")]
+use rand::rngs::OsRng;
+#[cfg(feature = "builder")]
+use rand::RngCore;
+use sha2::{Digest, Sha256};
+use std::convert::{TryFrom, TryInto};
+use std::fmt;
+use std::io::{ErrorKind, Read};
+use std::mem::size_of;
+
+#[derive(Debug)]
+pub enum CascadeError {
+ LongSalt,
+ TooManyLayers,
+ Collision,
+ UnknownHashFunction,
+ CapacityViolation(&'static str),
+ Parse(&'static str),
+}
+
+impl fmt::Display for CascadeError {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ match *self {
+ CascadeError::LongSalt => {
+ write!(f, "Cannot serialize a filter with a salt of length >= 256.")
+ }
+ CascadeError::TooManyLayers => {
+ write!(f, "Cannot serialize a filter with >= 255 layers.")
+ }
+ CascadeError::Collision => {
+ write!(f, "Collision between included and excluded sets.")
+ }
+ CascadeError::UnknownHashFunction => {
+ write!(f, "Unknown hash function.")
+ }
+ CascadeError::CapacityViolation(function) => {
+ write!(f, "Unexpected call to {}", function)
+ }
+ CascadeError::Parse(reason) => {
+ write!(f, "Cannot parse cascade: {}", reason)
+ }
+ }
+ }
+}
+
+/// A Bloom filter representing a specific layer in a multi-layer cascading Bloom filter.
+/// The same hash function is used for all layers, so it is not encoded here.
+struct Bloom {
+ /// How many hash functions this filter uses
+ n_hash_funcs: u32,
+ /// The bit length of the filter
+ size: u32,
+ /// The data of the filter
+ data: Vec<u8>,
+}
+
+#[repr(u8)]
+#[derive(Copy, Clone, PartialEq)]
+/// These enumerations need to match the python filter-cascade project:
+/// <https://github.com/mozilla/filter-cascade/blob/v0.3.0/filtercascade/fileformats.py>
+pub enum HashAlgorithm {
+ MurmurHash3 = 1,
+ Sha256l32 = 2, // low 32 bits of sha256
+ Sha256 = 3, // all 256 bits of sha256
+}
+
+impl fmt::Display for HashAlgorithm {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(f, "{}", *self as u8)
+ }
+}
+
+impl TryFrom<u8> for HashAlgorithm {
+ type Error = CascadeError;
+ fn try_from(value: u8) -> Result<HashAlgorithm, CascadeError> {
+ match value {
+ // Naturally, these need to match the enum declaration
+ 1 => Ok(Self::MurmurHash3),
+ 2 => Ok(Self::Sha256l32),
+ 3 => Ok(Self::Sha256),
+ _ => Err(CascadeError::UnknownHashFunction),
+ }
+ }
+}
+
+/// A CascadeIndexGenerator provides one-time access to a table of pseudorandom functions H_ij
+/// in which each function is of the form
+/// H(s: &[u8], r: u32) -> usize
+/// and for which 0 <= H(s,r) < r for all s, r.
+/// The pseudorandom functions share a common key, represented as a octet string, and the table can
+/// be constructed from this key alone. The functions are pseudorandom with respect to s, but not
+/// r. For a uniformly random key/table, fixed r, and arbitrary strings m0 and m1,
+/// H_ij(m0, r) is computationally indistinguishable from H_ij(m1,r)
+/// for all i,j.
+///
+/// A call to next_layer() increments i and resets j.
+/// A call to next_index(s, r) increments j, and outputs some value H_ij(s) with 0 <= H_ij(s) < r.
+
+#[derive(Debug)]
+enum CascadeIndexGenerator {
+ MurmurHash3 {
+ key: Vec<u8>,
+ counter: u32,
+ depth: u8,
+ },
+ Sha256l32 {
+ key: Vec<u8>,
+ counter: u32,
+ depth: u8,
+ },
+ Sha256Ctr {
+ key: Vec<u8>,
+ counter: u32,
+ state: [u8; 32],
+ state_available: u8,
+ },
+}
+
+impl PartialEq for CascadeIndexGenerator {
+ fn eq(&self, other: &Self) -> bool {
+ match (self, other) {
+ (
+ CascadeIndexGenerator::MurmurHash3 { key: ref a, .. },
+ CascadeIndexGenerator::MurmurHash3 { key: ref b, .. },
+ )
+ | (
+ CascadeIndexGenerator::Sha256l32 { key: ref a, .. },
+ CascadeIndexGenerator::Sha256l32 { key: ref b, .. },
+ )
+ | (
+ CascadeIndexGenerator::Sha256Ctr { key: ref a, .. },
+ CascadeIndexGenerator::Sha256Ctr { key: ref b, .. },
+ ) => a == b,
+ _ => false,
+ }
+ }
+}
+
+impl CascadeIndexGenerator {
+ fn new(hash_alg: HashAlgorithm, key: Vec<u8>) -> Self {
+ match hash_alg {
+ HashAlgorithm::MurmurHash3 => Self::MurmurHash3 {
+ key,
+ counter: 0,
+ depth: 1,
+ },
+ HashAlgorithm::Sha256l32 => Self::Sha256l32 {
+ key,
+ counter: 0,
+ depth: 1,
+ },
+ HashAlgorithm::Sha256 => Self::Sha256Ctr {
+ key,
+ counter: 0,
+ state: [0; 32],
+ state_available: 0,
+ },
+ }
+ }
+
+ fn next_layer(&mut self) {
+ match self {
+ Self::MurmurHash3 {
+ ref mut counter,
+ ref mut depth,
+ ..
+ }
+ | Self::Sha256l32 {
+ ref mut counter,
+ ref mut depth,
+ ..
+ } => {
+ *counter = 0;
+ *depth += 1;
+ }
+ Self::Sha256Ctr { .. } => (),
+ }
+ }
+
+ fn next_index(&mut self, salt: &[u8], range: u32) -> usize {
+ let index = match self {
+ Self::MurmurHash3 {
+ key,
+ ref mut counter,
+ depth,
+ } => {
+ let hash_seed = (*counter << 16) + *depth as u32;
+ *counter += 1;
+ murmurhash3_x86_32(key, hash_seed)
+ }
+
+ Self::Sha256l32 {
+ key,
+ ref mut counter,
+ depth,
+ } => {
+ let mut hasher = Sha256::new();
+ hasher.update(salt);
+ hasher.update(counter.to_le_bytes());
+ hasher.update(depth.to_le_bytes());
+ hasher.update(&key);
+ *counter += 1;
+ u32::from_le_bytes(
+ hasher.finalize()[0..4]
+ .try_into()
+ .expect("sha256 should have given enough bytes"),
+ )
+ }
+
+ Self::Sha256Ctr {
+ key,
+ ref mut counter,
+ ref mut state,
+ ref mut state_available,
+ } => {
+ // |bytes_needed| is the minimum number of bytes needed to represent a value in [0, range).
+ let bytes_needed = ((range.next_power_of_two().trailing_zeros() + 7) / 8) as usize;
+ let mut index_arr = [0u8; 4];
+ for byte in index_arr.iter_mut().take(bytes_needed) {
+ if *state_available == 0 {
+ let mut hasher = Sha256::new();
+ hasher.update(counter.to_le_bytes());
+ hasher.update(salt);
+ hasher.update(&key);
+ hasher.finalize_into(state.into());
+ *state_available = state.len() as u8;
+ *counter += 1;
+ }
+ *byte = state[state.len() - *state_available as usize];
+ *state_available -= 1;
+ }
+ LittleEndian::read_u32(&index_arr)
+ }
+ };
+ (index % range) as usize
+ }
+}
+
+impl Bloom {
+ /// `new_crlite_bloom` creates an empty bloom filter for a layer of a cascade with the
+ /// parameters specified in [LCL+17, Section III.C].
+ ///
+ /// # Arguments
+ /// * `include_capacity` - the number of elements that will be encoded at the new layer.
+ /// * `exclude_capacity` - the number of elements in the complement of the encoded set.
+ /// * `top_layer` - whether this is the top layer of the filter.
+ #[cfg(feature = "builder")]
+ pub fn new_crlite_bloom(
+ include_capacity: usize,
+ exclude_capacity: usize,
+ top_layer: bool,
+ ) -> Self {
+ assert!(include_capacity != 0 && exclude_capacity != 0);
+
+ let r = include_capacity as f64;
+ let s = exclude_capacity as f64;
+
+ // The desired false positive rate for the top layer is
+ // p = r/(sqrt(2)*s).
+ // With this setting, the number of false positives (which will need to be
+ // encoded at the second layer) is expected to be a factor of sqrt(2)
+ // smaller than the number of elements encoded at the top layer.
+ //
+ // At layer i > 1 we try to ensure that the number of elements to be
+ // encoded at layer i+1 is half the number of elements encoded at
+ // layer i. So we take p = 1/2.
+ let log2_fp_rate = match top_layer {
+ true => (r / s).log2() - 0.5f64,
+ false => -1f64,
+ };
+
+ // the number of hash functions (k) and the size of the bloom filter (m) are given in
+ // [LCL+17] as k = log2(1/p) and m = r log2(1/p) / ln(2).
+ //
+ // If this formula gives a value of m < 256, we take m=256 instead. This results in very
+ // slightly sub-optimal size, but gives us the added benefit of doing less hashing.
+ let n_hash_funcs = (-log2_fp_rate).round() as u32;
+ let size = match (r * (-log2_fp_rate) / (f64::ln(2f64))).round() as u32 {
+ size if size >= 256 => size,
+ _ => 256,
+ };
+
+ Bloom {
+ n_hash_funcs,
+ size,
+ data: vec![0u8; ((size + 7) / 8) as usize],
+ }
+ }
+
+ /// `read` attempts to decode the Bloom filter represented by the bytes in the given reader.
+ ///
+ /// # Arguments
+ /// * `reader` - The encoded representation of this Bloom filter. May be empty. May include
+ /// additional data describing further Bloom filters.
+ /// The format of an encoded Bloom filter is:
+ /// [1 byte] - the hash algorithm to use in the filter
+ /// [4 little endian bytes] - the length in bits of the filter
+ /// [4 little endian bytes] - the number of hash functions to use in the filter
+ /// [1 byte] - which layer in the cascade this filter is
+ /// [variable length bytes] - the filter itself (must be of minimal length)
+ pub fn read<R: Read>(
+ reader: &mut R,
+ ) -> Result<Option<(Bloom, usize, HashAlgorithm)>, CascadeError> {
+ let hash_algorithm_val = match reader.read_u8() {
+ Ok(val) => val,
+ // If reader is at EOF, there is no bloom filter.
+ Err(e) if e.kind() == ErrorKind::UnexpectedEof => return Ok(None),
+ Err(_) => return Err(CascadeError::Parse("read error")),
+ };
+ let hash_algorithm = HashAlgorithm::try_from(hash_algorithm_val)?;
+
+ let size = reader
+ .read_u32::<byteorder::LittleEndian>()
+ .or(Err(CascadeError::Parse("truncated at layer size")))?;
+ let n_hash_funcs = reader
+ .read_u32::<byteorder::LittleEndian>()
+ .or(Err(CascadeError::Parse("truncated at layer hash count")))?;
+ let layer = reader
+ .read_u8()
+ .or(Err(CascadeError::Parse("truncated at layer number")))?;
+
+ let byte_count = ((size + 7) / 8) as usize;
+ let mut data = vec![0; byte_count];
+ reader
+ .read_exact(&mut data)
+ .or(Err(CascadeError::Parse("truncated at layer data")))?;
+ let bloom = Bloom {
+ n_hash_funcs,
+ size,
+ data,
+ };
+ Ok(Some((bloom, layer as usize, hash_algorithm)))
+ }
+
+ fn has(&self, generator: &mut CascadeIndexGenerator, salt: &[u8]) -> bool {
+ for _ in 0..self.n_hash_funcs {
+ let bit_index = generator.next_index(salt, self.size);
+ assert!(bit_index < self.size as usize);
+ let byte_index = bit_index / 8;
+ let mask = 1 << (bit_index % 8);
+ if self.data[byte_index] & mask == 0 {
+ return false;
+ }
+ }
+ true
+ }
+
+ #[cfg(feature = "builder")]
+ fn insert(&mut self, generator: &mut CascadeIndexGenerator, salt: &[u8]) {
+ for _ in 0..self.n_hash_funcs {
+ let bit_index = generator.next_index(salt, self.size);
+ let byte_index = bit_index / 8;
+ let mask = 1 << (bit_index % 8);
+ self.data[byte_index] |= mask;
+ }
+ }
+
+ pub fn approximate_size_of(&self) -> usize {
+ size_of::<Bloom>() + self.data.len()
+ }
+}
+
+impl fmt::Display for Bloom {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(f, "n_hash_funcs={} size={}", self.n_hash_funcs, self.size)
+ }
+}
+
+/// A multi-layer cascading Bloom filter.
+pub struct Cascade {
+ /// The Bloom filter for this layer in the cascade
+ filters: Vec<Bloom>,
+ /// The salt in use, if any
+ salt: Vec<u8>,
+ /// The hash algorithm / index generating function to use
+ hash_algorithm: HashAlgorithm,
+ /// Whether the logic should be inverted
+ inverted: bool,
+}
+
+impl Cascade {
+ /// from_bytes attempts to decode and return a multi-layer cascading Bloom filter.
+ ///
+ /// # Arguments
+ /// `bytes` - The encoded representation of the Bloom filters in this cascade. Starts with 2
+ /// little endian bytes indicating the version. The current version is 2. The Python
+ /// filter-cascade project defines the formats, see
+ /// <https://github.com/mozilla/filter-cascade/blob/v0.3.0/filtercascade/fileformats.py>
+ ///
+ /// May be of length 0, in which case `None` is returned.
+ pub fn from_bytes(bytes: Vec<u8>) -> Result<Option<Self>, CascadeError> {
+ if bytes.is_empty() {
+ return Ok(None);
+ }
+ let mut reader = bytes.as_slice();
+ let version = reader
+ .read_u16::<byteorder::LittleEndian>()
+ .or(Err(CascadeError::Parse("truncated at version")))?;
+
+ let mut filters = vec![];
+ let mut salt = vec![];
+ let mut top_hash_alg = None;
+ let mut inverted = false;
+
+ if version > 2 {
+ return Err(CascadeError::Parse("unknown version"));
+ }
+
+ if version == 2 {
+ let inverted_val = reader
+ .read_u8()
+ .or(Err(CascadeError::Parse("truncated at inverted")))?;
+ if inverted_val > 1 {
+ return Err(CascadeError::Parse("invalid value for inverted"));
+ }
+ inverted = 0 != inverted_val;
+ let salt_len: usize = reader
+ .read_u8()
+ .or(Err(CascadeError::Parse("truncated at salt length")))?
+ .into();
+ if salt_len >= 256 {
+ return Err(CascadeError::Parse("salt too long"));
+ }
+ if salt_len > 0 {
+ let mut salt_bytes = vec![0; salt_len];
+ reader
+ .read_exact(&mut salt_bytes)
+ .or(Err(CascadeError::Parse("truncated at salt")))?;
+ salt = salt_bytes;
+ }
+ }
+
+ while let Some((filter, layer_number, layer_hash_alg)) = Bloom::read(&mut reader)? {
+ filters.push(filter);
+
+ if layer_number != filters.len() {
+ return Err(CascadeError::Parse("irregular layer numbering"));
+ }
+
+ if *top_hash_alg.get_or_insert(layer_hash_alg) != layer_hash_alg {
+ return Err(CascadeError::Parse("Inconsistent hash algorithms"));
+ }
+ }
+
+ if filters.is_empty() {
+ return Err(CascadeError::Parse("missing filters"));
+ }
+
+ let hash_algorithm = top_hash_alg.ok_or(CascadeError::Parse("missing hash algorithm"))?;
+
+ Ok(Some(Cascade {
+ filters,
+ salt,
+ hash_algorithm,
+ inverted,
+ }))
+ }
+
+ /// to_bytes encodes a cascade in the version 2 format.
+ pub fn to_bytes(&self) -> Result<Vec<u8>, CascadeError> {
+ if self.salt.len() >= 256 {
+ return Err(CascadeError::LongSalt);
+ }
+ if self.filters.len() >= 255 {
+ return Err(CascadeError::TooManyLayers);
+ }
+ let mut out = vec![];
+ let version: u16 = 2;
+ let inverted: u8 = self.inverted.into();
+ let salt_len: u8 = self.salt.len() as u8;
+ let hash_alg: u8 = self.hash_algorithm as u8;
+ out.extend_from_slice(&version.to_le_bytes());
+ out.push(inverted);
+ out.push(salt_len);
+ out.extend_from_slice(&self.salt);
+ for (layer, bloom) in self.filters.iter().enumerate() {
+ out.push(hash_alg);
+ out.extend_from_slice(&bloom.size.to_le_bytes());
+ out.extend_from_slice(&bloom.n_hash_funcs.to_le_bytes());
+ out.push((1 + layer) as u8); // 1-indexed
+ out.extend_from_slice(&bloom.data);
+ }
+ Ok(out)
+ }
+
+ /// has determines if the given sequence of bytes is in the cascade.
+ ///
+ /// # Arguments
+ /// `entry` - The bytes to query
+ pub fn has(&self, entry: Vec<u8>) -> bool {
+ // Query filters 0..self.filters.len() until we get a non-membership result.
+ // If this occurs at an even index filter, the element *is not* included.
+ // ... at an odd-index filter, the element *is* included.
+ let mut generator = CascadeIndexGenerator::new(self.hash_algorithm, entry);
+ let mut rv = false;
+ for filter in &self.filters {
+ if filter.has(&mut generator, &self.salt) {
+ rv = !rv;
+ generator.next_layer();
+ } else {
+ break;
+ }
+ }
+ if self.inverted {
+ rv = !rv;
+ }
+ rv
+ }
+
+ pub fn invert(&mut self) {
+ self.inverted = !self.inverted;
+ }
+
+ /// Determine the approximate amount of memory in bytes used by this
+ /// Cascade. Because this implementation does not integrate with the
+ /// allocator, it can't get an accurate measurement of how much memory it
+ /// uses. However, it can make a reasonable guess, assuming the sizes of
+ /// the bloom filters are large enough to dominate the overall allocated
+ /// size.
+ pub fn approximate_size_of(&self) -> usize {
+ size_of::<Cascade>()
+ + self
+ .filters
+ .iter()
+ .map(|x| x.approximate_size_of())
+ .sum::<usize>()
+ + self.salt.len()
+ }
+}
+
+impl fmt::Display for Cascade {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ writeln!(
+ f,
+ "salt={:?} inverted={} hash_algorithm={}",
+ self.salt, self.inverted, self.hash_algorithm,
+ )?;
+ for filter in &self.filters {
+ writeln!(f, "\t[{}]", filter)?;
+ }
+ Ok(())
+ }
+}
+
+/// A CascadeBuilder creates a Cascade with layers given by `Bloom::new_crlite_bloom`.
+///
+/// A builder is initialized using [`CascadeBuilder::default`] or [`CascadeBuilder::new`]. Prefer `default`. The `new` constructor
+/// allows the user to specify sensitive internal details such as the hash function and the domain
+/// separation parameter.
+///
+/// Both constructors take `include_capacity` and an `exclude_capacity` parameters. The
+/// `include_capacity` is the number of elements that will be encoded in the Cascade. The
+/// `exclude_capacity` is size of the complement of the encoded set.
+///
+/// The encoded set is specified through calls to [`CascadeBuilder::include`]. Its complement is specified through
+/// calls to [`CascadeBuilder::exclude`]. The cascade is built with a call to [`CascadeBuilder::finalize`].
+///
+/// The builder will track of the number of calls to `include` and `exclude`.
+/// The caller is responsible for making *exactly* `include_capacity` calls to `include`
+/// followed by *exactly* `exclude_capacity` calls to `exclude`.
+/// Calling `exclude` before all `include` calls have been made will result in a panic!().
+/// Calling `finalize` before all `exclude` calls have been made will result in a panic!().
+///
+#[cfg(feature = "builder")]
+pub struct CascadeBuilder {
+ filters: Vec<Bloom>,
+ salt: Vec<u8>,
+ hash_algorithm: HashAlgorithm,
+ to_include: Vec<CascadeIndexGenerator>,
+ to_exclude: Vec<CascadeIndexGenerator>,
+ status: BuildStatus,
+}
+
+#[cfg(feature = "builder")]
+impl CascadeBuilder {
+ pub fn default(include_capacity: usize, exclude_capacity: usize) -> Self {
+ let mut salt = vec![0u8; 16];
+ OsRng.fill_bytes(&mut salt);
+ CascadeBuilder::new(
+ HashAlgorithm::Sha256,
+ salt,
+ include_capacity,
+ exclude_capacity,
+ )
+ }
+
+ pub fn new(
+ hash_algorithm: HashAlgorithm,
+ salt: Vec<u8>,
+ include_capacity: usize,
+ exclude_capacity: usize,
+ ) -> Self {
+ CascadeBuilder {
+ filters: vec![Bloom::new_crlite_bloom(
+ include_capacity,
+ exclude_capacity,
+ true,
+ )],
+ salt,
+ to_include: vec![],
+ to_exclude: vec![],
+ hash_algorithm,
+ status: BuildStatus(include_capacity, exclude_capacity),
+ }
+ }
+
+ pub fn include(&mut self, item: Vec<u8>) -> Result<(), CascadeError> {
+ match self.status {
+ BuildStatus(ref mut cap, _) if *cap > 0 => *cap -= 1,
+ _ => return Err(CascadeError::CapacityViolation("include")),
+ }
+ let mut generator = CascadeIndexGenerator::new(self.hash_algorithm, item);
+ self.filters[0].insert(&mut generator, &self.salt);
+ self.to_include.push(generator);
+
+ Ok(())
+ }
+
+ pub fn exclude(&mut self, item: Vec<u8>) -> Result<(), CascadeError> {
+ match self.status {
+ BuildStatus(0, ref mut cap) if *cap > 0 => *cap -= 1,
+ _ => return Err(CascadeError::CapacityViolation("exclude")),
+ }
+ let mut generator = CascadeIndexGenerator::new(self.hash_algorithm, item);
+ if self.filters[0].has(&mut generator, &self.salt) {
+ self.to_exclude.push(generator);
+ }
+ Ok(())
+ }
+
+ /// `exclude_threaded` is like `exclude` but it stores false positives in a caller-owned
+ /// `ExcludeSet`. This allows the caller to exclude items in parallel.
+ pub fn exclude_threaded(&self, exclude_set: &mut ExcludeSet, item: Vec<u8>) {
+ exclude_set.size += 1;
+ let mut generator = CascadeIndexGenerator::new(self.hash_algorithm, item);
+ if self.filters[0].has(&mut generator, &self.salt) {
+ exclude_set.set.push(generator);
+ }
+ }
+
+ /// `collect_exclude_set` merges an `ExcludeSet` into the internal storage of the CascadeBuilder.
+ pub fn collect_exclude_set(
+ &mut self,
+ exclude_set: &mut ExcludeSet,
+ ) -> Result<(), CascadeError> {
+ match self.status {
+ BuildStatus(0, ref mut cap) if *cap >= exclude_set.size => *cap -= exclude_set.size,
+ _ => return Err(CascadeError::CapacityViolation("exclude")),
+ }
+ self.to_exclude.append(&mut exclude_set.set);
+
+ Ok(())
+ }
+
+ fn push_layer(&mut self) -> Result<(), CascadeError> {
+ // At even layers we encode elements of to_include. At odd layers we encode elements of
+ // to_exclude. In both cases, we track false positives by filtering the complement of the
+ // encoded set through the newly produced bloom filter.
+ let at_even_layer = self.filters.len() % 2 == 0;
+ let (to_encode, to_filter) = match at_even_layer {
+ true => (&mut self.to_include, &mut self.to_exclude),
+ false => (&mut self.to_exclude, &mut self.to_include),
+ };
+
+ // split ownership of `salt` away from `to_encode` and `to_filter`
+ // We need an immutable reference to salt during `to_encode.iter_mut()`
+ let mut bloom = Bloom::new_crlite_bloom(to_encode.len(), to_filter.len(), false);
+
+ let salt = self.salt.as_slice();
+
+ to_encode.iter_mut().for_each(|x| {
+ x.next_layer();
+ bloom.insert(x, salt)
+ });
+
+ let mut delta = to_filter.len();
+ to_filter.retain_mut(|x| {
+ x.next_layer();
+ bloom.has(x, salt)
+ });
+ delta -= to_filter.len();
+
+ if delta == 0 {
+ // Check for collisions between the |to_encode| and |to_filter| sets.
+ // The implementation of PartialEq for CascadeIndexGenerator will successfully
+ // identify cases where the user called |include(item)| and |exclude(item)| for the
+ // same item. It will not identify collisions in the underlying hash function.
+ for x in to_encode.iter_mut() {
+ if to_filter.contains(x) {
+ return Err(CascadeError::Collision);
+ }
+ }
+ }
+
+ self.filters.push(bloom);
+ Ok(())
+ }
+
+ pub fn finalize(mut self) -> Result<Box<Cascade>, CascadeError> {
+ match self.status {
+ BuildStatus(0, 0) => (),
+ _ => return Err(CascadeError::CapacityViolation("finalize")),
+ }
+
+ loop {
+ if self.to_exclude.is_empty() {
+ break;
+ }
+ self.push_layer()?;
+
+ if self.to_include.is_empty() {
+ break;
+ }
+ self.push_layer()?;
+ }
+
+ Ok(Box::new(Cascade {
+ filters: self.filters,
+ salt: self.salt,
+ hash_algorithm: self.hash_algorithm,
+ inverted: false,
+ }))
+ }
+}
+
+/// BuildStatus is used to ensure that the `include`, `exclude`, and `finalize` calls to
+/// CascadeBuilder are made in the right order. The (a,b) state indicates that the
+/// CascadeBuilder is waiting for `a` calls to `include` and `b` calls to `exclude`.
+#[cfg(feature = "builder")]
+struct BuildStatus(usize, usize);
+
+/// CascadeBuilder::exclude takes `&mut self` so that it can count exclusions and push items to
+/// self.to_exclude. The bulk of the work it does, however, can be done with an immutable reference
+/// to the top level bloom filter. An `ExcludeSet` is used by `CascadeBuilder::exclude_threaded` to
+/// track the changes to a `CascadeBuilder` that would be made with a call to
+/// `CascadeBuilder::exclude`.
+#[cfg(feature = "builder")]
+#[derive(Default)]
+pub struct ExcludeSet {
+ size: usize,
+ set: Vec<CascadeIndexGenerator>,
+}
+
+#[cfg(test)]
+mod tests {
+ use Bloom;
+ use Cascade;
+ #[cfg(feature = "builder")]
+ use CascadeBuilder;
+ #[cfg(feature = "builder")]
+ use CascadeError;
+ use CascadeIndexGenerator;
+ #[cfg(feature = "builder")]
+ use ExcludeSet;
+ use HashAlgorithm;
+
+ #[test]
+ fn bloom_v1_test_from_bytes() {
+ let src: Vec<u8> = vec![
+ 0x01, 0x09, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x41, 0x00,
+ ];
+ let mut reader = src.as_slice();
+
+ match Bloom::read(&mut reader) {
+ Ok(Some((bloom, 1, HashAlgorithm::MurmurHash3))) => {
+ assert!(bloom.has(
+ &mut CascadeIndexGenerator::new(HashAlgorithm::MurmurHash3, b"this".to_vec()),
+ &vec![]
+ ));
+ assert!(bloom.has(
+ &mut CascadeIndexGenerator::new(HashAlgorithm::MurmurHash3, b"that".to_vec()),
+ &vec![]
+ ));
+ assert!(!bloom.has(
+ &mut CascadeIndexGenerator::new(HashAlgorithm::MurmurHash3, b"other".to_vec()),
+ &vec![]
+ ));
+ }
+ Ok(_) => panic!("Parsing failed"),
+ Err(_) => panic!("Parsing failed"),
+ };
+ assert!(reader.is_empty());
+
+ let short: Vec<u8> = vec![
+ 0x01, 0x09, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x41,
+ ];
+ assert!(Bloom::read(&mut short.as_slice()).is_err());
+
+ let empty: Vec<u8> = Vec::new();
+ let mut reader = empty.as_slice();
+ match Bloom::read(&mut reader) {
+ Ok(should_be_none) => assert!(should_be_none.is_none()),
+ Err(_) => panic!("Parsing failed"),
+ };
+ }
+
+ #[test]
+ fn bloom_v3_unsupported() {
+ let src: Vec<u8> = vec![0x03, 0x01, 0x00];
+ assert!(Bloom::read(&mut src.as_slice()).is_err());
+ }
+
+ #[test]
+ fn cascade_v1_murmur_from_file_bytes_test() {
+ let v = include_bytes!("../test_data/test_v1_murmur_mlbf").to_vec();
+ let cascade = Cascade::from_bytes(v)
+ .expect("parsing Cascade should succeed")
+ .expect("Cascade should be Some");
+ // Key format is SHA256(issuer SPKI) + serial number
+ let key_for_revoked_cert_1 = vec![
+ 0x2e, 0xb2, 0xd5, 0xa8, 0x60, 0xfe, 0x50, 0xe9, 0xc2, 0x42, 0x36, 0x85, 0x52, 0x98,
+ 0x01, 0x50, 0xe4, 0x5d, 0xb5, 0x32, 0x1a, 0x5b, 0x00, 0x5e, 0x26, 0xd6, 0x76, 0x25,
+ 0x3a, 0x40, 0x9b, 0xf5, 0x06, 0x2d, 0xf5, 0x68, 0xa0, 0x51, 0x31, 0x08, 0x20, 0xd7,
+ 0xec, 0x43, 0x27, 0xe1, 0xba, 0xfd,
+ ];
+ assert!(cascade.has(key_for_revoked_cert_1));
+ let key_for_revoked_cert_2 = vec![
+ 0xf1, 0x1c, 0x3d, 0xd0, 0x48, 0xf7, 0x4e, 0xdb, 0x7c, 0x45, 0x19, 0x2b, 0x83, 0xe5,
+ 0x98, 0x0d, 0x2f, 0x67, 0xec, 0x84, 0xb4, 0xdd, 0xb9, 0x39, 0x6e, 0x33, 0xff, 0x51,
+ 0x73, 0xed, 0x69, 0x8f, 0x00, 0xd2, 0xe8, 0xf6, 0xaa, 0x80, 0x48, 0x1c, 0xd4,
+ ];
+ assert!(cascade.has(key_for_revoked_cert_2));
+ let key_for_valid_cert = vec![
+ 0x99, 0xfc, 0x9d, 0x40, 0xf1, 0xad, 0xb1, 0x63, 0x65, 0x61, 0xa6, 0x1d, 0x68, 0x3d,
+ 0x9e, 0xa6, 0xb4, 0x60, 0xc5, 0x7d, 0x0c, 0x75, 0xea, 0x00, 0xc3, 0x41, 0xb9, 0xdf,
+ 0xb9, 0x0b, 0x5f, 0x39, 0x0b, 0x77, 0x75, 0xf7, 0xaf, 0x9a, 0xe5, 0x42, 0x65, 0xc9,
+ 0xcd, 0x32, 0x57, 0x10, 0x77, 0x8e,
+ ];
+ assert!(!cascade.has(key_for_valid_cert));
+
+ assert_eq!(cascade.approximate_size_of(), 15408);
+
+ let v = include_bytes!("../test_data/test_v1_murmur_short_mlbf").to_vec();
+ assert!(Cascade::from_bytes(v).is_err());
+ }
+
+ #[test]
+ fn cascade_v2_sha256l32_from_file_bytes_test() {
+ let v = include_bytes!("../test_data/test_v2_sha256l32_mlbf").to_vec();
+ let cascade = Cascade::from_bytes(v)
+ .expect("parsing Cascade should succeed")
+ .expect("Cascade should be Some");
+
+ assert!(cascade.salt.len() == 0);
+ assert!(cascade.inverted == false);
+ assert!(cascade.has(b"this".to_vec()) == true);
+ assert!(cascade.has(b"that".to_vec()) == true);
+ assert!(cascade.has(b"other".to_vec()) == false);
+ assert_eq!(cascade.approximate_size_of(), 1001);
+ }
+
+ #[test]
+ fn cascade_v2_sha256l32_with_salt_from_file_bytes_test() {
+ let v = include_bytes!("../test_data/test_v2_sha256l32_salt_mlbf").to_vec();
+ let cascade = Cascade::from_bytes(v)
+ .expect("parsing Cascade should succeed")
+ .expect("Cascade should be Some");
+
+ assert!(cascade.salt == b"nacl".to_vec());
+ assert!(cascade.inverted == false);
+ assert!(cascade.has(b"this".to_vec()) == true);
+ assert!(cascade.has(b"that".to_vec()) == true);
+ assert!(cascade.has(b"other".to_vec()) == false);
+ assert_eq!(cascade.approximate_size_of(), 1001);
+ }
+
+ #[test]
+ fn cascade_v2_murmur_from_file_bytes_test() {
+ let v = include_bytes!("../test_data/test_v2_murmur_mlbf").to_vec();
+ let cascade = Cascade::from_bytes(v)
+ .expect("parsing Cascade should succeed")
+ .expect("Cascade should be Some");
+
+ assert!(cascade.salt.len() == 0);
+ assert!(cascade.inverted == false);
+ assert!(cascade.has(b"this".to_vec()) == true);
+ assert!(cascade.has(b"that".to_vec()) == true);
+ assert!(cascade.has(b"other".to_vec()) == false);
+ assert_eq!(cascade.approximate_size_of(), 992);
+ }
+
+ #[test]
+ fn cascade_v2_murmur_inverted_from_file_bytes_test() {
+ let v = include_bytes!("../test_data/test_v2_murmur_inverted_mlbf").to_vec();
+ let cascade = Cascade::from_bytes(v)
+ .expect("parsing Cascade should succeed")
+ .expect("Cascade should be Some");
+
+ assert!(cascade.salt.len() == 0);
+ assert!(cascade.inverted == true);
+ assert!(cascade.has(b"this".to_vec()) == true);
+ assert!(cascade.has(b"that".to_vec()) == true);
+ assert!(cascade.has(b"other".to_vec()) == false);
+ assert_eq!(cascade.approximate_size_of(), 1058);
+ }
+
+ #[test]
+ fn cascade_v2_sha256l32_inverted_from_file_bytes_test() {
+ let v = include_bytes!("../test_data/test_v2_sha256l32_inverted_mlbf").to_vec();
+ let cascade = Cascade::from_bytes(v)
+ .expect("parsing Cascade should succeed")
+ .expect("Cascade should be Some");
+
+ assert!(cascade.salt.len() == 0);
+ assert!(cascade.inverted == true);
+ assert!(cascade.has(b"this".to_vec()) == true);
+ assert!(cascade.has(b"that".to_vec()) == true);
+ assert!(cascade.has(b"other".to_vec()) == false);
+ assert_eq!(cascade.approximate_size_of(), 1061);
+ }
+
+ #[test]
+ fn cascade_v2_sha256ctr_from_file_bytes_test() {
+ let v = include_bytes!("../test_data/test_v2_sha256ctr_salt_mlbf").to_vec();
+ let cascade = Cascade::from_bytes(v)
+ .expect("parsing Cascade should succeed")
+ .expect("Cascade should be Some");
+
+ assert!(cascade.salt == b"nacl".to_vec());
+ assert!(cascade.inverted == false);
+ assert!(cascade.has(b"this".to_vec()) == true);
+ assert!(cascade.has(b"that".to_vec()) == true);
+ assert!(cascade.has(b"other".to_vec()) == false);
+ assert_eq!(cascade.approximate_size_of(), 1070);
+ }
+
+ #[test]
+ fn cascade_empty() {
+ let cascade = Cascade::from_bytes(Vec::new()).expect("parsing Cascade should succeed");
+ assert!(cascade.is_none());
+ }
+
+ #[test]
+ fn cascade_test_from_bytes() {
+ let unknown_version: Vec<u8> = vec![0xff, 0xff, 0x00, 0x00];
+ match Cascade::from_bytes(unknown_version) {
+ Ok(_) => panic!("Cascade::from_bytes allows unknown version."),
+ Err(_) => (),
+ }
+
+ let first_layer_is_zero: Vec<u8> = vec![
+ 0x01, 0x00, 0x01, 0x08, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00,
+ ];
+ match Cascade::from_bytes(first_layer_is_zero) {
+ Ok(_) => panic!("Cascade::from_bytes allows zero indexed layers."),
+ Err(_) => (),
+ }
+
+ let second_layer_is_three: Vec<u8> = vec![
+ 0x01, 0x00, 0x01, 0x08, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01,
+ 0x08, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00,
+ ];
+ match Cascade::from_bytes(second_layer_is_three) {
+ Ok(_) => panic!("Cascade::from_bytes allows non-sequential layers."),
+ Err(_) => (),
+ }
+ }
+
+ #[test]
+ #[cfg(feature = "builder")]
+ fn cascade_builder_test_collision() {
+ let mut builder = CascadeBuilder::default(1, 1);
+ builder.include(b"collision!".to_vec()).ok();
+ builder.exclude(b"collision!".to_vec()).ok();
+ assert!(matches!(builder.finalize(), Err(CascadeError::Collision)));
+ }
+
+ #[test]
+ #[cfg(feature = "builder")]
+ fn cascade_builder_test_exclude_too_few() {
+ let mut builder = CascadeBuilder::default(1, 1);
+ builder.include(b"1".to_vec()).ok();
+ assert!(matches!(
+ builder.finalize(),
+ Err(CascadeError::CapacityViolation(_))
+ ));
+ }
+
+ #[test]
+ #[cfg(feature = "builder")]
+ fn cascade_builder_test_include_too_few() {
+ let mut builder = CascadeBuilder::default(1, 1);
+ assert!(matches!(
+ builder.exclude(b"1".to_vec()),
+ Err(CascadeError::CapacityViolation(_))
+ ));
+ }
+
+ #[test]
+ #[cfg(feature = "builder")]
+ fn cascade_builder_test_include_too_many() {
+ let mut builder = CascadeBuilder::default(1, 1);
+ builder.include(b"1".to_vec()).ok();
+ assert!(matches!(
+ builder.include(b"2".to_vec()),
+ Err(CascadeError::CapacityViolation(_))
+ ));
+ }
+
+ #[test]
+ #[cfg(feature = "builder")]
+ fn cascade_builder_test_exclude_too_many() {
+ let mut builder = CascadeBuilder::default(1, 1);
+ builder.include(b"1".to_vec()).ok();
+ builder.exclude(b"2".to_vec()).ok();
+ assert!(matches!(
+ builder.exclude(b"3".to_vec()),
+ Err(CascadeError::CapacityViolation(_))
+ ));
+ }
+
+ #[test]
+ #[cfg(feature = "builder")]
+ fn cascade_builder_test_exclude_threaded_no_collect() {
+ let mut builder = CascadeBuilder::default(1, 3);
+ let mut exclude_set = ExcludeSet::default();
+ builder.include(b"1".to_vec()).ok();
+ builder.exclude_threaded(&mut exclude_set, b"2".to_vec());
+ builder.exclude_threaded(&mut exclude_set, b"3".to_vec());
+ builder.exclude_threaded(&mut exclude_set, b"4".to_vec());
+ assert!(matches!(
+ builder.finalize(),
+ Err(CascadeError::CapacityViolation(_))
+ ));
+ }
+
+ #[test]
+ #[cfg(feature = "builder")]
+ fn cascade_builder_test_exclude_threaded_too_many() {
+ let mut builder = CascadeBuilder::default(1, 3);
+ let mut exclude_set = ExcludeSet::default();
+ builder.include(b"1".to_vec()).ok();
+ builder.exclude_threaded(&mut exclude_set, b"2".to_vec());
+ builder.exclude_threaded(&mut exclude_set, b"3".to_vec());
+ builder.exclude_threaded(&mut exclude_set, b"4".to_vec());
+ builder.exclude_threaded(&mut exclude_set, b"5".to_vec());
+ assert!(matches!(
+ builder.collect_exclude_set(&mut exclude_set),
+ Err(CascadeError::CapacityViolation(_))
+ ));
+ }
+
+ #[test]
+ #[cfg(feature = "builder")]
+ fn cascade_builder_test_exclude_threaded() {
+ let mut builder = CascadeBuilder::default(1, 3);
+ let mut exclude_set = ExcludeSet::default();
+ builder.include(b"1".to_vec()).ok();
+ builder.exclude_threaded(&mut exclude_set, b"2".to_vec());
+ builder.exclude_threaded(&mut exclude_set, b"3".to_vec());
+ builder.exclude_threaded(&mut exclude_set, b"4".to_vec());
+ builder.collect_exclude_set(&mut exclude_set).ok();
+ builder.finalize().ok();
+ }
+
+ #[cfg(feature = "builder")]
+ fn cascade_builder_test_generate(hash_alg: HashAlgorithm, inverted: bool) {
+ let total = 10_000_usize;
+ let included = 100_usize;
+
+ let salt = vec![0u8; 16];
+ let mut builder =
+ CascadeBuilder::new(hash_alg, salt, included, (total - included) as usize);
+ for i in 0..included {
+ builder.include(i.to_le_bytes().to_vec()).ok();
+ }
+ for i in included..total {
+ builder.exclude(i.to_le_bytes().to_vec()).ok();
+ }
+ let mut cascade = builder.finalize().unwrap();
+
+ if inverted {
+ cascade.invert()
+ }
+
+ // Ensure we can serialize / deserialize
+ let cascade_bytes = cascade.to_bytes().expect("failed to serialize cascade");
+
+ let cascade = Cascade::from_bytes(cascade_bytes)
+ .expect("failed to deserialize cascade")
+ .expect("cascade should not be None here");
+
+ // Ensure each query gives the correct result
+ for i in 0..included {
+ assert!(cascade.has(i.to_le_bytes().to_vec()) == true ^ inverted)
+ }
+ for i in included..total {
+ assert!(cascade.has(i.to_le_bytes().to_vec()) == false ^ inverted)
+ }
+ }
+
+ #[test]
+ #[cfg(feature = "builder")]
+ fn cascade_builder_test_generate_murmurhash3_inverted() {
+ cascade_builder_test_generate(HashAlgorithm::MurmurHash3, true);
+ }
+
+ #[test]
+ #[cfg(feature = "builder")]
+ fn cascade_builder_test_generate_murmurhash3() {
+ cascade_builder_test_generate(HashAlgorithm::MurmurHash3, false);
+ }
+
+ #[test]
+ #[cfg(feature = "builder")]
+ fn cascade_builder_test_generate_sha256l32() {
+ cascade_builder_test_generate(HashAlgorithm::Sha256l32, false);
+ }
+
+ #[test]
+ #[cfg(feature = "builder")]
+ fn cascade_builder_test_generate_sha256() {
+ cascade_builder_test_generate(HashAlgorithm::Sha256, false);
+ }
+}