diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 19:33:14 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 19:33:14 +0000 |
commit | 36d22d82aa202bb199967e9512281e9a53db42c9 (patch) | |
tree | 105e8c98ddea1c1e4784a60a5a6410fa416be2de /third_party/jpeg-xl/lib/jxl/dec_ans.h | |
parent | Initial commit. (diff) | |
download | firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.tar.xz firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.zip |
Adding upstream version 115.7.0esr.upstream/115.7.0esrupstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/jpeg-xl/lib/jxl/dec_ans.h')
-rw-r--r-- | third_party/jpeg-xl/lib/jxl/dec_ans.h | 462 |
1 files changed, 462 insertions, 0 deletions
diff --git a/third_party/jpeg-xl/lib/jxl/dec_ans.h b/third_party/jpeg-xl/lib/jxl/dec_ans.h new file mode 100644 index 0000000000..0f4406745a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_ans.h @@ -0,0 +1,462 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DEC_ANS_H_ +#define LIB_JXL_DEC_ANS_H_ + +// Library to decode the ANS population counts from the bit-stream and build a +// decoding table from them. + +#include <stddef.h> +#include <stdint.h> + +#include <cstring> +#include <vector> + +#include "lib/jxl/ans_common.h" +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/cache_aligned.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_huffman.h" +#include "lib/jxl/field_encodings.h" + +namespace jxl { + +class ANSSymbolReader; + +// Experiments show that best performance is typically achieved for a +// split-exponent of 3 or 4. Trend seems to be that '4' is better +// for large-ish pictures, and '3' better for rather small-ish pictures. +// This is plausible - the more special symbols we have, the better +// statistics we need to get a benefit out of them. + +// Our hybrid-encoding scheme has dedicated tokens for the smallest +// (1 << split_exponents) numbers, and for the rest +// encodes (number of bits) + (msb_in_token sub-leading binary digits) + +// (lsb_in_token lowest binary digits) in the token, with the remaining bits +// then being encoded as data. +// +// Example with split_exponent = 4, msb_in_token = 2, lsb_in_token = 0. +// +// Numbers N in [0 .. 15]: +// These get represented as (token=N, bits=''). +// Numbers N >= 16: +// If n is such that 2**n <= N < 2**(n+1), +// and m = N - 2**n is the 'mantissa', +// these get represented as: +// (token=split_token + +// ((n - split_exponent) * 4) + +// (m >> (n - msb_in_token)), +// bits=m & (1 << (n - msb_in_token)) - 1) +// Specifically, we would get: +// N = 0 - 15: (token=N, nbits=0, bits='') +// N = 16 (10000): (token=16, nbits=2, bits='00') +// N = 17 (10001): (token=16, nbits=2, bits='01') +// N = 20 (10100): (token=17, nbits=2, bits='00') +// N = 24 (11000): (token=18, nbits=2, bits='00') +// N = 28 (11100): (token=19, nbits=2, bits='00') +// N = 32 (100000): (token=20, nbits=3, bits='000') +// N = 65535: (token=63, nbits=13, bits='1111111111111') +struct HybridUintConfig { + uint32_t split_exponent; + uint32_t split_token; + uint32_t msb_in_token; + uint32_t lsb_in_token; + JXL_INLINE void Encode(uint32_t value, uint32_t* JXL_RESTRICT token, + uint32_t* JXL_RESTRICT nbits, + uint32_t* JXL_RESTRICT bits) const { + if (value < split_token) { + *token = value; + *nbits = 0; + *bits = 0; + } else { + uint32_t n = FloorLog2Nonzero(value); + uint32_t m = value - (1 << n); + *token = split_token + + ((n - split_exponent) << (msb_in_token + lsb_in_token)) + + ((m >> (n - msb_in_token)) << lsb_in_token) + + (m & ((1 << lsb_in_token) - 1)); + *nbits = n - msb_in_token - lsb_in_token; + *bits = (value >> lsb_in_token) & ((1UL << *nbits) - 1); + } + } + + explicit HybridUintConfig(uint32_t split_exponent = 4, + uint32_t msb_in_token = 2, + uint32_t lsb_in_token = 0) + : split_exponent(split_exponent), + split_token(1 << split_exponent), + msb_in_token(msb_in_token), + lsb_in_token(lsb_in_token) { + JXL_DASSERT(split_exponent >= msb_in_token + lsb_in_token); + } +}; + +struct LZ77Params : public Fields { + LZ77Params(); + JXL_FIELDS_NAME(LZ77Params) + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + bool enabled; + + // Symbols above min_symbol use a special hybrid uint encoding and + // represent a length, to be added to min_length. + uint32_t min_symbol; + uint32_t min_length; + + // Not serialized by VisitFields. + HybridUintConfig length_uint_config{0, 0, 0}; + + size_t nonserialized_distance_context; +}; + +static constexpr size_t kWindowSize = 1 << 20; +static constexpr size_t kNumSpecialDistances = 120; +// Table of special distance codes from WebP lossless. +static constexpr int8_t kSpecialDistances[kNumSpecialDistances][2] = { + {0, 1}, {1, 0}, {1, 1}, {-1, 1}, {0, 2}, {2, 0}, {1, 2}, {-1, 2}, + {2, 1}, {-2, 1}, {2, 2}, {-2, 2}, {0, 3}, {3, 0}, {1, 3}, {-1, 3}, + {3, 1}, {-3, 1}, {2, 3}, {-2, 3}, {3, 2}, {-3, 2}, {0, 4}, {4, 0}, + {1, 4}, {-1, 4}, {4, 1}, {-4, 1}, {3, 3}, {-3, 3}, {2, 4}, {-2, 4}, + {4, 2}, {-4, 2}, {0, 5}, {3, 4}, {-3, 4}, {4, 3}, {-4, 3}, {5, 0}, + {1, 5}, {-1, 5}, {5, 1}, {-5, 1}, {2, 5}, {-2, 5}, {5, 2}, {-5, 2}, + {4, 4}, {-4, 4}, {3, 5}, {-3, 5}, {5, 3}, {-5, 3}, {0, 6}, {6, 0}, + {1, 6}, {-1, 6}, {6, 1}, {-6, 1}, {2, 6}, {-2, 6}, {6, 2}, {-6, 2}, + {4, 5}, {-4, 5}, {5, 4}, {-5, 4}, {3, 6}, {-3, 6}, {6, 3}, {-6, 3}, + {0, 7}, {7, 0}, {1, 7}, {-1, 7}, {5, 5}, {-5, 5}, {7, 1}, {-7, 1}, + {4, 6}, {-4, 6}, {6, 4}, {-6, 4}, {2, 7}, {-2, 7}, {7, 2}, {-7, 2}, + {3, 7}, {-3, 7}, {7, 3}, {-7, 3}, {5, 6}, {-5, 6}, {6, 5}, {-6, 5}, + {8, 0}, {4, 7}, {-4, 7}, {7, 4}, {-7, 4}, {8, 1}, {8, 2}, {6, 6}, + {-6, 6}, {8, 3}, {5, 7}, {-5, 7}, {7, 5}, {-7, 5}, {8, 4}, {6, 7}, + {-6, 7}, {7, 6}, {-7, 6}, {8, 5}, {7, 7}, {-7, 7}, {8, 6}, {8, 7}}; + +struct ANSCode { + CacheAlignedUniquePtr alias_tables; + std::vector<HuffmanDecodingData> huffman_data; + std::vector<HybridUintConfig> uint_config; + std::vector<int> degenerate_symbols; + bool use_prefix_code; + uint8_t log_alpha_size; // for ANS. + LZ77Params lz77; + // Maximum number of bits necessary to represent the result of a + // ReadHybridUint call done with this ANSCode. + size_t max_num_bits = 0; + void UpdateMaxNumBits(size_t ctx, size_t symbol); +}; + +class ANSSymbolReader { + public: + // Invalid symbol reader, to be overwritten. + ANSSymbolReader() = default; + ANSSymbolReader(const ANSCode* code, BitReader* JXL_RESTRICT br, + size_t distance_multiplier = 0) + : alias_tables_( + reinterpret_cast<AliasTable::Entry*>(code->alias_tables.get())), + huffman_data_(code->huffman_data.data()), + use_prefix_code_(code->use_prefix_code), + configs(code->uint_config.data()) { + if (!use_prefix_code_) { + state_ = static_cast<uint32_t>(br->ReadFixedBits<32>()); + log_alpha_size_ = code->log_alpha_size; + log_entry_size_ = ANS_LOG_TAB_SIZE - code->log_alpha_size; + entry_size_minus_1_ = (1 << log_entry_size_) - 1; + } else { + state_ = (ANS_SIGNATURE << 16u); + } + if (!code->lz77.enabled) return; + // a std::vector incurs unacceptable decoding speed loss because of + // initialization. + lz77_window_storage_ = AllocateArray(kWindowSize * sizeof(uint32_t)); + lz77_window_ = reinterpret_cast<uint32_t*>(lz77_window_storage_.get()); + lz77_ctx_ = code->lz77.nonserialized_distance_context; + lz77_length_uint_ = code->lz77.length_uint_config; + lz77_threshold_ = code->lz77.min_symbol; + lz77_min_length_ = code->lz77.min_length; + num_special_distances_ = + distance_multiplier == 0 ? 0 : kNumSpecialDistances; + for (size_t i = 0; i < num_special_distances_; i++) { + int dist = kSpecialDistances[i][0]; + dist += static_cast<int>(distance_multiplier) * kSpecialDistances[i][1]; + if (dist < 1) dist = 1; + special_distances_[i] = dist; + } + } + + JXL_INLINE size_t ReadSymbolANSWithoutRefill(const size_t histo_idx, + BitReader* JXL_RESTRICT br) { + const uint32_t res = state_ & (ANS_TAB_SIZE - 1u); + + const AliasTable::Entry* table = + &alias_tables_[histo_idx << log_alpha_size_]; + const AliasTable::Symbol symbol = + AliasTable::Lookup(table, res, log_entry_size_, entry_size_minus_1_); + state_ = symbol.freq * (state_ >> ANS_LOG_TAB_SIZE) + symbol.offset; + +#if 1 + // Branchless version is about equally fast on SKX. + const uint32_t new_state = + (state_ << 16u) | static_cast<uint32_t>(br->PeekFixedBits<16>()); + const bool normalize = state_ < (1u << 16u); + state_ = normalize ? new_state : state_; + br->Consume(normalize ? 16 : 0); +#else + if (JXL_UNLIKELY(state_ < (1u << 16u))) { + state_ = (state_ << 16u) | br->PeekFixedBits<16>(); + br->Consume(16); + } +#endif + const uint32_t next_res = state_ & (ANS_TAB_SIZE - 1u); + AliasTable::Prefetch(table, next_res, log_entry_size_); + + return symbol.value; + } + + JXL_INLINE size_t ReadSymbolHuffWithoutRefill(const size_t histo_idx, + BitReader* JXL_RESTRICT br) { + return huffman_data_[histo_idx].ReadSymbol(br); + } + + JXL_INLINE size_t ReadSymbolWithoutRefill(const size_t histo_idx, + BitReader* JXL_RESTRICT br) { + // TODO(veluca): hoist if in hotter loops. + if (JXL_UNLIKELY(use_prefix_code_)) { + return ReadSymbolHuffWithoutRefill(histo_idx, br); + } + return ReadSymbolANSWithoutRefill(histo_idx, br); + } + + JXL_INLINE size_t ReadSymbol(const size_t histo_idx, + BitReader* JXL_RESTRICT br) { + br->Refill(); + return ReadSymbolWithoutRefill(histo_idx, br); + } + +#ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION + bool CheckANSFinalState() const { return true; } +#else + bool CheckANSFinalState() const { return state_ == (ANS_SIGNATURE << 16u); } +#endif + + template <typename BitReader> + static JXL_INLINE uint32_t ReadHybridUintConfig( + const HybridUintConfig& config, size_t token, BitReader* br) { + size_t split_token = config.split_token; + size_t msb_in_token = config.msb_in_token; + size_t lsb_in_token = config.lsb_in_token; + size_t split_exponent = config.split_exponent; + // Fast-track version of hybrid integer decoding. + if (token < split_token) return token; + uint32_t nbits = split_exponent - (msb_in_token + lsb_in_token) + + ((token - split_token) >> (msb_in_token + lsb_in_token)); + // Max amount of bits for ReadBits is 32 and max valid left shift is 29 + // bits. However, for speed no error is propagated here, instead limit the + // nbits size. If nbits > 29, the code stream is invalid, but no error is + // returned. + // Note that in most cases we will emit an error if the histogram allows + // representing numbers that would cause invalid shifts, but we need to + // keep this check as when LZ77 is enabled it might make sense to have an + // histogram that could in principle cause invalid shifts. + nbits &= 31u; + uint32_t low = token & ((1 << lsb_in_token) - 1); + token >>= lsb_in_token; + const size_t bits = br->PeekBits(nbits); + br->Consume(nbits); + size_t ret = (((((1 << msb_in_token) | (token & ((1 << msb_in_token) - 1))) + << nbits) | + bits) + << lsb_in_token) | + low; + // TODO(eustas): mark BitReader as unhealthy if nbits > 29 or ret does not + // fit uint32_t + return static_cast<uint32_t>(ret); + } + + // Takes a *clustered* idx. Can only use if HuffRleOnly() is true. + void ReadHybridUintClusteredHuffRleOnly(size_t ctx, + BitReader* JXL_RESTRICT br, + uint32_t* value, uint32_t* run) { + JXL_DASSERT(HuffRleOnly()); + br->Refill(); // covers ReadSymbolWithoutRefill + PeekBits + size_t token = ReadSymbolHuffWithoutRefill(ctx, br); + if (JXL_UNLIKELY(token >= lz77_threshold_)) { + *run = + ReadHybridUintConfig(lz77_length_uint_, token - lz77_threshold_, br) + + lz77_min_length_ - 1; + return; + } + *value = ReadHybridUintConfig(configs[ctx], token, br); + } + bool HuffRleOnly() { + if (lz77_window_ == nullptr) return false; + if (!use_prefix_code_) return false; + for (size_t i = 0; i < kHuffmanTableBits; i++) { + if (huffman_data_[lz77_ctx_].table_[i].bits) return false; + if (huffman_data_[lz77_ctx_].table_[i].value != 1) return false; + } + if (configs[lz77_ctx_].split_token > 1) return false; + return true; + } + + // Takes a *clustered* idx. + size_t ReadHybridUintClustered(size_t ctx, BitReader* JXL_RESTRICT br) { + if (JXL_UNLIKELY(num_to_copy_ > 0)) { + size_t ret = lz77_window_[(copy_pos_++) & kWindowMask]; + num_to_copy_--; + lz77_window_[(num_decoded_++) & kWindowMask] = ret; + return ret; + } + br->Refill(); // covers ReadSymbolWithoutRefill + PeekBits + size_t token = ReadSymbolWithoutRefill(ctx, br); + if (JXL_UNLIKELY(token >= lz77_threshold_)) { + num_to_copy_ = + ReadHybridUintConfig(lz77_length_uint_, token - lz77_threshold_, br) + + lz77_min_length_; + br->Refill(); // covers ReadSymbolWithoutRefill + PeekBits + // Distance code. + size_t token = ReadSymbolWithoutRefill(lz77_ctx_, br); + size_t distance = ReadHybridUintConfig(configs[lz77_ctx_], token, br); + if (JXL_LIKELY(distance < num_special_distances_)) { + distance = special_distances_[distance]; + } else { + distance = distance + 1 - num_special_distances_; + } + if (JXL_UNLIKELY(distance > num_decoded_)) { + distance = num_decoded_; + } + if (JXL_UNLIKELY(distance > kWindowSize)) { + distance = kWindowSize; + } + copy_pos_ = num_decoded_ - distance; + if (JXL_UNLIKELY(distance == 0)) { + JXL_DASSERT(lz77_window_ != nullptr); + // distance 0 -> num_decoded_ == copy_pos_ == 0 + size_t to_fill = std::min<size_t>(num_to_copy_, kWindowSize); + memset(lz77_window_, 0, to_fill * sizeof(lz77_window_[0])); + } + // TODO(eustas): overflow; mark BitReader as unhealthy + if (num_to_copy_ < lz77_min_length_) return 0; + return ReadHybridUintClustered(ctx, br); // will trigger a copy. + } + size_t ret = ReadHybridUintConfig(configs[ctx], token, br); + if (lz77_window_) lz77_window_[(num_decoded_++) & kWindowMask] = ret; + return ret; + } + + JXL_INLINE size_t ReadHybridUint(size_t ctx, BitReader* JXL_RESTRICT br, + const std::vector<uint8_t>& context_map) { + return ReadHybridUintClustered(context_map[ctx], br); + } + + // ctx is a *clustered* context! + // This function will modify the ANS state as if `count` symbols have been + // decoded. + bool IsSingleValueAndAdvance(size_t ctx, uint32_t* value, size_t count) { + // TODO(veluca): No optimization for Huffman mode yet. + if (use_prefix_code_) return false; + // TODO(eustas): propagate "degenerate_symbol" to simplify this method. + const uint32_t res = state_ & (ANS_TAB_SIZE - 1u); + const AliasTable::Entry* table = &alias_tables_[ctx << log_alpha_size_]; + AliasTable::Symbol symbol = + AliasTable::Lookup(table, res, log_entry_size_, entry_size_minus_1_); + if (symbol.freq != ANS_TAB_SIZE) return false; + if (configs[ctx].split_token <= symbol.value) return false; + if (symbol.value >= lz77_threshold_) return false; + *value = symbol.value; + if (lz77_window_) { + for (size_t i = 0; i < count; i++) { + lz77_window_[(num_decoded_++) & kWindowMask] = symbol.value; + } + } + return true; + } + + static constexpr size_t kMaxCheckpointInterval = 512; + struct Checkpoint { + uint32_t state; + uint32_t num_to_copy; + uint32_t copy_pos; + uint32_t num_decoded; + uint32_t lz77_window[kMaxCheckpointInterval]; + }; + void Save(Checkpoint* checkpoint) { + checkpoint->state = state_; + checkpoint->num_decoded = num_decoded_; + checkpoint->num_to_copy = num_to_copy_; + checkpoint->copy_pos = copy_pos_; + if (lz77_window_) { + size_t win_start = num_decoded_ & kWindowMask; + size_t win_end = (num_decoded_ + kMaxCheckpointInterval) & kWindowMask; + if (win_end > win_start) { + memcpy(checkpoint->lz77_window, lz77_window_ + win_start, + (win_end - win_start) * sizeof(*lz77_window_)); + } else { + memcpy(checkpoint->lz77_window, lz77_window_ + win_start, + (kWindowSize - win_start) * sizeof(*lz77_window_)); + memcpy(checkpoint->lz77_window + (kWindowSize - win_start), + lz77_window_, win_end * sizeof(*lz77_window_)); + } + } + } + void Restore(const Checkpoint& checkpoint) { + state_ = checkpoint.state; + JXL_DASSERT(num_decoded_ <= + checkpoint.num_decoded + kMaxCheckpointInterval); + num_decoded_ = checkpoint.num_decoded; + num_to_copy_ = checkpoint.num_to_copy; + copy_pos_ = checkpoint.copy_pos; + if (lz77_window_) { + size_t win_start = num_decoded_ & kWindowMask; + size_t win_end = (num_decoded_ + kMaxCheckpointInterval) & kWindowMask; + if (win_end > win_start) { + memcpy(lz77_window_ + win_start, checkpoint.lz77_window, + (win_end - win_start) * sizeof(*lz77_window_)); + } else { + memcpy(lz77_window_ + win_start, checkpoint.lz77_window, + (kWindowSize - win_start) * sizeof(*lz77_window_)); + memcpy(lz77_window_, checkpoint.lz77_window + (kWindowSize - win_start), + win_end * sizeof(*lz77_window_)); + } + } + } + + private: + const AliasTable::Entry* JXL_RESTRICT alias_tables_; // not owned + const HuffmanDecodingData* huffman_data_; + bool use_prefix_code_; + uint32_t state_ = ANS_SIGNATURE << 16u; + const HybridUintConfig* JXL_RESTRICT configs; + uint32_t log_alpha_size_{}; + uint32_t log_entry_size_{}; + uint32_t entry_size_minus_1_{}; + + // LZ77 structures and constants. + static constexpr size_t kWindowMask = kWindowSize - 1; + CacheAlignedUniquePtr lz77_window_storage_; + uint32_t* lz77_window_ = nullptr; + uint32_t num_decoded_ = 0; + uint32_t num_to_copy_ = 0; + uint32_t copy_pos_ = 0; + uint32_t lz77_ctx_ = 0; + uint32_t lz77_min_length_ = 0; + uint32_t lz77_threshold_ = 1 << 20; // bigger than any symbol. + HybridUintConfig lz77_length_uint_; + uint32_t special_distances_[kNumSpecialDistances]{}; + uint32_t num_special_distances_{}; +}; + +Status DecodeHistograms(BitReader* br, size_t num_contexts, ANSCode* code, + std::vector<uint8_t>* context_map, + bool disallow_lz77 = false); + +// Exposed for tests. +Status DecodeUintConfigs(size_t log_alpha_size, + std::vector<HybridUintConfig>* uint_config, + BitReader* br); + +} // namespace jxl + +#endif // LIB_JXL_DEC_ANS_H_ |