diff options
Diffstat (limited to 'third_party/jpeg-xl/lib/jxl/enc_ans.cc')
-rw-r--r-- | third_party/jpeg-xl/lib/jxl/enc_ans.cc | 172 |
1 files changed, 112 insertions, 60 deletions
diff --git a/third_party/jpeg-xl/lib/jxl/enc_ans.cc b/third_party/jpeg-xl/lib/jxl/enc_ans.cc index 3efa62d8e1..5e59790b1e 100644 --- a/third_party/jpeg-xl/lib/jxl/enc_ans.cc +++ b/third_party/jpeg-xl/lib/jxl/enc_ans.cc @@ -5,11 +5,13 @@ #include "lib/jxl/enc_ans.h" +#include <jxl/types.h> #include <stdint.h> #include <algorithm> #include <array> #include <cmath> +#include <cstdint> #include <limits> #include <numeric> #include <type_traits> @@ -20,12 +22,15 @@ #include "lib/jxl/ans_common.h" #include "lib/jxl/base/bits.h" #include "lib/jxl/base/fast_math-inl.h" +#include "lib/jxl/base/status.h" #include "lib/jxl/dec_ans.h" +#include "lib/jxl/enc_ans_params.h" #include "lib/jxl/enc_aux_out.h" #include "lib/jxl/enc_cluster.h" #include "lib/jxl/enc_context_map.h" #include "lib/jxl/enc_fields.h" #include "lib/jxl/enc_huffman.h" +#include "lib/jxl/enc_params.h" #include "lib/jxl/fields.h" namespace jxl { @@ -37,7 +42,7 @@ constexpr #endif bool ans_fuzzer_friendly_ = false; -static const int kMaxNumSymbolsForSmallCode = 4; +const int kMaxNumSymbolsForSmallCode = 4; void ANSBuildInfoTable(const ANSHistBin* counts, const AliasTable::Entry* table, size_t alphabet_size, size_t log_alpha_size, @@ -99,16 +104,16 @@ float EstimateDataBitsFlat(const ANSHistBin* histogram, size_t len) { // Static Huffman code for encoding logcounts. The last symbol is used as RLE // sequence. -static const uint8_t kLogCountBitLengths[ANS_LOG_TAB_SIZE + 2] = { +const uint8_t kLogCountBitLengths[ANS_LOG_TAB_SIZE + 2] = { 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 6, 7, 7, }; -static const uint8_t kLogCountSymbols[ANS_LOG_TAB_SIZE + 2] = { +const uint8_t kLogCountSymbols[ANS_LOG_TAB_SIZE + 2] = { 17, 11, 15, 3, 9, 7, 4, 2, 5, 6, 0, 33, 1, 65, }; // Returns the difference between largest count that can be represented and is // smaller than "count" and smallest representable count larger than "count". -static int SmallestIncrement(uint32_t count, uint32_t shift) { +int SmallestIncrement(uint32_t count, uint32_t shift) { int bits = count == 0 ? -1 : FloorLog2Nonzero(count); int drop_bits = bits - GetPopulationCountPrecision(bits, shift); return drop_bits < 0 ? 1 : (1 << drop_bits); @@ -148,10 +153,11 @@ bool RebalanceHistogram(const float* targets, int max_symbol, int table_size, int inc = SmallestIncrement(counts[n], shift); counts[n] -= counts[n] & (inc - 1); // TODO(robryk): Should we rescale targets[n]? - const float target = - minimize_error_of_sum ? (sum_nonrounded - sum) : targets[n]; + const int target = minimize_error_of_sum + ? (static_cast<int>(sum_nonrounded) - sum) + : static_cast<int>(targets[n]); if (counts[n] == 0 || - (target > counts[n] + inc / 2 && counts[n] + inc < table_size)) { + (target >= counts[n] + inc / 2 && counts[n] + inc < table_size)) { counts[n] += inc; } sum += counts[n]; @@ -203,11 +209,11 @@ Status NormalizeCounts(ANSHistBin* counts, int* omit_pos, const int length, for (size_t n = 0; n < targets.size(); ++n) { targets[n] = norm * counts[n]; } - if (!RebalanceHistogram<false>(&targets[0], max_symbol, table_size, shift, + if (!RebalanceHistogram<false>(targets.data(), max_symbol, table_size, shift, omit_pos, counts)) { // Use an alternative rebalancing mechanism if the one above failed // to create a histogram that is positive wherever the original one was. - if (!RebalanceHistogram<true>(&targets[0], max_symbol, table_size, shift, + if (!RebalanceHistogram<true>(targets.data(), max_symbol, table_size, shift, omit_pos, counts)) { return JXL_FAILURE("Logic error: couldn't rebalance a histogram"); } @@ -482,8 +488,8 @@ size_t BuildAndStoreANSEncodingData( std::vector<ANSHistBin> counts(histogram, histogram + alphabet_size); if (!counts.empty()) { size_t sum = 0; - for (size_t i = 0; i < counts.size(); i++) { - sum += counts[i]; + for (int count : counts) { + sum += count; } if (sum == 0) { counts[0] = ANS_TAB_SIZE; @@ -538,8 +544,8 @@ template <typename Writer> void EncodeUintConfigs(const std::vector<HybridUintConfig>& uint_config, Writer* writer, size_t log_alpha_size) { // TODO(veluca): RLE? - for (size_t i = 0; i < uint_config.size(); i++) { - EncodeUintConfig(uint_config[i], writer, log_alpha_size); + for (const auto& cfg : uint_config) { + EncodeUintConfig(cfg, writer, log_alpha_size); } } template void EncodeUintConfigs(const std::vector<HybridUintConfig>&, @@ -553,8 +559,7 @@ void ChooseUintConfigs(const HistogramParams& params, std::vector<Histogram>* clustered_histograms, EntropyEncodingData* codes, size_t* log_alpha_size) { codes->uint_config.resize(clustered_histograms->size()); - if (params.streaming_mode || - params.uint_method == HistogramParams::HybridUintMethod::kNone) { + if (params.uint_method == HistogramParams::HybridUintMethod::kNone) { return; } if (params.uint_method == HistogramParams::HybridUintMethod::k000) { @@ -570,6 +575,12 @@ void ChooseUintConfigs(const HistogramParams& params, return; } + // If the uint config is adaptive, just stick with the default in streaming + // mode. + if (params.streaming_mode) { + return; + } + // Brute-force method that tries a few options. std::vector<HybridUintConfig> configs; if (params.uint_method == HistogramParams::HybridUintMethod::kBest) { @@ -619,12 +630,11 @@ void ChooseUintConfigs(const HistogramParams& params, std::fill(is_valid.begin(), is_valid.end(), true); std::fill(extra_bits.begin(), extra_bits.end(), 0); - for (size_t i = 0; i < clustered_histograms->size(); i++) { - (*clustered_histograms)[i].Clear(); + for (auto& histo : *clustered_histograms) { + histo.Clear(); } - for (size_t i = 0; i < tokens.size(); ++i) { - for (size_t j = 0; j < tokens[i].size(); ++j) { - const Token token = tokens[i][j]; + for (const auto& stream : tokens) { + for (const auto& token : stream) { // TODO(veluca): do not ignore lz77 commands. if (token.is_lz77_length) continue; size_t histo = context_map[token.context]; @@ -632,7 +642,7 @@ void ChooseUintConfigs(const HistogramParams& params, cfg.Encode(token.value, &tok, &nbits, &bits); if (tok >= max_alpha || (codes->lz77.enabled && tok >= codes->lz77.min_symbol)) { - is_valid[histo] = false; + is_valid[histo] = JXL_FALSE; continue; } extra_bits[histo] += nbits; @@ -654,13 +664,12 @@ void ChooseUintConfigs(const HistogramParams& params, } // Rebuild histograms. - for (size_t i = 0; i < clustered_histograms->size(); i++) { - (*clustered_histograms)[i].Clear(); + for (auto& histo : *clustered_histograms) { + histo.Clear(); } *log_alpha_size = 4; - for (size_t i = 0; i < tokens.size(); ++i) { - for (size_t j = 0; j < tokens[i].size(); ++j) { - const Token token = tokens[i][j]; + for (const auto& stream : tokens) { + for (const auto& token : stream) { uint32_t tok, nbits, bits; size_t histo = context_map[token.context]; (token.is_lz77_length ? codes->lz77.length_uint_config @@ -771,7 +780,7 @@ class HistogramBuilder { } SizeWriter size_writer; // Used if writer == nullptr to estimate costs. cost += 1; - if (writer) writer->Write(1, codes->use_prefix_code); + if (writer) writer->Write(1, TO_JXL_BOOL(codes->use_prefix_code)); if (codes->use_prefix_code) { log_alpha_size = PREFIX_MAX_BITS; @@ -785,8 +794,8 @@ class HistogramBuilder { EncodeUintConfigs(codes->uint_config, writer, log_alpha_size); } if (codes->use_prefix_code) { - for (size_t c = 0; c < clustered_histograms.size(); ++c) { - size_t alphabet_size = clustered_histograms[c].alphabet_size(); + for (const auto& histo : clustered_histograms) { + size_t alphabet_size = histo.alphabet_size(); if (writer) { StoreVarLenUint16(alphabet_size - 1, writer); } else { @@ -832,9 +841,8 @@ class SymbolCostEstimator { HistogramBuilder builder(num_contexts); // Build histograms for estimating lz77 savings. HybridUintConfig uint_config; - for (size_t i = 0; i < tokens.size(); ++i) { - for (size_t j = 0; j < tokens[i].size(); ++j) { - const Token token = tokens[i][j]; + for (const auto& stream : tokens) { + for (const auto& token : stream) { uint32_t tok, nbits, bits; (token.is_lz77_length ? lz77.length_uint_config : uint_config) .Encode(token.value, &tok, &nbits, &bits); @@ -1025,12 +1033,7 @@ struct HashChain { // Count down, so if due to small distance multiplier multiple distances // map to the same code, the smallest code will be used in the end. for (int i = kNumSpecialDistances - 1; i >= 0; --i) { - int xi = kSpecialDistances[i][0]; - int yi = kSpecialDistances[i][1]; - int distance = yi * distance_multiplier + xi; - // Ensure that we map distance 1 to the lowest symbols. - if (distance < 1) distance = 1; - special_dist_table_[distance] = i; + special_dist_table_[SpecialDistance(i, distance_multiplier)] = i; } num_special_distances_ = kNumSpecialDistances; } @@ -1041,9 +1044,9 @@ struct HashChain { if (pos + 2 < size_) { // TODO(lode): take the MSB's of the uint32_t values into account as well, // given that the hash code itself is less than 32 bits. - result ^= (uint32_t)(data_[pos + 0] << 0u); - result ^= (uint32_t)(data_[pos + 1] << hash_shift_); - result ^= (uint32_t)(data_[pos + 2] << (hash_shift_ * 2)); + result ^= static_cast<uint32_t>(data_[pos + 0] << 0u); + result ^= static_cast<uint32_t>(data_[pos + 1] << hash_shift_); + result ^= static_cast<uint32_t>(data_[pos + 2] << (hash_shift_ * 2)); } else { // No need to compute hash of last 2 bytes, the length 2 is too short. return 0; @@ -1071,7 +1074,7 @@ struct HashChain { uint32_t hashval = GetHash(pos); uint32_t wpos = pos & window_mask_; - val[wpos] = (int)hashval; + val[wpos] = static_cast<int>(hashval); if (head[hashval] != -1) chain[wpos] = head[hashval]; head[hashval] = wpos; @@ -1142,7 +1145,10 @@ struct HashChain { } else { if (hashpos == chain[hashpos]) break; hashpos = chain[hashpos]; - if (val[hashpos] != (int)hashval) break; // outdated hash value + if (val[hashpos] != static_cast<int>(hashval)) { + // outdated hash value + break; + } } } } @@ -1274,7 +1280,8 @@ void ApplyLZ77_LZ77(const HistogramParams& params, size_t num_contexts, HashChain chain(in.data(), in.size(), window_size, min_length, max_length, distance_multiplier); - size_t len, dist_symbol; + size_t len; + size_t dist_symbol; const size_t max_lazy_match_len = 256; // 0 to disable lazy matching @@ -1507,7 +1514,7 @@ void EncodeHistograms(const std::vector<uint8_t>& context_map, } EncodeContextMap(context_map, codes.encoding_info.size(), writer, layer, aux_out); - writer->Write(1, codes.use_prefix_code); + writer->Write(1, TO_JXL_BOOL(codes.use_prefix_code)); size_t log_alpha_size = 8; if (codes.use_prefix_code) { log_alpha_size = PREFIX_MAX_BITS; @@ -1583,10 +1590,9 @@ size_t BuildAndEncodeHistograms(const HistogramParams& params, if (ans_fuzzer_friendly_) { uint_config = HybridUintConfig(10, 0, 0); } - for (size_t i = 0; i < tokens.size(); ++i) { + for (const auto& stream : tokens) { if (codes->lz77.enabled) { - for (size_t j = 0; j < tokens[i].size(); ++j) { - const Token& token = tokens[i][j]; + for (const auto& token : stream) { total_tokens++; uint32_t tok, nbits, bits; (token.is_lz77_length ? codes->lz77.length_uint_config : uint_config) @@ -1595,16 +1601,14 @@ size_t BuildAndEncodeHistograms(const HistogramParams& params, builder.VisitSymbol(tok, token.context); } } else if (num_contexts == 1) { - for (size_t j = 0; j < tokens[i].size(); ++j) { - const Token& token = tokens[i][j]; + for (const auto& token : stream) { total_tokens++; uint32_t tok, nbits, bits; uint_config.Encode(token.value, &tok, &nbits, &bits); builder.VisitSymbol(tok, /*token.context=*/0); } } else { - for (size_t j = 0; j < tokens[i].size(); ++j) { - const Token& token = tokens[i][j]; + for (const auto& token : stream) { total_tokens++; uint32_t tok, nbits, bits; uint_config.Encode(token.value, &tok, &nbits, &bits); @@ -1654,10 +1658,10 @@ size_t BuildAndEncodeHistograms(const HistogramParams& params, codes->encoded_histograms.emplace_back(); BitWriter* histo_writer = &codes->encoded_histograms.back(); BitWriter::Allotment allotment(histo_writer, 256 + alphabet_size * 24); - BuildAndStoreANSEncodingData(params.ans_histogram_strategy, counts.data(), - alphabet_size, log_alpha_size, - codes->use_prefix_code, - &codes->encoding_info.back()[0], histo_writer); + BuildAndStoreANSEncodingData( + params.ans_histogram_strategy, counts.data(), alphabet_size, + log_alpha_size, codes->use_prefix_code, + codes->encoding_info.back().data(), histo_writer); allotment.ReclaimAndCharge(histo_writer, 0, nullptr); } @@ -1680,9 +1684,8 @@ size_t WriteTokens(const std::vector<Token>& tokens, size_t context_offset, BitWriter* writer) { size_t num_extra_bits = 0; if (codes.use_prefix_code) { - for (size_t i = 0; i < tokens.size(); i++) { + for (const auto& token : tokens) { uint32_t tok, nbits, bits; - const Token& token = tokens[i]; size_t histo = context_map[context_offset + token.context]; (token.is_lz77_length ? codes.lz77.length_uint_config : codes.uint_config[histo]) @@ -1693,7 +1696,8 @@ size_t WriteTokens(const std::vector<Token>& tokens, // codes.encoding_info[histo][tok].bits); // writer->Write(nbits, bits); uint64_t data = codes.encoding_info[histo][tok].bits; - data |= bits << codes.encoding_info[histo][tok].depth; + data |= static_cast<uint64_t>(bits) + << codes.encoding_info[histo][tok].depth; writer->Write(codes.encoding_info[histo][tok].depth + nbits, data); num_extra_bits += nbits; } @@ -1765,7 +1769,8 @@ void WriteTokens(const std::vector<Token>& tokens, const EntropyEncodingData& codes, const std::vector<uint8_t>& context_map, size_t context_offset, BitWriter* writer, size_t layer, AuxOut* aux_out) { - BitWriter::Allotment allotment(writer, 32 * tokens.size() + 32 * 1024 * 4); + // Theoretically, we could have 15 prefix code bits + 31 extra bits. + BitWriter::Allotment allotment(writer, 46 * tokens.size() + 32 * 1024 * 4); size_t num_extra_bits = WriteTokens(tokens, codes, context_map, context_offset, writer); allotment.ReclaimAndCharge(writer, layer, aux_out); @@ -1779,4 +1784,51 @@ void SetANSFuzzerFriendly(bool ans_fuzzer_friendly) { ans_fuzzer_friendly_ = ans_fuzzer_friendly; #endif } + +HistogramParams HistogramParams::ForModular( + const CompressParams& cparams, + const std::vector<uint8_t>& extra_dc_precision, bool streaming_mode) { + HistogramParams params; + params.streaming_mode = streaming_mode; + if (cparams.speed_tier > SpeedTier::kKitten) { + params.clustering = HistogramParams::ClusteringType::kFast; + params.ans_histogram_strategy = + cparams.speed_tier > SpeedTier::kThunder + ? HistogramParams::ANSHistogramStrategy::kFast + : HistogramParams::ANSHistogramStrategy::kApproximate; + params.lz77_method = + cparams.decoding_speed_tier >= 3 && cparams.modular_mode + ? (cparams.speed_tier >= SpeedTier::kFalcon + ? HistogramParams::LZ77Method::kRLE + : HistogramParams::LZ77Method::kLZ77) + : HistogramParams::LZ77Method::kNone; + // Near-lossless DC, as well as modular mode, require choosing hybrid uint + // more carefully. + if ((!extra_dc_precision.empty() && extra_dc_precision[0] != 0) || + (cparams.modular_mode && cparams.speed_tier < SpeedTier::kCheetah)) { + params.uint_method = HistogramParams::HybridUintMethod::kFast; + } else { + params.uint_method = HistogramParams::HybridUintMethod::kNone; + } + } else if (cparams.speed_tier <= SpeedTier::kTortoise) { + params.lz77_method = HistogramParams::LZ77Method::kOptimal; + } else { + params.lz77_method = HistogramParams::LZ77Method::kLZ77; + } + if (cparams.decoding_speed_tier >= 1) { + params.max_histograms = 12; + } + if (cparams.decoding_speed_tier >= 1 && cparams.responsive) { + params.lz77_method = cparams.speed_tier >= SpeedTier::kCheetah + ? HistogramParams::LZ77Method::kRLE + : cparams.speed_tier >= SpeedTier::kKitten + ? HistogramParams::LZ77Method::kLZ77 + : HistogramParams::LZ77Method::kOptimal; + } + if (cparams.decoding_speed_tier >= 2 && cparams.responsive) { + params.uint_method = HistogramParams::HybridUintMethod::k000; + params.force_huffman = true; + } + return params; +} } // namespace jxl |