diff options
Diffstat (limited to 'vendor/ruzstd/src/fse/fse_decoder.rs')
-rw-r--r-- | vendor/ruzstd/src/fse/fse_decoder.rs | 336 |
1 files changed, 336 insertions, 0 deletions
diff --git a/vendor/ruzstd/src/fse/fse_decoder.rs b/vendor/ruzstd/src/fse/fse_decoder.rs new file mode 100644 index 000000000..1847da728 --- /dev/null +++ b/vendor/ruzstd/src/fse/fse_decoder.rs @@ -0,0 +1,336 @@ +use crate::decoding::bit_reader::BitReader; +use crate::decoding::bit_reader_reverse::{BitReaderReversed, GetBitsError}; + +#[derive(Clone)] +pub struct FSETable { + pub decode: Vec<Entry>, //used to decode symbols, and calculate the next state + + pub accuracy_log: u8, + pub symbol_probabilities: Vec<i32>, //used while building the decode Vector + symbol_counter: Vec<u32>, +} + +impl Default for FSETable { + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +pub enum FSETableError { + #[error("Acclog must be at least 1")] + AccLogIsZero, + #[error("Found FSE acc_log: {got} bigger than allowed maximum in this case: {max}")] + AccLogTooBig { got: u8, max: u8 }, + #[error(transparent)] + GetBitsError(#[from] GetBitsError), + #[error("The counter ({got}) exceeded the expected sum: {expected_sum}. This means an error or corrupted data \n {symbol_probabilities:?}")] + ProbabilityCounterMismatch { + got: u32, + expected_sum: u32, + symbol_probabilities: Vec<i32>, + }, + #[error("There are too many symbols in this distribution: {got}. Max: 256")] + TooManySymbols { got: usize }, +} + +pub struct FSEDecoder<'table> { + pub state: Entry, + table: &'table FSETable, +} + +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +pub enum FSEDecoderError { + #[error(transparent)] + GetBitsError(#[from] GetBitsError), + #[error("Tried to use an uninitialized table!")] + TableIsUninitialized, +} + +#[derive(Copy, Clone)] +pub struct Entry { + pub base_line: u32, + pub num_bits: u8, + pub symbol: u8, +} + +const ACC_LOG_OFFSET: u8 = 5; + +fn highest_bit_set(x: u32) -> u32 { + assert!(x > 0); + u32::BITS - x.leading_zeros() +} + +impl<'t> FSEDecoder<'t> { + pub fn new(table: &'t FSETable) -> FSEDecoder<'_> { + FSEDecoder { + state: table.decode.get(0).copied().unwrap_or(Entry { + base_line: 0, + num_bits: 0, + symbol: 0, + }), + table, + } + } + + pub fn decode_symbol(&self) -> u8 { + self.state.symbol + } + + pub fn init_state(&mut self, bits: &mut BitReaderReversed<'_>) -> Result<(), FSEDecoderError> { + if self.table.accuracy_log == 0 { + return Err(FSEDecoderError::TableIsUninitialized); + } + self.state = self.table.decode[bits.get_bits(self.table.accuracy_log)? as usize]; + + Ok(()) + } + + pub fn update_state( + &mut self, + bits: &mut BitReaderReversed<'_>, + ) -> Result<(), FSEDecoderError> { + let num_bits = self.state.num_bits; + let add = bits.get_bits(num_bits)?; + let base_line = self.state.base_line; + let new_state = base_line + add as u32; + self.state = self.table.decode[new_state as usize]; + + //println!("Update: {}, {} -> {}", base_line, add, self.state); + Ok(()) + } +} + +impl FSETable { + pub fn new() -> FSETable { + FSETable { + symbol_probabilities: Vec::with_capacity(256), //will never be more than 256 symbols because u8 + symbol_counter: Vec::with_capacity(256), //will never be more than 256 symbols because u8 + decode: Vec::new(), //depending on acc_log. + accuracy_log: 0, + } + } + + pub fn reset(&mut self) { + self.symbol_counter.clear(); + self.symbol_probabilities.clear(); + self.decode.clear(); + self.accuracy_log = 0; + } + + //returns how many BYTEs (not bits) were read while building the decoder + pub fn build_decoder(&mut self, source: &[u8], max_log: u8) -> Result<usize, FSETableError> { + self.accuracy_log = 0; + + let bytes_read = self.read_probabilities(source, max_log)?; + self.build_decoding_table(); + + Ok(bytes_read) + } + + pub fn build_from_probabilities( + &mut self, + acc_log: u8, + probs: &[i32], + ) -> Result<(), FSETableError> { + if acc_log == 0 { + return Err(FSETableError::AccLogIsZero); + } + self.symbol_probabilities = probs.to_vec(); + self.accuracy_log = acc_log; + self.build_decoding_table(); + Ok(()) + } + + fn build_decoding_table(&mut self) { + self.decode.clear(); + + let table_size = 1 << self.accuracy_log; + if self.decode.len() < table_size { + self.decode.reserve(table_size - self.decode.len()); + } + //fill with dummy entries + self.decode.resize( + table_size, + Entry { + base_line: 0, + num_bits: 0, + symbol: 0, + }, + ); + + let mut negative_idx = table_size; //will point to the highest index with is already occupied by a negative-probability-symbol + + //first scan for all -1 probabilities and place them at the top of the table + for symbol in 0..self.symbol_probabilities.len() { + if self.symbol_probabilities[symbol] == -1 { + negative_idx -= 1; + let entry = &mut self.decode[negative_idx]; + entry.symbol = symbol as u8; + entry.base_line = 0; + entry.num_bits = self.accuracy_log; + } + } + + //then place in a semi-random order all of the other symbols + let mut position = 0; + for idx in 0..self.symbol_probabilities.len() { + let symbol = idx as u8; + if self.symbol_probabilities[idx] <= 0 { + continue; + } + + //for each probability point the symbol gets on slot + let prob = self.symbol_probabilities[idx]; + for _ in 0..prob { + let entry = &mut self.decode[position]; + entry.symbol = symbol; + + position = next_position(position, table_size); + while position >= negative_idx { + position = next_position(position, table_size); + //everything above negative_idx is already taken + } + } + } + + // baselines and num_bits can only be calculated when all symbols have been spread + self.symbol_counter.clear(); + self.symbol_counter + .resize(self.symbol_probabilities.len(), 0); + for idx in 0..negative_idx { + let entry = &mut self.decode[idx]; + let symbol = entry.symbol; + let prob = self.symbol_probabilities[symbol as usize]; + + let symbol_count = self.symbol_counter[symbol as usize]; + let (bl, nb) = calc_baseline_and_numbits(table_size as u32, prob as u32, symbol_count); + + //println!("symbol: {:2}, table: {}, prob: {:3}, count: {:3}, bl: {:3}, nb: {:2}", symbol, table_size, prob, symbol_count, bl, nb); + + assert!(nb <= self.accuracy_log); + self.symbol_counter[symbol as usize] += 1; + + entry.base_line = bl; + entry.num_bits = nb; + } + } + + fn read_probabilities(&mut self, source: &[u8], max_log: u8) -> Result<usize, FSETableError> { + self.symbol_probabilities.clear(); //just clear, we will fill a probability for each entry anyways. No need to force new allocs here + + let mut br = BitReader::new(source); + self.accuracy_log = ACC_LOG_OFFSET + (br.get_bits(4)? as u8); + if self.accuracy_log > max_log { + return Err(FSETableError::AccLogTooBig { + got: self.accuracy_log, + max: max_log, + }); + } + if self.accuracy_log == 0 { + return Err(FSETableError::AccLogIsZero); + } + + let probablility_sum = 1 << self.accuracy_log; + let mut probability_counter = 0; + + while probability_counter < probablility_sum { + let max_remaining_value = probablility_sum - probability_counter + 1; + let bits_to_read = highest_bit_set(max_remaining_value); + + let unchecked_value = br.get_bits(bits_to_read as usize)? as u32; + + let low_threshold = ((1 << bits_to_read) - 1) - (max_remaining_value); + let mask = (1 << (bits_to_read - 1)) - 1; + let small_value = unchecked_value & mask; + + let value = if small_value < low_threshold { + br.return_bits(1); + small_value + } else if unchecked_value > mask { + unchecked_value - low_threshold + } else { + unchecked_value + }; + //println!("{}, {}, {}", self.symbol_probablilities.len(), unchecked_value, value); + + let prob = (value as i32) - 1; + + self.symbol_probabilities.push(prob); + if prob != 0 { + if prob > 0 { + probability_counter += prob as u32; + } else { + // probability -1 counts as 1 + assert!(prob == -1); + probability_counter += 1; + } + } else { + //fast skip further zero probabilities + loop { + let skip_amount = br.get_bits(2)? as usize; + + self.symbol_probabilities + .resize(self.symbol_probabilities.len() + skip_amount, 0); + if skip_amount != 3 { + break; + } + } + } + } + + if probability_counter != probablility_sum { + return Err(FSETableError::ProbabilityCounterMismatch { + got: probability_counter, + expected_sum: probablility_sum, + symbol_probabilities: self.symbol_probabilities.clone(), + }); + } + if self.symbol_probabilities.len() > 256 { + return Err(FSETableError::TooManySymbols { + got: self.symbol_probabilities.len(), + }); + } + + let bytes_read = if br.bits_read() % 8 == 0 { + br.bits_read() / 8 + } else { + (br.bits_read() / 8) + 1 + }; + Ok(bytes_read) + } +} + +//utility functions for building the decoding table from probabilities +fn next_position(mut p: usize, table_size: usize) -> usize { + p += (table_size >> 1) + (table_size >> 3) + 3; + p &= table_size - 1; + p +} + +fn calc_baseline_and_numbits( + num_states_total: u32, + num_states_symbol: u32, + state_number: u32, +) -> (u32, u8) { + let num_state_slices = if 1 << (highest_bit_set(num_states_symbol) - 1) == num_states_symbol { + num_states_symbol + } else { + 1 << (highest_bit_set(num_states_symbol)) + }; //always power of two + + let num_double_width_state_slices = num_state_slices - num_states_symbol; //leftovers to the power of two need to be distributed + let num_single_width_state_slices = num_states_symbol - num_double_width_state_slices; //these will not receive a double width slice of states + let slice_width = num_states_total / num_state_slices; //size of a single width slice of states + let num_bits = highest_bit_set(slice_width) - 1; //number of bits needed to read for one slice + + if state_number < num_double_width_state_slices { + let baseline = num_single_width_state_slices * slice_width + state_number * slice_width * 2; + (baseline, num_bits as u8 + 1) + } else { + let index_shifted = state_number - num_double_width_state_slices; + ((index_shifted * slice_width), num_bits as u8) + } +} |