summaryrefslogtreecommitdiffstats
path: root/third_party/jpeg-xl/lib/jxl/enc_cluster.cc
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/jpeg-xl/lib/jxl/enc_cluster.cc')
-rw-r--r--third_party/jpeg-xl/lib/jxl/enc_cluster.cc352
1 files changed, 352 insertions, 0 deletions
diff --git a/third_party/jpeg-xl/lib/jxl/enc_cluster.cc b/third_party/jpeg-xl/lib/jxl/enc_cluster.cc
new file mode 100644
index 0000000000..df1b31ddf7
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/enc_cluster.cc
@@ -0,0 +1,352 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#include "lib/jxl/enc_cluster.h"
+
+#include <algorithm>
+#include <cmath>
+#include <limits>
+#include <map>
+#include <memory>
+#include <numeric>
+#include <queue>
+#include <tuple>
+
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "lib/jxl/enc_cluster.cc"
+#include <hwy/foreach_target.h>
+#include <hwy/highway.h>
+
+#include "lib/jxl/ac_context.h"
+#include "lib/jxl/base/fast_math-inl.h"
+#include "lib/jxl/enc_ans.h"
+HWY_BEFORE_NAMESPACE();
+namespace jxl {
+namespace HWY_NAMESPACE {
+
+// These templates are not found via ADL.
+using hwy::HWY_NAMESPACE::Eq;
+using hwy::HWY_NAMESPACE::IfThenZeroElse;
+
+template <class V>
+V Entropy(V count, V inv_total, V total) {
+ const HWY_CAPPED(float, Histogram::kRounding) d;
+ const auto zero = Set(d, 0.0f);
+ // TODO(eustas): why (0 - x) instead of Neg(x)?
+ return IfThenZeroElse(
+ Eq(count, total),
+ Sub(zero, Mul(count, FastLog2f(d, Mul(inv_total, count)))));
+}
+
+void HistogramEntropy(const Histogram& a) {
+ a.entropy_ = 0.0f;
+ if (a.total_count_ == 0) return;
+
+ const HWY_CAPPED(float, Histogram::kRounding) df;
+ const HWY_CAPPED(int32_t, Histogram::kRounding) di;
+
+ const auto inv_tot = Set(df, 1.0f / a.total_count_);
+ auto entropy_lanes = Zero(df);
+ auto total = Set(df, a.total_count_);
+
+ for (size_t i = 0; i < a.data_.size(); i += Lanes(di)) {
+ const auto counts = LoadU(di, &a.data_[i]);
+ entropy_lanes =
+ Add(entropy_lanes, Entropy(ConvertTo(df, counts), inv_tot, total));
+ }
+ a.entropy_ += GetLane(SumOfLanes(df, entropy_lanes));
+}
+
+float HistogramDistance(const Histogram& a, const Histogram& b) {
+ if (a.total_count_ == 0 || b.total_count_ == 0) return 0;
+
+ const HWY_CAPPED(float, Histogram::kRounding) df;
+ const HWY_CAPPED(int32_t, Histogram::kRounding) di;
+
+ const auto inv_tot = Set(df, 1.0f / (a.total_count_ + b.total_count_));
+ auto distance_lanes = Zero(df);
+ auto total = Set(df, a.total_count_ + b.total_count_);
+
+ for (size_t i = 0; i < std::max(a.data_.size(), b.data_.size());
+ i += Lanes(di)) {
+ const auto a_counts =
+ a.data_.size() > i ? LoadU(di, &a.data_[i]) : Zero(di);
+ const auto b_counts =
+ b.data_.size() > i ? LoadU(di, &b.data_[i]) : Zero(di);
+ const auto counts = ConvertTo(df, Add(a_counts, b_counts));
+ distance_lanes = Add(distance_lanes, Entropy(counts, inv_tot, total));
+ }
+ const float total_distance = GetLane(SumOfLanes(df, distance_lanes));
+ return total_distance - a.entropy_ - b.entropy_;
+}
+
+constexpr const float kInfinity = std::numeric_limits<float>::infinity();
+
+float HistogramKLDivergence(const Histogram& actual, const Histogram& coding) {
+ if (actual.total_count_ == 0) return 0;
+ if (coding.total_count_ == 0) return kInfinity;
+
+ const HWY_CAPPED(float, Histogram::kRounding) df;
+ const HWY_CAPPED(int32_t, Histogram::kRounding) di;
+
+ const auto coding_inv = Set(df, 1.0f / coding.total_count_);
+ auto cost_lanes = Zero(df);
+
+ for (size_t i = 0; i < actual.data_.size(); i += Lanes(di)) {
+ const auto counts = LoadU(di, &actual.data_[i]);
+ const auto coding_counts =
+ coding.data_.size() > i ? LoadU(di, &coding.data_[i]) : Zero(di);
+ const auto coding_probs = Mul(ConvertTo(df, coding_counts), coding_inv);
+ const auto neg_coding_cost = BitCast(
+ df,
+ IfThenZeroElse(Eq(counts, Zero(di)),
+ IfThenElse(Eq(coding_counts, Zero(di)),
+ BitCast(di, Set(df, -kInfinity)),
+ BitCast(di, FastLog2f(df, coding_probs)))));
+ cost_lanes = NegMulAdd(ConvertTo(df, counts), neg_coding_cost, cost_lanes);
+ }
+ const float total_cost = GetLane(SumOfLanes(df, cost_lanes));
+ return total_cost - actual.entropy_;
+}
+
+// First step of a k-means clustering with a fancy distance metric.
+void FastClusterHistograms(const std::vector<Histogram>& in,
+ size_t max_histograms, std::vector<Histogram>* out,
+ std::vector<uint32_t>* histogram_symbols) {
+ const size_t prev_histograms = out->size();
+ out->reserve(max_histograms);
+ histogram_symbols->clear();
+ histogram_symbols->resize(in.size(), max_histograms);
+
+ std::vector<float> dists(in.size(), std::numeric_limits<float>::max());
+ size_t largest_idx = 0;
+ for (size_t i = 0; i < in.size(); i++) {
+ if (in[i].total_count_ == 0) {
+ (*histogram_symbols)[i] = 0;
+ dists[i] = 0.0f;
+ continue;
+ }
+ HistogramEntropy(in[i]);
+ if (in[i].total_count_ > in[largest_idx].total_count_) {
+ largest_idx = i;
+ }
+ }
+
+ if (prev_histograms > 0) {
+ for (size_t j = 0; j < prev_histograms; ++j) {
+ HistogramEntropy((*out)[j]);
+ }
+ for (size_t i = 0; i < in.size(); i++) {
+ if (dists[i] == 0.0f) continue;
+ for (size_t j = 0; j < prev_histograms; ++j) {
+ dists[i] = std::min(HistogramKLDivergence(in[i], (*out)[j]), dists[i]);
+ }
+ }
+ auto max_dist = std::max_element(dists.begin(), dists.end());
+ if (*max_dist > 0.0f) {
+ largest_idx = max_dist - dists.begin();
+ }
+ }
+
+ constexpr float kMinDistanceForDistinct = 48.0f;
+ while (out->size() < max_histograms) {
+ (*histogram_symbols)[largest_idx] = out->size();
+ out->push_back(in[largest_idx]);
+ dists[largest_idx] = 0.0f;
+ largest_idx = 0;
+ for (size_t i = 0; i < in.size(); i++) {
+ if (dists[i] == 0.0f) continue;
+ dists[i] = std::min(HistogramDistance(in[i], out->back()), dists[i]);
+ if (dists[i] > dists[largest_idx]) largest_idx = i;
+ }
+ if (dists[largest_idx] < kMinDistanceForDistinct) break;
+ }
+
+ for (size_t i = 0; i < in.size(); i++) {
+ if ((*histogram_symbols)[i] != max_histograms) continue;
+ size_t best = 0;
+ float best_dist = std::numeric_limits<float>::max();
+ for (size_t j = 0; j < out->size(); j++) {
+ float dist = j < prev_histograms ? HistogramKLDivergence(in[i], (*out)[j])
+ : HistogramDistance(in[i], (*out)[j]);
+ if (dist < best_dist) {
+ best = j;
+ best_dist = dist;
+ }
+ }
+ JXL_ASSERT(best_dist < std::numeric_limits<float>::max());
+ if (best >= prev_histograms) {
+ (*out)[best].AddHistogram(in[i]);
+ HistogramEntropy((*out)[best]);
+ }
+ (*histogram_symbols)[i] = best;
+ }
+}
+
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace jxl
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace jxl {
+HWY_EXPORT(FastClusterHistograms); // Local function
+HWY_EXPORT(HistogramEntropy); // Local function
+
+float Histogram::PopulationCost() const {
+ return ANSPopulationCost(data_.data(), data_.size());
+}
+
+float Histogram::ShannonEntropy() const {
+ HWY_DYNAMIC_DISPATCH(HistogramEntropy)(*this);
+ return entropy_;
+}
+
+namespace {
+// -----------------------------------------------------------------------------
+// Histogram refinement
+
+// Reorder histograms in *out so that the new symbols in *symbols come in
+// increasing order.
+void HistogramReindex(std::vector<Histogram>* out, size_t prev_histograms,
+ std::vector<uint32_t>* symbols) {
+ std::vector<Histogram> tmp(*out);
+ std::map<int, int> new_index;
+ for (size_t i = 0; i < prev_histograms; ++i) {
+ new_index[i] = i;
+ }
+ int next_index = prev_histograms;
+ for (uint32_t symbol : *symbols) {
+ if (new_index.find(symbol) == new_index.end()) {
+ new_index[symbol] = next_index;
+ (*out)[next_index] = tmp[symbol];
+ ++next_index;
+ }
+ }
+ out->resize(next_index);
+ for (uint32_t& symbol : *symbols) {
+ symbol = new_index[symbol];
+ }
+}
+
+} // namespace
+
+// Clusters similar histograms in 'in' together, the selected histograms are
+// placed in 'out', and for each index in 'in', *histogram_symbols will
+// indicate which of the 'out' histograms is the best approximation.
+void ClusterHistograms(const HistogramParams params,
+ const std::vector<Histogram>& in, size_t max_histograms,
+ std::vector<Histogram>* out,
+ std::vector<uint32_t>* histogram_symbols) {
+ size_t prev_histograms = out->size();
+ max_histograms = std::min(max_histograms, params.max_histograms);
+ max_histograms = std::min(max_histograms, in.size());
+ if (params.clustering == HistogramParams::ClusteringType::kFastest) {
+ max_histograms = std::min(max_histograms, static_cast<size_t>(4));
+ }
+
+ HWY_DYNAMIC_DISPATCH(FastClusterHistograms)
+ (in, prev_histograms + max_histograms, out, histogram_symbols);
+
+ if (prev_histograms == 0 &&
+ params.clustering == HistogramParams::ClusteringType::kBest) {
+ for (size_t i = 0; i < out->size(); i++) {
+ (*out)[i].entropy_ =
+ ANSPopulationCost((*out)[i].data_.data(), (*out)[i].data_.size());
+ }
+ uint32_t next_version = 2;
+ std::vector<uint32_t> version(out->size(), 1);
+ std::vector<uint32_t> renumbering(out->size());
+ std::iota(renumbering.begin(), renumbering.end(), 0);
+
+ // Try to pair up clusters if doing so reduces the total cost.
+
+ struct HistogramPair {
+ // validity of a pair: p.version == max(version[i], version[j])
+ float cost;
+ uint32_t first;
+ uint32_t second;
+ uint32_t version;
+ // We use > because priority queues sort in *decreasing* order, but we
+ // want lower cost elements to appear first.
+ bool operator<(const HistogramPair& other) const {
+ return std::make_tuple(cost, first, second, version) >
+ std::make_tuple(other.cost, other.first, other.second,
+ other.version);
+ }
+ };
+
+ // Create list of all pairs by increasing merging cost.
+ std::priority_queue<HistogramPair> pairs_to_merge;
+ for (uint32_t i = 0; i < out->size(); i++) {
+ for (uint32_t j = i + 1; j < out->size(); j++) {
+ Histogram histo;
+ histo.AddHistogram((*out)[i]);
+ histo.AddHistogram((*out)[j]);
+ float cost = ANSPopulationCost(histo.data_.data(), histo.data_.size()) -
+ (*out)[i].entropy_ - (*out)[j].entropy_;
+ // Avoid enqueueing pairs that are not advantageous to merge.
+ if (cost >= 0) continue;
+ pairs_to_merge.push(
+ HistogramPair{cost, i, j, std::max(version[i], version[j])});
+ }
+ }
+
+ // Merge the best pair to merge, add new pairs that get formed as a
+ // consequence.
+ while (!pairs_to_merge.empty()) {
+ uint32_t first = pairs_to_merge.top().first;
+ uint32_t second = pairs_to_merge.top().second;
+ uint32_t ver = pairs_to_merge.top().version;
+ pairs_to_merge.pop();
+ if (ver != std::max(version[first], version[second]) ||
+ version[first] == 0 || version[second] == 0) {
+ continue;
+ }
+ (*out)[first].AddHistogram((*out)[second]);
+ (*out)[first].entropy_ = ANSPopulationCost((*out)[first].data_.data(),
+ (*out)[first].data_.size());
+ for (size_t i = 0; i < renumbering.size(); i++) {
+ if (renumbering[i] == second) {
+ renumbering[i] = first;
+ }
+ }
+ version[second] = 0;
+ version[first] = next_version++;
+ for (uint32_t j = 0; j < out->size(); j++) {
+ if (j == first) continue;
+ if (version[j] == 0) continue;
+ Histogram histo;
+ histo.AddHistogram((*out)[first]);
+ histo.AddHistogram((*out)[j]);
+ float cost = ANSPopulationCost(histo.data_.data(), histo.data_.size()) -
+ (*out)[first].entropy_ - (*out)[j].entropy_;
+ // Avoid enqueueing pairs that are not advantageous to merge.
+ if (cost >= 0) continue;
+ pairs_to_merge.push(
+ HistogramPair{cost, std::min(first, j), std::max(first, j),
+ std::max(version[first], version[j])});
+ }
+ }
+ std::vector<uint32_t> reverse_renumbering(out->size(), -1);
+ size_t num_alive = 0;
+ for (size_t i = 0; i < out->size(); i++) {
+ if (version[i] == 0) continue;
+ (*out)[num_alive++] = (*out)[i];
+ reverse_renumbering[i] = num_alive - 1;
+ }
+ out->resize(num_alive);
+ for (size_t i = 0; i < histogram_symbols->size(); i++) {
+ (*histogram_symbols)[i] =
+ reverse_renumbering[renumbering[(*histogram_symbols)[i]]];
+ }
+ }
+
+ // Convert the context map to a canonical form.
+ HistogramReindex(out, prev_histograms, histogram_symbols);
+}
+
+} // namespace jxl
+#endif // HWY_ONCE