summaryrefslogtreecommitdiffstats
path: root/third_party/jpeg-xl/lib/jxl/enc_ans.cc
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/jpeg-xl/lib/jxl/enc_ans.cc')
-rw-r--r--third_party/jpeg-xl/lib/jxl/enc_ans.cc172
1 files changed, 112 insertions, 60 deletions
diff --git a/third_party/jpeg-xl/lib/jxl/enc_ans.cc b/third_party/jpeg-xl/lib/jxl/enc_ans.cc
index 3efa62d8e1..5e59790b1e 100644
--- a/third_party/jpeg-xl/lib/jxl/enc_ans.cc
+++ b/third_party/jpeg-xl/lib/jxl/enc_ans.cc
@@ -5,11 +5,13 @@
#include "lib/jxl/enc_ans.h"
+#include <jxl/types.h>
#include <stdint.h>
#include <algorithm>
#include <array>
#include <cmath>
+#include <cstdint>
#include <limits>
#include <numeric>
#include <type_traits>
@@ -20,12 +22,15 @@
#include "lib/jxl/ans_common.h"
#include "lib/jxl/base/bits.h"
#include "lib/jxl/base/fast_math-inl.h"
+#include "lib/jxl/base/status.h"
#include "lib/jxl/dec_ans.h"
+#include "lib/jxl/enc_ans_params.h"
#include "lib/jxl/enc_aux_out.h"
#include "lib/jxl/enc_cluster.h"
#include "lib/jxl/enc_context_map.h"
#include "lib/jxl/enc_fields.h"
#include "lib/jxl/enc_huffman.h"
+#include "lib/jxl/enc_params.h"
#include "lib/jxl/fields.h"
namespace jxl {
@@ -37,7 +42,7 @@ constexpr
#endif
bool ans_fuzzer_friendly_ = false;
-static const int kMaxNumSymbolsForSmallCode = 4;
+const int kMaxNumSymbolsForSmallCode = 4;
void ANSBuildInfoTable(const ANSHistBin* counts, const AliasTable::Entry* table,
size_t alphabet_size, size_t log_alpha_size,
@@ -99,16 +104,16 @@ float EstimateDataBitsFlat(const ANSHistBin* histogram, size_t len) {
// Static Huffman code for encoding logcounts. The last symbol is used as RLE
// sequence.
-static const uint8_t kLogCountBitLengths[ANS_LOG_TAB_SIZE + 2] = {
+const uint8_t kLogCountBitLengths[ANS_LOG_TAB_SIZE + 2] = {
5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 6, 7, 7,
};
-static const uint8_t kLogCountSymbols[ANS_LOG_TAB_SIZE + 2] = {
+const uint8_t kLogCountSymbols[ANS_LOG_TAB_SIZE + 2] = {
17, 11, 15, 3, 9, 7, 4, 2, 5, 6, 0, 33, 1, 65,
};
// Returns the difference between largest count that can be represented and is
// smaller than "count" and smallest representable count larger than "count".
-static int SmallestIncrement(uint32_t count, uint32_t shift) {
+int SmallestIncrement(uint32_t count, uint32_t shift) {
int bits = count == 0 ? -1 : FloorLog2Nonzero(count);
int drop_bits = bits - GetPopulationCountPrecision(bits, shift);
return drop_bits < 0 ? 1 : (1 << drop_bits);
@@ -148,10 +153,11 @@ bool RebalanceHistogram(const float* targets, int max_symbol, int table_size,
int inc = SmallestIncrement(counts[n], shift);
counts[n] -= counts[n] & (inc - 1);
// TODO(robryk): Should we rescale targets[n]?
- const float target =
- minimize_error_of_sum ? (sum_nonrounded - sum) : targets[n];
+ const int target = minimize_error_of_sum
+ ? (static_cast<int>(sum_nonrounded) - sum)
+ : static_cast<int>(targets[n]);
if (counts[n] == 0 ||
- (target > counts[n] + inc / 2 && counts[n] + inc < table_size)) {
+ (target >= counts[n] + inc / 2 && counts[n] + inc < table_size)) {
counts[n] += inc;
}
sum += counts[n];
@@ -203,11 +209,11 @@ Status NormalizeCounts(ANSHistBin* counts, int* omit_pos, const int length,
for (size_t n = 0; n < targets.size(); ++n) {
targets[n] = norm * counts[n];
}
- if (!RebalanceHistogram<false>(&targets[0], max_symbol, table_size, shift,
+ if (!RebalanceHistogram<false>(targets.data(), max_symbol, table_size, shift,
omit_pos, counts)) {
// Use an alternative rebalancing mechanism if the one above failed
// to create a histogram that is positive wherever the original one was.
- if (!RebalanceHistogram<true>(&targets[0], max_symbol, table_size, shift,
+ if (!RebalanceHistogram<true>(targets.data(), max_symbol, table_size, shift,
omit_pos, counts)) {
return JXL_FAILURE("Logic error: couldn't rebalance a histogram");
}
@@ -482,8 +488,8 @@ size_t BuildAndStoreANSEncodingData(
std::vector<ANSHistBin> counts(histogram, histogram + alphabet_size);
if (!counts.empty()) {
size_t sum = 0;
- for (size_t i = 0; i < counts.size(); i++) {
- sum += counts[i];
+ for (int count : counts) {
+ sum += count;
}
if (sum == 0) {
counts[0] = ANS_TAB_SIZE;
@@ -538,8 +544,8 @@ template <typename Writer>
void EncodeUintConfigs(const std::vector<HybridUintConfig>& uint_config,
Writer* writer, size_t log_alpha_size) {
// TODO(veluca): RLE?
- for (size_t i = 0; i < uint_config.size(); i++) {
- EncodeUintConfig(uint_config[i], writer, log_alpha_size);
+ for (const auto& cfg : uint_config) {
+ EncodeUintConfig(cfg, writer, log_alpha_size);
}
}
template void EncodeUintConfigs(const std::vector<HybridUintConfig>&,
@@ -553,8 +559,7 @@ void ChooseUintConfigs(const HistogramParams& params,
std::vector<Histogram>* clustered_histograms,
EntropyEncodingData* codes, size_t* log_alpha_size) {
codes->uint_config.resize(clustered_histograms->size());
- if (params.streaming_mode ||
- params.uint_method == HistogramParams::HybridUintMethod::kNone) {
+ if (params.uint_method == HistogramParams::HybridUintMethod::kNone) {
return;
}
if (params.uint_method == HistogramParams::HybridUintMethod::k000) {
@@ -570,6 +575,12 @@ void ChooseUintConfigs(const HistogramParams& params,
return;
}
+ // If the uint config is adaptive, just stick with the default in streaming
+ // mode.
+ if (params.streaming_mode) {
+ return;
+ }
+
// Brute-force method that tries a few options.
std::vector<HybridUintConfig> configs;
if (params.uint_method == HistogramParams::HybridUintMethod::kBest) {
@@ -619,12 +630,11 @@ void ChooseUintConfigs(const HistogramParams& params,
std::fill(is_valid.begin(), is_valid.end(), true);
std::fill(extra_bits.begin(), extra_bits.end(), 0);
- for (size_t i = 0; i < clustered_histograms->size(); i++) {
- (*clustered_histograms)[i].Clear();
+ for (auto& histo : *clustered_histograms) {
+ histo.Clear();
}
- for (size_t i = 0; i < tokens.size(); ++i) {
- for (size_t j = 0; j < tokens[i].size(); ++j) {
- const Token token = tokens[i][j];
+ for (const auto& stream : tokens) {
+ for (const auto& token : stream) {
// TODO(veluca): do not ignore lz77 commands.
if (token.is_lz77_length) continue;
size_t histo = context_map[token.context];
@@ -632,7 +642,7 @@ void ChooseUintConfigs(const HistogramParams& params,
cfg.Encode(token.value, &tok, &nbits, &bits);
if (tok >= max_alpha ||
(codes->lz77.enabled && tok >= codes->lz77.min_symbol)) {
- is_valid[histo] = false;
+ is_valid[histo] = JXL_FALSE;
continue;
}
extra_bits[histo] += nbits;
@@ -654,13 +664,12 @@ void ChooseUintConfigs(const HistogramParams& params,
}
// Rebuild histograms.
- for (size_t i = 0; i < clustered_histograms->size(); i++) {
- (*clustered_histograms)[i].Clear();
+ for (auto& histo : *clustered_histograms) {
+ histo.Clear();
}
*log_alpha_size = 4;
- for (size_t i = 0; i < tokens.size(); ++i) {
- for (size_t j = 0; j < tokens[i].size(); ++j) {
- const Token token = tokens[i][j];
+ for (const auto& stream : tokens) {
+ for (const auto& token : stream) {
uint32_t tok, nbits, bits;
size_t histo = context_map[token.context];
(token.is_lz77_length ? codes->lz77.length_uint_config
@@ -771,7 +780,7 @@ class HistogramBuilder {
}
SizeWriter size_writer; // Used if writer == nullptr to estimate costs.
cost += 1;
- if (writer) writer->Write(1, codes->use_prefix_code);
+ if (writer) writer->Write(1, TO_JXL_BOOL(codes->use_prefix_code));
if (codes->use_prefix_code) {
log_alpha_size = PREFIX_MAX_BITS;
@@ -785,8 +794,8 @@ class HistogramBuilder {
EncodeUintConfigs(codes->uint_config, writer, log_alpha_size);
}
if (codes->use_prefix_code) {
- for (size_t c = 0; c < clustered_histograms.size(); ++c) {
- size_t alphabet_size = clustered_histograms[c].alphabet_size();
+ for (const auto& histo : clustered_histograms) {
+ size_t alphabet_size = histo.alphabet_size();
if (writer) {
StoreVarLenUint16(alphabet_size - 1, writer);
} else {
@@ -832,9 +841,8 @@ class SymbolCostEstimator {
HistogramBuilder builder(num_contexts);
// Build histograms for estimating lz77 savings.
HybridUintConfig uint_config;
- for (size_t i = 0; i < tokens.size(); ++i) {
- for (size_t j = 0; j < tokens[i].size(); ++j) {
- const Token token = tokens[i][j];
+ for (const auto& stream : tokens) {
+ for (const auto& token : stream) {
uint32_t tok, nbits, bits;
(token.is_lz77_length ? lz77.length_uint_config : uint_config)
.Encode(token.value, &tok, &nbits, &bits);
@@ -1025,12 +1033,7 @@ struct HashChain {
// Count down, so if due to small distance multiplier multiple distances
// map to the same code, the smallest code will be used in the end.
for (int i = kNumSpecialDistances - 1; i >= 0; --i) {
- int xi = kSpecialDistances[i][0];
- int yi = kSpecialDistances[i][1];
- int distance = yi * distance_multiplier + xi;
- // Ensure that we map distance 1 to the lowest symbols.
- if (distance < 1) distance = 1;
- special_dist_table_[distance] = i;
+ special_dist_table_[SpecialDistance(i, distance_multiplier)] = i;
}
num_special_distances_ = kNumSpecialDistances;
}
@@ -1041,9 +1044,9 @@ struct HashChain {
if (pos + 2 < size_) {
// TODO(lode): take the MSB's of the uint32_t values into account as well,
// given that the hash code itself is less than 32 bits.
- result ^= (uint32_t)(data_[pos + 0] << 0u);
- result ^= (uint32_t)(data_[pos + 1] << hash_shift_);
- result ^= (uint32_t)(data_[pos + 2] << (hash_shift_ * 2));
+ result ^= static_cast<uint32_t>(data_[pos + 0] << 0u);
+ result ^= static_cast<uint32_t>(data_[pos + 1] << hash_shift_);
+ result ^= static_cast<uint32_t>(data_[pos + 2] << (hash_shift_ * 2));
} else {
// No need to compute hash of last 2 bytes, the length 2 is too short.
return 0;
@@ -1071,7 +1074,7 @@ struct HashChain {
uint32_t hashval = GetHash(pos);
uint32_t wpos = pos & window_mask_;
- val[wpos] = (int)hashval;
+ val[wpos] = static_cast<int>(hashval);
if (head[hashval] != -1) chain[wpos] = head[hashval];
head[hashval] = wpos;
@@ -1142,7 +1145,10 @@ struct HashChain {
} else {
if (hashpos == chain[hashpos]) break;
hashpos = chain[hashpos];
- if (val[hashpos] != (int)hashval) break; // outdated hash value
+ if (val[hashpos] != static_cast<int>(hashval)) {
+ // outdated hash value
+ break;
+ }
}
}
}
@@ -1274,7 +1280,8 @@ void ApplyLZ77_LZ77(const HistogramParams& params, size_t num_contexts,
HashChain chain(in.data(), in.size(), window_size, min_length, max_length,
distance_multiplier);
- size_t len, dist_symbol;
+ size_t len;
+ size_t dist_symbol;
const size_t max_lazy_match_len = 256; // 0 to disable lazy matching
@@ -1507,7 +1514,7 @@ void EncodeHistograms(const std::vector<uint8_t>& context_map,
}
EncodeContextMap(context_map, codes.encoding_info.size(), writer, layer,
aux_out);
- writer->Write(1, codes.use_prefix_code);
+ writer->Write(1, TO_JXL_BOOL(codes.use_prefix_code));
size_t log_alpha_size = 8;
if (codes.use_prefix_code) {
log_alpha_size = PREFIX_MAX_BITS;
@@ -1583,10 +1590,9 @@ size_t BuildAndEncodeHistograms(const HistogramParams& params,
if (ans_fuzzer_friendly_) {
uint_config = HybridUintConfig(10, 0, 0);
}
- for (size_t i = 0; i < tokens.size(); ++i) {
+ for (const auto& stream : tokens) {
if (codes->lz77.enabled) {
- for (size_t j = 0; j < tokens[i].size(); ++j) {
- const Token& token = tokens[i][j];
+ for (const auto& token : stream) {
total_tokens++;
uint32_t tok, nbits, bits;
(token.is_lz77_length ? codes->lz77.length_uint_config : uint_config)
@@ -1595,16 +1601,14 @@ size_t BuildAndEncodeHistograms(const HistogramParams& params,
builder.VisitSymbol(tok, token.context);
}
} else if (num_contexts == 1) {
- for (size_t j = 0; j < tokens[i].size(); ++j) {
- const Token& token = tokens[i][j];
+ for (const auto& token : stream) {
total_tokens++;
uint32_t tok, nbits, bits;
uint_config.Encode(token.value, &tok, &nbits, &bits);
builder.VisitSymbol(tok, /*token.context=*/0);
}
} else {
- for (size_t j = 0; j < tokens[i].size(); ++j) {
- const Token& token = tokens[i][j];
+ for (const auto& token : stream) {
total_tokens++;
uint32_t tok, nbits, bits;
uint_config.Encode(token.value, &tok, &nbits, &bits);
@@ -1654,10 +1658,10 @@ size_t BuildAndEncodeHistograms(const HistogramParams& params,
codes->encoded_histograms.emplace_back();
BitWriter* histo_writer = &codes->encoded_histograms.back();
BitWriter::Allotment allotment(histo_writer, 256 + alphabet_size * 24);
- BuildAndStoreANSEncodingData(params.ans_histogram_strategy, counts.data(),
- alphabet_size, log_alpha_size,
- codes->use_prefix_code,
- &codes->encoding_info.back()[0], histo_writer);
+ BuildAndStoreANSEncodingData(
+ params.ans_histogram_strategy, counts.data(), alphabet_size,
+ log_alpha_size, codes->use_prefix_code,
+ codes->encoding_info.back().data(), histo_writer);
allotment.ReclaimAndCharge(histo_writer, 0, nullptr);
}
@@ -1680,9 +1684,8 @@ size_t WriteTokens(const std::vector<Token>& tokens,
size_t context_offset, BitWriter* writer) {
size_t num_extra_bits = 0;
if (codes.use_prefix_code) {
- for (size_t i = 0; i < tokens.size(); i++) {
+ for (const auto& token : tokens) {
uint32_t tok, nbits, bits;
- const Token& token = tokens[i];
size_t histo = context_map[context_offset + token.context];
(token.is_lz77_length ? codes.lz77.length_uint_config
: codes.uint_config[histo])
@@ -1693,7 +1696,8 @@ size_t WriteTokens(const std::vector<Token>& tokens,
// codes.encoding_info[histo][tok].bits);
// writer->Write(nbits, bits);
uint64_t data = codes.encoding_info[histo][tok].bits;
- data |= bits << codes.encoding_info[histo][tok].depth;
+ data |= static_cast<uint64_t>(bits)
+ << codes.encoding_info[histo][tok].depth;
writer->Write(codes.encoding_info[histo][tok].depth + nbits, data);
num_extra_bits += nbits;
}
@@ -1765,7 +1769,8 @@ void WriteTokens(const std::vector<Token>& tokens,
const EntropyEncodingData& codes,
const std::vector<uint8_t>& context_map, size_t context_offset,
BitWriter* writer, size_t layer, AuxOut* aux_out) {
- BitWriter::Allotment allotment(writer, 32 * tokens.size() + 32 * 1024 * 4);
+ // Theoretically, we could have 15 prefix code bits + 31 extra bits.
+ BitWriter::Allotment allotment(writer, 46 * tokens.size() + 32 * 1024 * 4);
size_t num_extra_bits =
WriteTokens(tokens, codes, context_map, context_offset, writer);
allotment.ReclaimAndCharge(writer, layer, aux_out);
@@ -1779,4 +1784,51 @@ void SetANSFuzzerFriendly(bool ans_fuzzer_friendly) {
ans_fuzzer_friendly_ = ans_fuzzer_friendly;
#endif
}
+
+HistogramParams HistogramParams::ForModular(
+ const CompressParams& cparams,
+ const std::vector<uint8_t>& extra_dc_precision, bool streaming_mode) {
+ HistogramParams params;
+ params.streaming_mode = streaming_mode;
+ if (cparams.speed_tier > SpeedTier::kKitten) {
+ params.clustering = HistogramParams::ClusteringType::kFast;
+ params.ans_histogram_strategy =
+ cparams.speed_tier > SpeedTier::kThunder
+ ? HistogramParams::ANSHistogramStrategy::kFast
+ : HistogramParams::ANSHistogramStrategy::kApproximate;
+ params.lz77_method =
+ cparams.decoding_speed_tier >= 3 && cparams.modular_mode
+ ? (cparams.speed_tier >= SpeedTier::kFalcon
+ ? HistogramParams::LZ77Method::kRLE
+ : HistogramParams::LZ77Method::kLZ77)
+ : HistogramParams::LZ77Method::kNone;
+ // Near-lossless DC, as well as modular mode, require choosing hybrid uint
+ // more carefully.
+ if ((!extra_dc_precision.empty() && extra_dc_precision[0] != 0) ||
+ (cparams.modular_mode && cparams.speed_tier < SpeedTier::kCheetah)) {
+ params.uint_method = HistogramParams::HybridUintMethod::kFast;
+ } else {
+ params.uint_method = HistogramParams::HybridUintMethod::kNone;
+ }
+ } else if (cparams.speed_tier <= SpeedTier::kTortoise) {
+ params.lz77_method = HistogramParams::LZ77Method::kOptimal;
+ } else {
+ params.lz77_method = HistogramParams::LZ77Method::kLZ77;
+ }
+ if (cparams.decoding_speed_tier >= 1) {
+ params.max_histograms = 12;
+ }
+ if (cparams.decoding_speed_tier >= 1 && cparams.responsive) {
+ params.lz77_method = cparams.speed_tier >= SpeedTier::kCheetah
+ ? HistogramParams::LZ77Method::kRLE
+ : cparams.speed_tier >= SpeedTier::kKitten
+ ? HistogramParams::LZ77Method::kLZ77
+ : HistogramParams::LZ77Method::kOptimal;
+ }
+ if (cparams.decoding_speed_tier >= 2 && cparams.responsive) {
+ params.uint_method = HistogramParams::HybridUintMethod::k000;
+ params.force_huffman = true;
+ }
+ return params;
+}
} // namespace jxl