summaryrefslogtreecommitdiffstats
path: root/third_party/jpeg-xl/lib/jxl/modular
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/jpeg-xl/lib/jxl/modular')
-rw-r--r--third_party/jpeg-xl/lib/jxl/modular/encoding/context_predict.h626
-rw-r--r--third_party/jpeg-xl/lib/jxl/modular/encoding/dec_ma.cc107
-rw-r--r--third_party/jpeg-xl/lib/jxl/modular/encoding/dec_ma.h66
-rw-r--r--third_party/jpeg-xl/lib/jxl/modular/encoding/enc_debug_tree.cc124
-rw-r--r--third_party/jpeg-xl/lib/jxl/modular/encoding/enc_debug_tree.h27
-rw-r--r--third_party/jpeg-xl/lib/jxl/modular/encoding/enc_encoding.cc562
-rw-r--r--third_party/jpeg-xl/lib/jxl/modular/encoding/enc_encoding.h47
-rw-r--r--third_party/jpeg-xl/lib/jxl/modular/encoding/enc_ma.cc1023
-rw-r--r--third_party/jpeg-xl/lib/jxl/modular/encoding/enc_ma.h157
-rw-r--r--third_party/jpeg-xl/lib/jxl/modular/encoding/encoding.cc622
-rw-r--r--third_party/jpeg-xl/lib/jxl/modular/encoding/encoding.h135
-rw-r--r--third_party/jpeg-xl/lib/jxl/modular/encoding/ma_common.h28
-rw-r--r--third_party/jpeg-xl/lib/jxl/modular/modular_image.cc77
-rw-r--r--third_party/jpeg-xl/lib/jxl/modular/modular_image.h118
-rw-r--r--third_party/jpeg-xl/lib/jxl/modular/options.h117
-rw-r--r--third_party/jpeg-xl/lib/jxl/modular/transform/enc_palette.cc606
-rw-r--r--third_party/jpeg-xl/lib/jxl/modular/transform/enc_palette.h22
-rw-r--r--third_party/jpeg-xl/lib/jxl/modular/transform/enc_rct.cc73
-rw-r--r--third_party/jpeg-xl/lib/jxl/modular/transform/enc_rct.h17
-rw-r--r--third_party/jpeg-xl/lib/jxl/modular/transform/enc_squeeze.cc141
-rw-r--r--third_party/jpeg-xl/lib/jxl/modular/transform/enc_squeeze.h20
-rw-r--r--third_party/jpeg-xl/lib/jxl/modular/transform/enc_transform.cc46
-rw-r--r--third_party/jpeg-xl/lib/jxl/modular/transform/enc_transform.h22
-rw-r--r--third_party/jpeg-xl/lib/jxl/modular/transform/palette.cc176
-rw-r--r--third_party/jpeg-xl/lib/jxl/modular/transform/palette.h129
-rw-r--r--third_party/jpeg-xl/lib/jxl/modular/transform/rct.cc153
-rw-r--r--third_party/jpeg-xl/lib/jxl/modular/transform/rct.h20
-rw-r--r--third_party/jpeg-xl/lib/jxl/modular/transform/squeeze.cc478
-rw-r--r--third_party/jpeg-xl/lib/jxl/modular/transform/squeeze.h90
-rw-r--r--third_party/jpeg-xl/lib/jxl/modular/transform/transform.cc98
-rw-r--r--third_party/jpeg-xl/lib/jxl/modular/transform/transform.h148
31 files changed, 6075 insertions, 0 deletions
diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/context_predict.h b/third_party/jpeg-xl/lib/jxl/modular/encoding/context_predict.h
new file mode 100644
index 0000000000..914cd6a4e4
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/context_predict.h
@@ -0,0 +1,626 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#ifndef LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_
+#define LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_
+
+#include <utility>
+#include <vector>
+
+#include "lib/jxl/fields.h"
+#include "lib/jxl/modular/modular_image.h"
+#include "lib/jxl/modular/options.h"
+
+namespace jxl {
+
+namespace weighted {
+constexpr static size_t kNumPredictors = 4;
+constexpr static int64_t kPredExtraBits = 3;
+constexpr static int64_t kPredictionRound = ((1 << kPredExtraBits) >> 1) - 1;
+constexpr static size_t kNumProperties = 1;
+
+struct Header : public Fields {
+ JXL_FIELDS_NAME(WeightedPredictorHeader)
+ // TODO(janwas): move to cc file, avoid including fields.h.
+ Header() { Bundle::Init(this); }
+
+ Status VisitFields(Visitor *JXL_RESTRICT visitor) override {
+ if (visitor->AllDefault(*this, &all_default)) {
+ // Overwrite all serialized fields, but not any nonserialized_*.
+ visitor->SetDefault(this);
+ return true;
+ }
+ auto visit_p = [visitor](pixel_type val, pixel_type *p) {
+ uint32_t up = *p;
+ JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(5, val, &up));
+ *p = up;
+ return Status(true);
+ };
+ JXL_QUIET_RETURN_IF_ERROR(visit_p(16, &p1C));
+ JXL_QUIET_RETURN_IF_ERROR(visit_p(10, &p2C));
+ JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Ca));
+ JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Cb));
+ JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Cc));
+ JXL_QUIET_RETURN_IF_ERROR(visit_p(0, &p3Cd));
+ JXL_QUIET_RETURN_IF_ERROR(visit_p(0, &p3Ce));
+ JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xd, &w[0]));
+ JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[1]));
+ JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[2]));
+ JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[3]));
+ return true;
+ }
+
+ bool all_default;
+ pixel_type p1C = 0, p2C = 0, p3Ca = 0, p3Cb = 0, p3Cc = 0, p3Cd = 0, p3Ce = 0;
+ uint32_t w[kNumPredictors] = {};
+};
+
+struct State {
+ pixel_type_w prediction[kNumPredictors] = {};
+ pixel_type_w pred = 0; // *before* removing the added bits.
+ std::vector<uint32_t> pred_errors[kNumPredictors];
+ std::vector<int32_t> error;
+ const Header header;
+
+ // Allows to approximate division by a number from 1 to 64.
+ uint32_t divlookup[64];
+
+ constexpr static pixel_type_w AddBits(pixel_type_w x) {
+ return uint64_t(x) << kPredExtraBits;
+ }
+
+ State(Header header, size_t xsize, size_t ysize) : header(header) {
+ // Extra margin to avoid out-of-bounds writes.
+ // All have space for two rows of data.
+ for (size_t i = 0; i < 4; i++) {
+ pred_errors[i].resize((xsize + 2) * 2);
+ }
+ error.resize((xsize + 2) * 2);
+ // Initialize division lookup table.
+ for (int i = 0; i < 64; i++) {
+ divlookup[i] = (1 << 24) / (i + 1);
+ }
+ }
+
+ // Approximates 4+(maxweight<<24)/(x+1), avoiding division
+ JXL_INLINE uint32_t ErrorWeight(uint64_t x, uint32_t maxweight) const {
+ int shift = static_cast<int>(FloorLog2Nonzero(x + 1)) - 5;
+ if (shift < 0) shift = 0;
+ return 4 + ((maxweight * divlookup[x >> shift]) >> shift);
+ }
+
+ // Approximates the weighted average of the input values with the given
+ // weights, avoiding division. Weights must sum to at least 16.
+ JXL_INLINE pixel_type_w
+ WeightedAverage(const pixel_type_w *JXL_RESTRICT p,
+ std::array<uint32_t, kNumPredictors> w) const {
+ uint32_t weight_sum = 0;
+ for (size_t i = 0; i < kNumPredictors; i++) {
+ weight_sum += w[i];
+ }
+ JXL_DASSERT(weight_sum > 15);
+ uint32_t log_weight = FloorLog2Nonzero(weight_sum); // at least 4.
+ weight_sum = 0;
+ for (size_t i = 0; i < kNumPredictors; i++) {
+ w[i] >>= log_weight - 4;
+ weight_sum += w[i];
+ }
+ // for rounding.
+ pixel_type_w sum = (weight_sum >> 1) - 1;
+ for (size_t i = 0; i < kNumPredictors; i++) {
+ sum += p[i] * w[i];
+ }
+ return (sum * divlookup[weight_sum - 1]) >> 24;
+ }
+
+ template <bool compute_properties>
+ JXL_INLINE pixel_type_w Predict(size_t x, size_t y, size_t xsize,
+ pixel_type_w N, pixel_type_w W,
+ pixel_type_w NE, pixel_type_w NW,
+ pixel_type_w NN, Properties *properties,
+ size_t offset) {
+ size_t cur_row = y & 1 ? 0 : (xsize + 2);
+ size_t prev_row = y & 1 ? (xsize + 2) : 0;
+ size_t pos_N = prev_row + x;
+ size_t pos_NE = x < xsize - 1 ? pos_N + 1 : pos_N;
+ size_t pos_NW = x > 0 ? pos_N - 1 : pos_N;
+ std::array<uint32_t, kNumPredictors> weights;
+ for (size_t i = 0; i < kNumPredictors; i++) {
+ // pred_errors[pos_N] also contains the error of pixel W.
+ // pred_errors[pos_NW] also contains the error of pixel WW.
+ weights[i] = pred_errors[i][pos_N] + pred_errors[i][pos_NE] +
+ pred_errors[i][pos_NW];
+ weights[i] = ErrorWeight(weights[i], header.w[i]);
+ }
+
+ N = AddBits(N);
+ W = AddBits(W);
+ NE = AddBits(NE);
+ NW = AddBits(NW);
+ NN = AddBits(NN);
+
+ pixel_type_w teW = x == 0 ? 0 : error[cur_row + x - 1];
+ pixel_type_w teN = error[pos_N];
+ pixel_type_w teNW = error[pos_NW];
+ pixel_type_w sumWN = teN + teW;
+ pixel_type_w teNE = error[pos_NE];
+
+ if (compute_properties) {
+ pixel_type_w p = teW;
+ if (std::abs(teN) > std::abs(p)) p = teN;
+ if (std::abs(teNW) > std::abs(p)) p = teNW;
+ if (std::abs(teNE) > std::abs(p)) p = teNE;
+ (*properties)[offset++] = p;
+ }
+
+ prediction[0] = W + NE - N;
+ prediction[1] = N - (((sumWN + teNE) * header.p1C) >> 5);
+ prediction[2] = W - (((sumWN + teNW) * header.p2C) >> 5);
+ prediction[3] =
+ N - ((teNW * header.p3Ca + teN * header.p3Cb + teNE * header.p3Cc +
+ (NN - N) * header.p3Cd + (NW - W) * header.p3Ce) >>
+ 5);
+
+ pred = WeightedAverage(prediction, weights);
+
+ // If all three have the same sign, skip clamping.
+ if (((teN ^ teW) | (teN ^ teNW)) > 0) {
+ return (pred + kPredictionRound) >> kPredExtraBits;
+ }
+
+ // Otherwise, clamp to min/max of neighbouring pixels (just W, NE, N).
+ pixel_type_w mx = std::max(W, std::max(NE, N));
+ pixel_type_w mn = std::min(W, std::min(NE, N));
+ pred = std::max(mn, std::min(mx, pred));
+ return (pred + kPredictionRound) >> kPredExtraBits;
+ }
+
+ JXL_INLINE void UpdateErrors(pixel_type_w val, size_t x, size_t y,
+ size_t xsize) {
+ size_t cur_row = y & 1 ? 0 : (xsize + 2);
+ size_t prev_row = y & 1 ? (xsize + 2) : 0;
+ val = AddBits(val);
+ error[cur_row + x] = pred - val;
+ for (size_t i = 0; i < kNumPredictors; i++) {
+ pixel_type_w err =
+ (std::abs(prediction[i] - val) + kPredictionRound) >> kPredExtraBits;
+ // For predicting in the next row.
+ pred_errors[i][cur_row + x] = err;
+ // Add the error on this pixel to the error on the NE pixel. This has the
+ // effect of adding the error on this pixel to the E and EE pixels.
+ pred_errors[i][prev_row + x + 1] += err;
+ }
+ }
+};
+
+// Encoder helper function to set the parameters to some presets.
+inline void PredictorMode(int i, Header *header) {
+ switch (i) {
+ case 0:
+ // ~ lossless16 predictor
+ header->w[0] = 0xd;
+ header->w[1] = 0xc;
+ header->w[2] = 0xc;
+ header->w[3] = 0xc;
+ header->p1C = 16;
+ header->p2C = 10;
+ header->p3Ca = 7;
+ header->p3Cb = 7;
+ header->p3Cc = 7;
+ header->p3Cd = 0;
+ header->p3Ce = 0;
+ break;
+ case 1:
+ // ~ default lossless8 predictor
+ header->w[0] = 0xd;
+ header->w[1] = 0xc;
+ header->w[2] = 0xc;
+ header->w[3] = 0xb;
+ header->p1C = 8;
+ header->p2C = 8;
+ header->p3Ca = 4;
+ header->p3Cb = 0;
+ header->p3Cc = 3;
+ header->p3Cd = 23;
+ header->p3Ce = 2;
+ break;
+ case 2:
+ // ~ west lossless8 predictor
+ header->w[0] = 0xd;
+ header->w[1] = 0xc;
+ header->w[2] = 0xd;
+ header->w[3] = 0xc;
+ header->p1C = 10;
+ header->p2C = 9;
+ header->p3Ca = 7;
+ header->p3Cb = 0;
+ header->p3Cc = 0;
+ header->p3Cd = 16;
+ header->p3Ce = 9;
+ break;
+ case 3:
+ // ~ north lossless8 predictor
+ header->w[0] = 0xd;
+ header->w[1] = 0xd;
+ header->w[2] = 0xc;
+ header->w[3] = 0xc;
+ header->p1C = 16;
+ header->p2C = 8;
+ header->p3Ca = 0;
+ header->p3Cb = 16;
+ header->p3Cc = 0;
+ header->p3Cd = 23;
+ header->p3Ce = 0;
+ break;
+ case 4:
+ default:
+ // something else, because why not
+ header->w[0] = 0xd;
+ header->w[1] = 0xc;
+ header->w[2] = 0xc;
+ header->w[3] = 0xc;
+ header->p1C = 10;
+ header->p2C = 10;
+ header->p3Ca = 5;
+ header->p3Cb = 5;
+ header->p3Cc = 5;
+ header->p3Cd = 12;
+ header->p3Ce = 4;
+ break;
+ }
+}
+} // namespace weighted
+
+// Stores a node and its two children at the same time. This significantly
+// reduces the number of branches needed during decoding.
+struct FlatDecisionNode {
+ // Property + splitval of the top node.
+ int32_t property0; // -1 if leaf.
+ union {
+ PropertyVal splitval0;
+ Predictor predictor;
+ };
+ uint32_t childID; // childID is ctx id if leaf.
+ // Property+splitval of the two child nodes.
+ union {
+ PropertyVal splitvals[2];
+ int32_t multiplier;
+ };
+ union {
+ int32_t properties[2];
+ int64_t predictor_offset;
+ };
+};
+using FlatTree = std::vector<FlatDecisionNode>;
+
+class MATreeLookup {
+ public:
+ explicit MATreeLookup(const FlatTree &tree) : nodes_(tree) {}
+ struct LookupResult {
+ uint32_t context;
+ Predictor predictor;
+ int64_t offset;
+ int32_t multiplier;
+ };
+ JXL_INLINE LookupResult Lookup(const Properties &properties) const {
+ uint32_t pos = 0;
+ while (true) {
+ const FlatDecisionNode &node = nodes_[pos];
+ if (node.property0 < 0) {
+ return {node.childID, node.predictor, node.predictor_offset,
+ node.multiplier};
+ }
+ bool p0 = properties[node.property0] <= node.splitval0;
+ uint32_t off0 = properties[node.properties[0]] <= node.splitvals[0];
+ uint32_t off1 =
+ 2 | (properties[node.properties[1]] <= node.splitvals[1] ? 1 : 0);
+ pos = node.childID + (p0 ? off1 : off0);
+ }
+ }
+
+ private:
+ const FlatTree &nodes_;
+};
+
+static constexpr size_t kExtraPropsPerChannel = 4;
+static constexpr size_t kNumNonrefProperties =
+ kNumStaticProperties + 13 + weighted::kNumProperties;
+
+constexpr size_t kWPProp = kNumNonrefProperties - weighted::kNumProperties;
+constexpr size_t kGradientProp = 9;
+
+// Clamps gradient to the min/max of n, w (and l, implicitly).
+static JXL_INLINE int32_t ClampedGradient(const int32_t n, const int32_t w,
+ const int32_t l) {
+ const int32_t m = std::min(n, w);
+ const int32_t M = std::max(n, w);
+ // The end result of this operation doesn't overflow or underflow if the
+ // result is between m and M, but the intermediate value may overflow, so we
+ // do the intermediate operations in uint32_t and check later if we had an
+ // overflow or underflow condition comparing m, M and l directly.
+ // grad = M + m - l = n + w - l
+ const int32_t grad =
+ static_cast<int32_t>(static_cast<uint32_t>(n) + static_cast<uint32_t>(w) -
+ static_cast<uint32_t>(l));
+ // We use two sets of ternary operators to force the evaluation of them in
+ // any case, allowing the compiler to avoid branches and use cmovl/cmovg in
+ // x86.
+ const int32_t grad_clamp_M = (l < m) ? M : grad;
+ return (l > M) ? m : grad_clamp_M;
+}
+
+inline pixel_type_w Select(pixel_type_w a, pixel_type_w b, pixel_type_w c) {
+ pixel_type_w p = a + b - c;
+ pixel_type_w pa = std::abs(p - a);
+ pixel_type_w pb = std::abs(p - b);
+ return pa < pb ? a : b;
+}
+
+inline void PrecomputeReferences(const Channel &ch, size_t y,
+ const Image &image, uint32_t i,
+ Channel *references) {
+ ZeroFillImage(&references->plane);
+ uint32_t offset = 0;
+ size_t num_extra_props = references->w;
+ intptr_t onerow = references->plane.PixelsPerRow();
+ for (int32_t j = static_cast<int32_t>(i) - 1;
+ j >= 0 && offset < num_extra_props; j--) {
+ if (image.channel[j].w != image.channel[i].w ||
+ image.channel[j].h != image.channel[i].h) {
+ continue;
+ }
+ if (image.channel[j].hshift != image.channel[i].hshift) continue;
+ if (image.channel[j].vshift != image.channel[i].vshift) continue;
+ pixel_type *JXL_RESTRICT rp = references->Row(0) + offset;
+ const pixel_type *JXL_RESTRICT rpp = image.channel[j].Row(y);
+ const pixel_type *JXL_RESTRICT rpprev = image.channel[j].Row(y ? y - 1 : 0);
+ for (size_t x = 0; x < ch.w; x++, rp += onerow) {
+ pixel_type_w v = rpp[x];
+ rp[0] = std::abs(v);
+ rp[1] = v;
+ pixel_type_w vleft = (x ? rpp[x - 1] : 0);
+ pixel_type_w vtop = (y ? rpprev[x] : vleft);
+ pixel_type_w vtopleft = (x && y ? rpprev[x - 1] : vleft);
+ pixel_type_w vpredicted = ClampedGradient(vleft, vtop, vtopleft);
+ rp[2] = std::abs(v - vpredicted);
+ rp[3] = v - vpredicted;
+ }
+
+ offset += kExtraPropsPerChannel;
+ }
+}
+
+struct PredictionResult {
+ int context = 0;
+ pixel_type_w guess = 0;
+ Predictor predictor;
+ int32_t multiplier;
+};
+
+inline void InitPropsRow(
+ Properties *p,
+ const std::array<pixel_type, kNumStaticProperties> &static_props,
+ const int y) {
+ for (size_t i = 0; i < kNumStaticProperties; i++) {
+ (*p)[i] = static_props[i];
+ }
+ (*p)[2] = y;
+ (*p)[9] = 0; // local gradient.
+}
+
+namespace detail {
+enum PredictorMode {
+ kUseTree = 1,
+ kUseWP = 2,
+ kForceComputeProperties = 4,
+ kAllPredictions = 8,
+ kNoEdgeCases = 16
+};
+
+JXL_INLINE pixel_type_w PredictOne(Predictor p, pixel_type_w left,
+ pixel_type_w top, pixel_type_w toptop,
+ pixel_type_w topleft, pixel_type_w topright,
+ pixel_type_w leftleft,
+ pixel_type_w toprightright,
+ pixel_type_w wp_pred) {
+ switch (p) {
+ case Predictor::Zero:
+ return pixel_type_w{0};
+ case Predictor::Left:
+ return left;
+ case Predictor::Top:
+ return top;
+ case Predictor::Select:
+ return Select(left, top, topleft);
+ case Predictor::Weighted:
+ return wp_pred;
+ case Predictor::Gradient:
+ return pixel_type_w{ClampedGradient(left, top, topleft)};
+ case Predictor::TopLeft:
+ return topleft;
+ case Predictor::TopRight:
+ return topright;
+ case Predictor::LeftLeft:
+ return leftleft;
+ case Predictor::Average0:
+ return (left + top) / 2;
+ case Predictor::Average1:
+ return (left + topleft) / 2;
+ case Predictor::Average2:
+ return (topleft + top) / 2;
+ case Predictor::Average3:
+ return (top + topright) / 2;
+ case Predictor::Average4:
+ return (6 * top - 2 * toptop + 7 * left + 1 * leftleft +
+ 1 * toprightright + 3 * topright + 8) /
+ 16;
+ default:
+ return pixel_type_w{0};
+ }
+}
+
+template <int mode>
+JXL_INLINE PredictionResult Predict(
+ Properties *p, size_t w, const pixel_type *JXL_RESTRICT pp,
+ const intptr_t onerow, const size_t x, const size_t y, Predictor predictor,
+ const MATreeLookup *lookup, const Channel *references,
+ weighted::State *wp_state, pixel_type_w *predictions) {
+ // We start in position 3 because of 2 static properties + y.
+ size_t offset = 3;
+ constexpr bool compute_properties =
+ mode & kUseTree || mode & kForceComputeProperties;
+ constexpr bool nec = mode & kNoEdgeCases;
+ pixel_type_w left = (nec || x ? pp[-1] : (y ? pp[-onerow] : 0));
+ pixel_type_w top = (nec || y ? pp[-onerow] : left);
+ pixel_type_w topleft = (nec || (x && y) ? pp[-1 - onerow] : left);
+ pixel_type_w topright = (nec || (x + 1 < w && y) ? pp[1 - onerow] : top);
+ pixel_type_w leftleft = (nec || x > 1 ? pp[-2] : left);
+ pixel_type_w toptop = (nec || y > 1 ? pp[-onerow - onerow] : top);
+ pixel_type_w toprightright =
+ (nec || (x + 2 < w && y) ? pp[2 - onerow] : topright);
+
+ if (compute_properties) {
+ // location
+ (*p)[offset++] = x;
+ // neighbors
+ (*p)[offset++] = std::abs(top);
+ (*p)[offset++] = std::abs(left);
+ (*p)[offset++] = top;
+ (*p)[offset++] = left;
+
+ // local gradient
+ (*p)[offset] = left - (*p)[offset + 1];
+ offset++;
+ // local gradient
+ (*p)[offset++] = left + top - topleft;
+
+ // FFV1 context properties
+ (*p)[offset++] = left - topleft;
+ (*p)[offset++] = topleft - top;
+ (*p)[offset++] = top - topright;
+ (*p)[offset++] = top - toptop;
+ (*p)[offset++] = left - leftleft;
+ }
+
+ pixel_type_w wp_pred = 0;
+ if (mode & kUseWP) {
+ wp_pred = wp_state->Predict<compute_properties>(
+ x, y, w, top, left, topright, topleft, toptop, p, offset);
+ }
+ if (!nec && compute_properties) {
+ offset += weighted::kNumProperties;
+ // Extra properties.
+ const pixel_type *JXL_RESTRICT rp = references->Row(x);
+ for (size_t i = 0; i < references->w; i++) {
+ (*p)[offset++] = rp[i];
+ }
+ }
+ PredictionResult result;
+ if (mode & kUseTree) {
+ MATreeLookup::LookupResult lr = lookup->Lookup(*p);
+ result.context = lr.context;
+ result.guess = lr.offset;
+ result.multiplier = lr.multiplier;
+ predictor = lr.predictor;
+ }
+ if (mode & kAllPredictions) {
+ for (size_t i = 0; i < kNumModularPredictors; i++) {
+ predictions[i] = PredictOne((Predictor)i, left, top, toptop, topleft,
+ topright, leftleft, toprightright, wp_pred);
+ }
+ }
+ result.guess += PredictOne(predictor, left, top, toptop, topleft, topright,
+ leftleft, toprightright, wp_pred);
+ result.predictor = predictor;
+
+ return result;
+}
+} // namespace detail
+
+inline PredictionResult PredictNoTreeNoWP(size_t w,
+ const pixel_type *JXL_RESTRICT pp,
+ const intptr_t onerow, const int x,
+ const int y, Predictor predictor) {
+ return detail::Predict</*mode=*/0>(
+ /*p=*/nullptr, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr,
+ /*references=*/nullptr, /*wp_state=*/nullptr, /*predictions=*/nullptr);
+}
+
+inline PredictionResult PredictNoTreeWP(size_t w,
+ const pixel_type *JXL_RESTRICT pp,
+ const intptr_t onerow, const int x,
+ const int y, Predictor predictor,
+ weighted::State *wp_state) {
+ return detail::Predict<detail::kUseWP>(
+ /*p=*/nullptr, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr,
+ /*references=*/nullptr, wp_state, /*predictions=*/nullptr);
+}
+
+inline PredictionResult PredictTreeNoWP(Properties *p, size_t w,
+ const pixel_type *JXL_RESTRICT pp,
+ const intptr_t onerow, const int x,
+ const int y,
+ const MATreeLookup &tree_lookup,
+ const Channel &references) {
+ return detail::Predict<detail::kUseTree>(
+ p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references,
+ /*wp_state=*/nullptr, /*predictions=*/nullptr);
+}
+// Only use for y > 1, x > 1, x < w-2, and empty references
+JXL_INLINE PredictionResult
+PredictTreeNoWPNEC(Properties *p, size_t w, const pixel_type *JXL_RESTRICT pp,
+ const intptr_t onerow, const int x, const int y,
+ const MATreeLookup &tree_lookup, const Channel &references) {
+ return detail::Predict<detail::kUseTree | detail::kNoEdgeCases>(
+ p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references,
+ /*wp_state=*/nullptr, /*predictions=*/nullptr);
+}
+
+inline PredictionResult PredictTreeWP(Properties *p, size_t w,
+ const pixel_type *JXL_RESTRICT pp,
+ const intptr_t onerow, const int x,
+ const int y,
+ const MATreeLookup &tree_lookup,
+ const Channel &references,
+ weighted::State *wp_state) {
+ return detail::Predict<detail::kUseTree | detail::kUseWP>(
+ p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references,
+ wp_state, /*predictions=*/nullptr);
+}
+
+inline PredictionResult PredictLearn(Properties *p, size_t w,
+ const pixel_type *JXL_RESTRICT pp,
+ const intptr_t onerow, const int x,
+ const int y, Predictor predictor,
+ const Channel &references,
+ weighted::State *wp_state) {
+ return detail::Predict<detail::kForceComputeProperties | detail::kUseWP>(
+ p, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, &references,
+ wp_state, /*predictions=*/nullptr);
+}
+
+inline void PredictLearnAll(Properties *p, size_t w,
+ const pixel_type *JXL_RESTRICT pp,
+ const intptr_t onerow, const int x, const int y,
+ const Channel &references,
+ weighted::State *wp_state,
+ pixel_type_w *predictions) {
+ detail::Predict<detail::kForceComputeProperties | detail::kUseWP |
+ detail::kAllPredictions>(
+ p, w, pp, onerow, x, y, Predictor::Zero,
+ /*lookup=*/nullptr, &references, wp_state, predictions);
+}
+
+inline void PredictAllNoWP(size_t w, const pixel_type *JXL_RESTRICT pp,
+ const intptr_t onerow, const int x, const int y,
+ pixel_type_w *predictions) {
+ detail::Predict<detail::kAllPredictions>(
+ /*p=*/nullptr, w, pp, onerow, x, y, Predictor::Zero,
+ /*lookup=*/nullptr,
+ /*references=*/nullptr, /*wp_state=*/nullptr, predictions);
+}
+} // namespace jxl
+
+#endif // LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_
diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/dec_ma.cc b/third_party/jpeg-xl/lib/jxl/modular/encoding/dec_ma.cc
new file mode 100644
index 0000000000..66562f7dfd
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/dec_ma.cc
@@ -0,0 +1,107 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#include "lib/jxl/modular/encoding/dec_ma.h"
+
+#include "lib/jxl/base/printf_macros.h"
+#include "lib/jxl/dec_ans.h"
+#include "lib/jxl/modular/encoding/ma_common.h"
+#include "lib/jxl/modular/modular_image.h"
+
+namespace jxl {
+
+namespace {
+
+Status ValidateTree(
+ const Tree &tree,
+ const std::vector<std::pair<pixel_type, pixel_type>> &prop_bounds,
+ size_t root) {
+ if (tree[root].property == -1) return true;
+ size_t p = tree[root].property;
+ int val = tree[root].splitval;
+ if (prop_bounds[p].first > val) return JXL_FAILURE("Invalid tree");
+ // Splitting at max value makes no sense: left range will be exactly same
+ // as parent, right range will be invalid (min > max).
+ if (prop_bounds[p].second <= val) return JXL_FAILURE("Invalid tree");
+ auto new_bounds = prop_bounds;
+ new_bounds[p].first = val + 1;
+ JXL_RETURN_IF_ERROR(ValidateTree(tree, new_bounds, tree[root].lchild));
+ new_bounds[p] = prop_bounds[p];
+ new_bounds[p].second = val;
+ return ValidateTree(tree, new_bounds, tree[root].rchild);
+}
+
+Status DecodeTree(BitReader *br, ANSSymbolReader *reader,
+ const std::vector<uint8_t> &context_map, Tree *tree,
+ size_t tree_size_limit) {
+ size_t leaf_id = 0;
+ size_t to_decode = 1;
+ tree->clear();
+ while (to_decode > 0) {
+ JXL_RETURN_IF_ERROR(br->AllReadsWithinBounds());
+ if (tree->size() > tree_size_limit) {
+ return JXL_FAILURE("Tree is too large: %" PRIuS " nodes vs %" PRIuS
+ " max nodes",
+ tree->size(), tree_size_limit);
+ }
+ to_decode--;
+ uint32_t prop1 = reader->ReadHybridUint(kPropertyContext, br, context_map);
+ if (prop1 > 256) return JXL_FAILURE("Invalid tree property value");
+ int property = prop1 - 1;
+ if (property == -1) {
+ size_t predictor =
+ reader->ReadHybridUint(kPredictorContext, br, context_map);
+ if (predictor >= kNumModularPredictors) {
+ return JXL_FAILURE("Invalid predictor");
+ }
+ int64_t predictor_offset =
+ UnpackSigned(reader->ReadHybridUint(kOffsetContext, br, context_map));
+ uint32_t mul_log =
+ reader->ReadHybridUint(kMultiplierLogContext, br, context_map);
+ if (mul_log >= 31) {
+ return JXL_FAILURE("Invalid multiplier logarithm");
+ }
+ uint32_t mul_bits =
+ reader->ReadHybridUint(kMultiplierBitsContext, br, context_map);
+ if (mul_bits + 1 >= 1u << (31u - mul_log)) {
+ return JXL_FAILURE("Invalid multiplier");
+ }
+ uint32_t multiplier = (mul_bits + 1U) << mul_log;
+ tree->emplace_back(-1, 0, leaf_id++, 0, static_cast<Predictor>(predictor),
+ predictor_offset, multiplier);
+ continue;
+ }
+ int splitval =
+ UnpackSigned(reader->ReadHybridUint(kSplitValContext, br, context_map));
+ tree->emplace_back(property, splitval, tree->size() + to_decode + 1,
+ tree->size() + to_decode + 2, Predictor::Zero, 0, 1);
+ to_decode += 2;
+ }
+ std::vector<std::pair<pixel_type, pixel_type>> prop_bounds;
+ prop_bounds.resize(256, {std::numeric_limits<pixel_type>::min(),
+ std::numeric_limits<pixel_type>::max()});
+ return ValidateTree(*tree, prop_bounds, 0);
+}
+} // namespace
+
+Status DecodeTree(BitReader *br, Tree *tree, size_t tree_size_limit) {
+ std::vector<uint8_t> tree_context_map;
+ ANSCode tree_code;
+ JXL_RETURN_IF_ERROR(
+ DecodeHistograms(br, kNumTreeContexts, &tree_code, &tree_context_map));
+ // TODO(eustas): investigate more infinite tree cases.
+ if (tree_code.degenerate_symbols[tree_context_map[kPropertyContext]] > 0) {
+ return JXL_FAILURE("Infinite tree");
+ }
+ ANSSymbolReader reader(&tree_code, br);
+ JXL_RETURN_IF_ERROR(DecodeTree(br, &reader, tree_context_map, tree,
+ std::min(tree_size_limit, kMaxTreeSize)));
+ if (!reader.CheckANSFinalState()) {
+ return JXL_FAILURE("ANS decode final state failed");
+ }
+ return true;
+}
+
+} // namespace jxl
diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/dec_ma.h b/third_party/jpeg-xl/lib/jxl/modular/encoding/dec_ma.h
new file mode 100644
index 0000000000..a910c4deb1
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/dec_ma.h
@@ -0,0 +1,66 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#ifndef LIB_JXL_MODULAR_ENCODING_DEC_MA_H_
+#define LIB_JXL_MODULAR_ENCODING_DEC_MA_H_
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <vector>
+
+#include "lib/jxl/base/status.h"
+#include "lib/jxl/dec_bit_reader.h"
+#include "lib/jxl/modular/options.h"
+
+namespace jxl {
+
+// inner nodes
+struct PropertyDecisionNode {
+ PropertyVal splitval;
+ int16_t property; // -1: leaf node, lchild points to leaf node
+ uint32_t lchild;
+ uint32_t rchild;
+ Predictor predictor;
+ int64_t predictor_offset;
+ uint32_t multiplier;
+
+ PropertyDecisionNode(int p, int split_val, int lchild, int rchild,
+ Predictor predictor, int64_t predictor_offset,
+ uint32_t multiplier)
+ : splitval(split_val),
+ property(p),
+ lchild(lchild),
+ rchild(rchild),
+ predictor(predictor),
+ predictor_offset(predictor_offset),
+ multiplier(multiplier) {}
+ PropertyDecisionNode()
+ : splitval(0),
+ property(-1),
+ lchild(0),
+ rchild(0),
+ predictor(Predictor::Zero),
+ predictor_offset(0),
+ multiplier(1) {}
+ static PropertyDecisionNode Leaf(Predictor predictor, int64_t offset = 0,
+ uint32_t multiplier = 1) {
+ return PropertyDecisionNode(-1, 0, 0, 0, predictor, offset, multiplier);
+ }
+ static PropertyDecisionNode Split(int p, int split_val, int lchild,
+ int rchild = -1) {
+ if (rchild == -1) rchild = lchild + 1;
+ return PropertyDecisionNode(p, split_val, lchild, rchild, Predictor::Zero,
+ 0, 1);
+ }
+};
+
+using Tree = std::vector<PropertyDecisionNode>;
+
+Status DecodeTree(BitReader *br, Tree *tree, size_t tree_size_limit);
+
+} // namespace jxl
+
+#endif // LIB_JXL_MODULAR_ENCODING_DEC_MA_H_
diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_debug_tree.cc b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_debug_tree.cc
new file mode 100644
index 0000000000..f2a1705e4b
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_debug_tree.cc
@@ -0,0 +1,124 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#include "lib/jxl/modular/encoding/enc_debug_tree.h"
+
+#include <stdint.h>
+#include <stdlib.h>
+
+#include "lib/jxl/base/os_macros.h"
+#include "lib/jxl/base/printf_macros.h"
+#include "lib/jxl/base/status.h"
+#include "lib/jxl/modular/encoding/context_predict.h"
+#include "lib/jxl/modular/encoding/dec_ma.h"
+#include "lib/jxl/modular/options.h"
+
+#if JXL_OS_IOS
+#define JXL_ENABLE_DOT 0
+#else
+#define JXL_ENABLE_DOT 1 // iOS lacks C89 system()
+#endif
+
+namespace jxl {
+
+const char *PredictorName(Predictor p) {
+ switch (p) {
+ case Predictor::Zero:
+ return "Zero";
+ case Predictor::Left:
+ return "Left";
+ case Predictor::Top:
+ return "Top";
+ case Predictor::Average0:
+ return "Avg0";
+ case Predictor::Average1:
+ return "Avg1";
+ case Predictor::Average2:
+ return "Avg2";
+ case Predictor::Average3:
+ return "Avg3";
+ case Predictor::Average4:
+ return "Avg4";
+ case Predictor::Select:
+ return "Sel";
+ case Predictor::Gradient:
+ return "Grd";
+ case Predictor::Weighted:
+ return "Wgh";
+ case Predictor::TopLeft:
+ return "TopL";
+ case Predictor::TopRight:
+ return "TopR";
+ case Predictor::LeftLeft:
+ return "LL";
+ default:
+ return "INVALID";
+ };
+}
+
+std::string PropertyName(size_t i) {
+ static_assert(kNumNonrefProperties == 16, "Update this function");
+ switch (i) {
+ case 0:
+ return "c";
+ case 1:
+ return "g";
+ case 2:
+ return "y";
+ case 3:
+ return "x";
+ case 4:
+ return "|N|";
+ case 5:
+ return "|W|";
+ case 6:
+ return "N";
+ case 7:
+ return "W";
+ case 8:
+ return "W-WW-NW+NWW";
+ case 9:
+ return "W+N-NW";
+ case 10:
+ return "W-NW";
+ case 11:
+ return "NW-N";
+ case 12:
+ return "N-NE";
+ case 13:
+ return "N-NN";
+ case 14:
+ return "W-WW";
+ case 15:
+ return "WGH";
+ default:
+ return "ch[" + ToString(15 - (int)i) + "]";
+ }
+}
+
+void PrintTree(const Tree &tree, const std::string &path) {
+ FILE *f = fopen((path + ".dot").c_str(), "w");
+ fprintf(f, "graph{\n");
+ for (size_t cur = 0; cur < tree.size(); cur++) {
+ if (tree[cur].property < 0) {
+ fprintf(f, "n%05" PRIuS " [label=\"%s%+" PRId64 " (x%u)\"];\n", cur,
+ PredictorName(tree[cur].predictor), tree[cur].predictor_offset,
+ tree[cur].multiplier);
+ } else {
+ fprintf(f, "n%05" PRIuS " [label=\"%s>%d\"];\n", cur,
+ PropertyName(tree[cur].property).c_str(), tree[cur].splitval);
+ fprintf(f, "n%05" PRIuS " -- n%05d;\n", cur, tree[cur].lchild);
+ fprintf(f, "n%05" PRIuS " -- n%05d;\n", cur, tree[cur].rchild);
+ }
+ }
+ fprintf(f, "}\n");
+ fclose(f);
+#if JXL_ENABLE_DOT
+ JXL_ASSERT(
+ system(("dot " + path + ".dot -T svg -o " + path + ".svg").c_str()) == 0);
+#endif
+}
+
+} // namespace jxl
diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_debug_tree.h b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_debug_tree.h
new file mode 100644
index 0000000000..78deaab1b8
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_debug_tree.h
@@ -0,0 +1,27 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#ifndef LIB_JXL_MODULAR_ENCODING_ENC_DEBUG_TREE_H_
+#define LIB_JXL_MODULAR_ENCODING_ENC_DEBUG_TREE_H_
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <string>
+#include <vector>
+
+#include "lib/jxl/modular/encoding/dec_ma.h"
+#include "lib/jxl/modular/options.h"
+
+namespace jxl {
+
+const char *PredictorName(Predictor p);
+std::string PropertyName(size_t i);
+
+void PrintTree(const Tree &tree, const std::string &path);
+
+} // namespace jxl
+
+#endif // LIB_JXL_MODULAR_ENCODING_ENC_DEBUG_TREE_H_
diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_encoding.cc b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_encoding.cc
new file mode 100644
index 0000000000..c8c183335e
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_encoding.cc
@@ -0,0 +1,562 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#include <stdint.h>
+#include <stdlib.h>
+
+#include <cinttypes>
+#include <limits>
+#include <numeric>
+#include <queue>
+#include <set>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "lib/jxl/base/printf_macros.h"
+#include "lib/jxl/base/status.h"
+#include "lib/jxl/common.h"
+#include "lib/jxl/dec_ans.h"
+#include "lib/jxl/dec_bit_reader.h"
+#include "lib/jxl/enc_ans.h"
+#include "lib/jxl/enc_aux_out.h"
+#include "lib/jxl/enc_bit_writer.h"
+#include "lib/jxl/enc_fields.h"
+#include "lib/jxl/entropy_coder.h"
+#include "lib/jxl/fields.h"
+#include "lib/jxl/image_ops.h"
+#include "lib/jxl/modular/encoding/context_predict.h"
+#include "lib/jxl/modular/encoding/enc_debug_tree.h"
+#include "lib/jxl/modular/encoding/enc_ma.h"
+#include "lib/jxl/modular/encoding/encoding.h"
+#include "lib/jxl/modular/encoding/ma_common.h"
+#include "lib/jxl/modular/options.h"
+#include "lib/jxl/modular/transform/transform.h"
+#include "lib/jxl/toc.h"
+
+namespace jxl {
+
+namespace {
+// Plot tree (if enabled) and predictor usage map.
+constexpr bool kWantDebug = false;
+constexpr bool kPrintTree = false;
+
+inline std::array<uint8_t, 3> PredictorColor(Predictor p) {
+ switch (p) {
+ case Predictor::Zero:
+ return {{0, 0, 0}};
+ case Predictor::Left:
+ return {{255, 0, 0}};
+ case Predictor::Top:
+ return {{0, 255, 0}};
+ case Predictor::Average0:
+ return {{0, 0, 255}};
+ case Predictor::Average4:
+ return {{192, 128, 128}};
+ case Predictor::Select:
+ return {{255, 255, 0}};
+ case Predictor::Gradient:
+ return {{255, 0, 255}};
+ case Predictor::Weighted:
+ return {{0, 255, 255}};
+ // TODO
+ default:
+ return {{255, 255, 255}};
+ };
+}
+
+} // namespace
+
+void GatherTreeData(const Image &image, pixel_type chan, size_t group_id,
+ const weighted::Header &wp_header,
+ const ModularOptions &options, TreeSamples &tree_samples,
+ size_t *total_pixels) {
+ const Channel &channel = image.channel[chan];
+
+ JXL_DEBUG_V(7, "Learning %" PRIuS "x%" PRIuS " channel %d", channel.w,
+ channel.h, chan);
+
+ std::array<pixel_type, kNumStaticProperties> static_props = {
+ {chan, (int)group_id}};
+ Properties properties(kNumNonrefProperties +
+ kExtraPropsPerChannel * options.max_properties);
+ double pixel_fraction = std::min(1.0f, options.nb_repeats);
+ // a fraction of 0 is used to disable learning entirely.
+ if (pixel_fraction > 0) {
+ pixel_fraction = std::max(pixel_fraction,
+ std::min(1.0, 1024.0 / (channel.w * channel.h)));
+ }
+ uint64_t threshold =
+ (std::numeric_limits<uint64_t>::max() >> 32) * pixel_fraction;
+ uint64_t s[2] = {static_cast<uint64_t>(0x94D049BB133111EBull),
+ static_cast<uint64_t>(0xBF58476D1CE4E5B9ull)};
+ // Xorshift128+ adapted from xorshift128+-inl.h
+ auto use_sample = [&]() {
+ auto s1 = s[0];
+ const auto s0 = s[1];
+ const auto bits = s1 + s0; // b, c
+ s[0] = s0;
+ s1 ^= s1 << 23;
+ s1 ^= s0 ^ (s1 >> 18) ^ (s0 >> 5);
+ s[1] = s1;
+ return (bits >> 32) <= threshold;
+ };
+
+ const intptr_t onerow = channel.plane.PixelsPerRow();
+ Channel references(properties.size() - kNumNonrefProperties, channel.w);
+ weighted::State wp_state(wp_header, channel.w, channel.h);
+ tree_samples.PrepareForSamples(pixel_fraction * channel.h * channel.w + 64);
+ for (size_t y = 0; y < channel.h; y++) {
+ const pixel_type *JXL_RESTRICT p = channel.Row(y);
+ PrecomputeReferences(channel, y, image, chan, &references);
+ InitPropsRow(&properties, static_props, y);
+ // TODO(veluca): avoid computing WP if we don't use its property or
+ // predictions.
+ for (size_t x = 0; x < channel.w; x++) {
+ pixel_type_w pred[kNumModularPredictors];
+ if (tree_samples.NumPredictors() != 1) {
+ PredictLearnAll(&properties, channel.w, p + x, onerow, x, y, references,
+ &wp_state, pred);
+ } else {
+ pred[static_cast<int>(tree_samples.PredictorFromIndex(0))] =
+ PredictLearn(&properties, channel.w, p + x, onerow, x, y,
+ tree_samples.PredictorFromIndex(0), references,
+ &wp_state)
+ .guess;
+ }
+ (*total_pixels)++;
+ if (use_sample()) {
+ tree_samples.AddSample(p[x], properties, pred);
+ }
+ wp_state.UpdateErrors(p[x], x, y, channel.w);
+ }
+ }
+}
+
+Tree LearnTree(TreeSamples &&tree_samples, size_t total_pixels,
+ const ModularOptions &options,
+ const std::vector<ModularMultiplierInfo> &multiplier_info = {},
+ StaticPropRange static_prop_range = {}) {
+ for (size_t i = 0; i < kNumStaticProperties; i++) {
+ if (static_prop_range[i][1] == 0) {
+ static_prop_range[i][1] = std::numeric_limits<uint32_t>::max();
+ }
+ }
+ if (!tree_samples.HasSamples()) {
+ Tree tree;
+ tree.emplace_back();
+ tree.back().predictor = tree_samples.PredictorFromIndex(0);
+ tree.back().property = -1;
+ tree.back().predictor_offset = 0;
+ tree.back().multiplier = 1;
+ return tree;
+ }
+ float pixel_fraction = tree_samples.NumSamples() * 1.0f / total_pixels;
+ float required_cost = pixel_fraction * 0.9 + 0.1;
+ tree_samples.AllSamplesDone();
+ Tree tree;
+ ComputeBestTree(tree_samples,
+ options.splitting_heuristics_node_threshold * required_cost,
+ multiplier_info, static_prop_range,
+ options.fast_decode_multiplier, &tree);
+ return tree;
+}
+
+Status EncodeModularChannelMAANS(const Image &image, pixel_type chan,
+ const weighted::Header &wp_header,
+ const Tree &global_tree, Token **tokenpp,
+ AuxOut *aux_out, size_t group_id,
+ bool skip_encoder_fast_path) {
+ const Channel &channel = image.channel[chan];
+ Token *tokenp = *tokenpp;
+ JXL_ASSERT(channel.w != 0 && channel.h != 0);
+
+ Image3F predictor_img;
+ if (kWantDebug) predictor_img = Image3F(channel.w, channel.h);
+
+ JXL_DEBUG_V(6,
+ "Encoding %" PRIuS "x%" PRIuS
+ " channel %d, "
+ "(shift=%i,%i)",
+ channel.w, channel.h, chan, channel.hshift, channel.vshift);
+
+ std::array<pixel_type, kNumStaticProperties> static_props = {
+ {chan, (int)group_id}};
+ bool use_wp, is_wp_only;
+ bool is_gradient_only;
+ size_t num_props;
+ FlatTree tree = FilterTree(global_tree, static_props, &num_props, &use_wp,
+ &is_wp_only, &is_gradient_only);
+ Properties properties(num_props);
+ MATreeLookup tree_lookup(tree);
+ JXL_DEBUG_V(3, "Encoding using a MA tree with %" PRIuS " nodes", tree.size());
+
+ // Check if this tree is a WP-only tree with a small enough property value
+ // range.
+ // Initialized to avoid clang-tidy complaining.
+ uint16_t context_lookup[2 * kPropRangeFast] = {};
+ int8_t offsets[2 * kPropRangeFast] = {};
+ if (is_wp_only) {
+ is_wp_only = TreeToLookupTable(tree, context_lookup, offsets);
+ }
+ if (is_gradient_only) {
+ is_gradient_only = TreeToLookupTable(tree, context_lookup, offsets);
+ }
+
+ if (is_wp_only && !skip_encoder_fast_path) {
+ for (size_t c = 0; c < 3; c++) {
+ FillImage(static_cast<float>(PredictorColor(Predictor::Weighted)[c]),
+ &predictor_img.Plane(c));
+ }
+ const intptr_t onerow = channel.plane.PixelsPerRow();
+ weighted::State wp_state(wp_header, channel.w, channel.h);
+ Properties properties(1);
+ for (size_t y = 0; y < channel.h; y++) {
+ const pixel_type *JXL_RESTRICT r = channel.Row(y);
+ for (size_t x = 0; x < channel.w; x++) {
+ size_t offset = 0;
+ pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0);
+ pixel_type_w top = (y ? *(r + x - onerow) : left);
+ pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left);
+ pixel_type_w topright =
+ (x + 1 < channel.w && y ? *(r + x + 1 - onerow) : top);
+ pixel_type_w toptop = (y > 1 ? *(r + x - onerow - onerow) : top);
+ int32_t guess = wp_state.Predict</*compute_properties=*/true>(
+ x, y, channel.w, top, left, topright, topleft, toptop, &properties,
+ offset);
+ uint32_t pos =
+ kPropRangeFast + std::min(std::max(-kPropRangeFast, properties[0]),
+ kPropRangeFast - 1);
+ uint32_t ctx_id = context_lookup[pos];
+ int32_t residual = r[x] - guess - offsets[pos];
+ *tokenp++ = Token(ctx_id, PackSigned(residual));
+ wp_state.UpdateErrors(r[x], x, y, channel.w);
+ }
+ }
+ } else if (tree.size() == 1 && tree[0].predictor == Predictor::Gradient &&
+ tree[0].multiplier == 1 && tree[0].predictor_offset == 0 &&
+ !skip_encoder_fast_path) {
+ for (size_t c = 0; c < 3; c++) {
+ FillImage(static_cast<float>(PredictorColor(Predictor::Gradient)[c]),
+ &predictor_img.Plane(c));
+ }
+ const intptr_t onerow = channel.plane.PixelsPerRow();
+ for (size_t y = 0; y < channel.h; y++) {
+ const pixel_type *JXL_RESTRICT r = channel.Row(y);
+ for (size_t x = 0; x < channel.w; x++) {
+ pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0);
+ pixel_type_w top = (y ? *(r + x - onerow) : left);
+ pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left);
+ int32_t guess = ClampedGradient(top, left, topleft);
+ int32_t residual = r[x] - guess;
+ *tokenp++ = Token(tree[0].childID, PackSigned(residual));
+ }
+ }
+ } else if (is_gradient_only && !skip_encoder_fast_path) {
+ for (size_t c = 0; c < 3; c++) {
+ FillImage(static_cast<float>(PredictorColor(Predictor::Gradient)[c]),
+ &predictor_img.Plane(c));
+ }
+ const intptr_t onerow = channel.plane.PixelsPerRow();
+ for (size_t y = 0; y < channel.h; y++) {
+ const pixel_type *JXL_RESTRICT r = channel.Row(y);
+ for (size_t x = 0; x < channel.w; x++) {
+ pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0);
+ pixel_type_w top = (y ? *(r + x - onerow) : left);
+ pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left);
+ int32_t guess = ClampedGradient(top, left, topleft);
+ uint32_t pos =
+ kPropRangeFast +
+ std::min<pixel_type_w>(
+ std::max<pixel_type_w>(-kPropRangeFast, top + left - topleft),
+ kPropRangeFast - 1);
+ uint32_t ctx_id = context_lookup[pos];
+ int32_t residual = r[x] - guess - offsets[pos];
+ *tokenp++ = Token(ctx_id, PackSigned(residual));
+ }
+ }
+ } else if (tree.size() == 1 && tree[0].predictor == Predictor::Zero &&
+ tree[0].multiplier == 1 && tree[0].predictor_offset == 0 &&
+ !skip_encoder_fast_path) {
+ for (size_t c = 0; c < 3; c++) {
+ FillImage(static_cast<float>(PredictorColor(Predictor::Zero)[c]),
+ &predictor_img.Plane(c));
+ }
+ for (size_t y = 0; y < channel.h; y++) {
+ const pixel_type *JXL_RESTRICT p = channel.Row(y);
+ for (size_t x = 0; x < channel.w; x++) {
+ *tokenp++ = Token(tree[0].childID, PackSigned(p[x]));
+ }
+ }
+ } else if (tree.size() == 1 && tree[0].predictor != Predictor::Weighted &&
+ (tree[0].multiplier & (tree[0].multiplier - 1)) == 0 &&
+ tree[0].predictor_offset == 0 && !skip_encoder_fast_path) {
+ // multiplier is a power of 2.
+ for (size_t c = 0; c < 3; c++) {
+ FillImage(static_cast<float>(PredictorColor(tree[0].predictor)[c]),
+ &predictor_img.Plane(c));
+ }
+ uint32_t mul_shift = FloorLog2Nonzero((uint32_t)tree[0].multiplier);
+ const intptr_t onerow = channel.plane.PixelsPerRow();
+ for (size_t y = 0; y < channel.h; y++) {
+ const pixel_type *JXL_RESTRICT r = channel.Row(y);
+ for (size_t x = 0; x < channel.w; x++) {
+ PredictionResult pred = PredictNoTreeNoWP(channel.w, r + x, onerow, x,
+ y, tree[0].predictor);
+ pixel_type_w residual = r[x] - pred.guess;
+ JXL_DASSERT((residual >> mul_shift) * tree[0].multiplier == residual);
+ *tokenp++ = Token(tree[0].childID, PackSigned(residual >> mul_shift));
+ }
+ }
+
+ } else if (!use_wp && !skip_encoder_fast_path) {
+ const intptr_t onerow = channel.plane.PixelsPerRow();
+ Channel references(properties.size() - kNumNonrefProperties, channel.w);
+ for (size_t y = 0; y < channel.h; y++) {
+ const pixel_type *JXL_RESTRICT p = channel.Row(y);
+ PrecomputeReferences(channel, y, image, chan, &references);
+ float *pred_img_row[3];
+ if (kWantDebug) {
+ for (size_t c = 0; c < 3; c++) {
+ pred_img_row[c] = predictor_img.PlaneRow(c, y);
+ }
+ }
+ InitPropsRow(&properties, static_props, y);
+ for (size_t x = 0; x < channel.w; x++) {
+ PredictionResult res =
+ PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y,
+ tree_lookup, references);
+ if (kWantDebug) {
+ for (size_t i = 0; i < 3; i++) {
+ pred_img_row[i][x] = PredictorColor(res.predictor)[i];
+ }
+ }
+ pixel_type_w residual = p[x] - res.guess;
+ JXL_ASSERT(residual % res.multiplier == 0);
+ *tokenp++ = Token(res.context, PackSigned(residual / res.multiplier));
+ }
+ }
+ } else {
+ const intptr_t onerow = channel.plane.PixelsPerRow();
+ Channel references(properties.size() - kNumNonrefProperties, channel.w);
+ weighted::State wp_state(wp_header, channel.w, channel.h);
+ for (size_t y = 0; y < channel.h; y++) {
+ const pixel_type *JXL_RESTRICT p = channel.Row(y);
+ PrecomputeReferences(channel, y, image, chan, &references);
+ float *pred_img_row[3];
+ if (kWantDebug) {
+ for (size_t c = 0; c < 3; c++) {
+ pred_img_row[c] = predictor_img.PlaneRow(c, y);
+ }
+ }
+ InitPropsRow(&properties, static_props, y);
+ for (size_t x = 0; x < channel.w; x++) {
+ PredictionResult res =
+ PredictTreeWP(&properties, channel.w, p + x, onerow, x, y,
+ tree_lookup, references, &wp_state);
+ if (kWantDebug) {
+ for (size_t i = 0; i < 3; i++) {
+ pred_img_row[i][x] = PredictorColor(res.predictor)[i];
+ }
+ }
+ pixel_type_w residual = p[x] - res.guess;
+ JXL_ASSERT(residual % res.multiplier == 0);
+ *tokenp++ = Token(res.context, PackSigned(residual / res.multiplier));
+ wp_state.UpdateErrors(p[x], x, y, channel.w);
+ }
+ }
+ }
+ if (kWantDebug && WantDebugOutput(aux_out)) {
+ aux_out->DumpImage(
+ ("pred_" + ToString(group_id) + "_" + ToString(chan)).c_str(),
+ predictor_img);
+ }
+ *tokenpp = tokenp;
+ return true;
+}
+
+Status ModularEncode(const Image &image, const ModularOptions &options,
+ BitWriter *writer, AuxOut *aux_out, size_t layer,
+ size_t group_id, TreeSamples *tree_samples,
+ size_t *total_pixels, const Tree *tree,
+ GroupHeader *header, std::vector<Token> *tokens,
+ size_t *width) {
+ if (image.error) return JXL_FAILURE("Invalid image");
+ size_t nb_channels = image.channel.size();
+ JXL_DEBUG_V(
+ 2, "Encoding %" PRIuS "-channel, %i-bit, %" PRIuS "x%" PRIuS " image.",
+ nb_channels, image.bitdepth, image.w, image.h);
+
+ if (nb_channels < 1) {
+ return true; // is there any use for a zero-channel image?
+ }
+
+ // encode transforms
+ GroupHeader header_storage;
+ if (header == nullptr) header = &header_storage;
+ Bundle::Init(header);
+ if (options.predictor == Predictor::Weighted) {
+ weighted::PredictorMode(options.wp_mode, &header->wp_header);
+ }
+ header->transforms = image.transform;
+ // This doesn't actually work
+ if (tree != nullptr) {
+ header->use_global_tree = true;
+ }
+ if (tree_samples == nullptr && tree == nullptr) {
+ JXL_RETURN_IF_ERROR(Bundle::Write(*header, writer, layer, aux_out));
+ }
+
+ TreeSamples tree_samples_storage;
+ size_t total_pixels_storage = 0;
+ if (!total_pixels) total_pixels = &total_pixels_storage;
+ // If there's no tree, compute one (or gather data to).
+ if (tree == nullptr) {
+ bool gather_data = tree_samples != nullptr;
+ if (tree_samples == nullptr) {
+ JXL_RETURN_IF_ERROR(tree_samples_storage.SetPredictor(
+ options.predictor, options.wp_tree_mode));
+ JXL_RETURN_IF_ERROR(tree_samples_storage.SetProperties(
+ options.splitting_heuristics_properties, options.wp_tree_mode));
+ std::vector<pixel_type> pixel_samples;
+ std::vector<pixel_type> diff_samples;
+ std::vector<uint32_t> group_pixel_count;
+ std::vector<uint32_t> channel_pixel_count;
+ CollectPixelSamples(image, options, 0, group_pixel_count,
+ channel_pixel_count, pixel_samples, diff_samples);
+ std::vector<ModularMultiplierInfo> dummy_multiplier_info;
+ StaticPropRange range;
+ tree_samples_storage.PreQuantizeProperties(
+ range, dummy_multiplier_info, group_pixel_count, channel_pixel_count,
+ pixel_samples, diff_samples, options.max_property_values);
+ }
+ for (size_t i = 0; i < nb_channels; i++) {
+ if (!image.channel[i].w || !image.channel[i].h) {
+ continue; // skip empty channels
+ }
+ if (i >= image.nb_meta_channels &&
+ (image.channel[i].w > options.max_chan_size ||
+ image.channel[i].h > options.max_chan_size)) {
+ break;
+ }
+ GatherTreeData(image, i, group_id, header->wp_header, options,
+ gather_data ? *tree_samples : tree_samples_storage,
+ total_pixels);
+ }
+ if (gather_data) return true;
+ }
+
+ JXL_ASSERT((tree == nullptr) == (tokens == nullptr));
+
+ Tree tree_storage;
+ std::vector<std::vector<Token>> tokens_storage(1);
+ // Compute tree.
+ if (tree == nullptr) {
+ EntropyEncodingData code;
+ std::vector<uint8_t> context_map;
+
+ std::vector<std::vector<Token>> tree_tokens(1);
+ tree_storage =
+ LearnTree(std::move(tree_samples_storage), *total_pixels, options);
+ tree = &tree_storage;
+ tokens = &tokens_storage[0];
+
+ Tree decoded_tree;
+ TokenizeTree(*tree, &tree_tokens[0], &decoded_tree);
+ JXL_ASSERT(tree->size() == decoded_tree.size());
+ tree_storage = std::move(decoded_tree);
+
+ if (kWantDebug && kPrintTree && WantDebugOutput(aux_out)) {
+ PrintTree(*tree, aux_out->debug_prefix + "/tree_" + ToString(group_id));
+ }
+ // Write tree
+ BuildAndEncodeHistograms(HistogramParams(), kNumTreeContexts, tree_tokens,
+ &code, &context_map, writer, kLayerModularTree,
+ aux_out);
+ WriteTokens(tree_tokens[0], code, context_map, writer, kLayerModularTree,
+ aux_out);
+ }
+
+ size_t image_width = 0;
+ size_t total_tokens = 0;
+ for (size_t i = 0; i < nb_channels; i++) {
+ if (i >= image.nb_meta_channels &&
+ (image.channel[i].w > options.max_chan_size ||
+ image.channel[i].h > options.max_chan_size)) {
+ break;
+ }
+ if (image.channel[i].w > image_width) image_width = image.channel[i].w;
+ total_tokens += image.channel[i].w * image.channel[i].h;
+ }
+ if (options.zero_tokens) {
+ tokens->resize(tokens->size() + total_tokens, {0, 0});
+ } else {
+ // Do one big allocation for all the tokens we'll need,
+ // to avoid reallocs that might require copying.
+ size_t pos = tokens->size();
+ tokens->resize(pos + total_tokens);
+ Token *tokenp = tokens->data() + pos;
+ for (size_t i = 0; i < nb_channels; i++) {
+ if (!image.channel[i].w || !image.channel[i].h) {
+ continue; // skip empty channels
+ }
+ if (i >= image.nb_meta_channels &&
+ (image.channel[i].w > options.max_chan_size ||
+ image.channel[i].h > options.max_chan_size)) {
+ break;
+ }
+ JXL_RETURN_IF_ERROR(EncodeModularChannelMAANS(
+ image, i, header->wp_header, *tree, &tokenp, aux_out, group_id,
+ options.skip_encoder_fast_path));
+ }
+ // Make sure we actually wrote all tokens
+ JXL_CHECK(tokenp == tokens->data() + tokens->size());
+ }
+
+ // Write data if not using a global tree/ANS stream.
+ if (!header->use_global_tree) {
+ EntropyEncodingData code;
+ std::vector<uint8_t> context_map;
+ HistogramParams histo_params;
+ histo_params.image_widths.push_back(image_width);
+ BuildAndEncodeHistograms(histo_params, (tree->size() + 1) / 2,
+ tokens_storage, &code, &context_map, writer, layer,
+ aux_out);
+ WriteTokens(tokens_storage[0], code, context_map, writer, layer, aux_out);
+ } else {
+ *width = image_width;
+ }
+ return true;
+}
+
+Status ModularGenericCompress(Image &image, const ModularOptions &opts,
+ BitWriter *writer, AuxOut *aux_out, size_t layer,
+ size_t group_id, TreeSamples *tree_samples,
+ size_t *total_pixels, const Tree *tree,
+ GroupHeader *header, std::vector<Token> *tokens,
+ size_t *width) {
+ if (image.w == 0 || image.h == 0) return true;
+ ModularOptions options = opts; // Make a copy to modify it.
+
+ if (options.predictor == static_cast<Predictor>(-1)) {
+ options.predictor = Predictor::Gradient;
+ }
+
+ size_t bits = writer ? writer->BitsWritten() : 0;
+ JXL_RETURN_IF_ERROR(ModularEncode(image, options, writer, aux_out, layer,
+ group_id, tree_samples, total_pixels, tree,
+ header, tokens, width));
+ bits = writer ? writer->BitsWritten() - bits : 0;
+ if (writer) {
+ JXL_DEBUG_V(4,
+ "Modular-encoded a %" PRIuS "x%" PRIuS
+ " bitdepth=%i nbchans=%" PRIuS " image in %" PRIuS " bytes",
+ image.w, image.h, image.bitdepth, image.channel.size(),
+ bits / 8);
+ }
+ (void)bits;
+ return true;
+}
+
+} // namespace jxl
diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_encoding.h b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_encoding.h
new file mode 100644
index 0000000000..04df504750
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_encoding.h
@@ -0,0 +1,47 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#ifndef LIB_JXL_MODULAR_ENCODING_ENC_ENCODING_H_
+#define LIB_JXL_MODULAR_ENCODING_ENC_ENCODING_H_
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <vector>
+
+#include "lib/jxl/base/compiler_specific.h"
+#include "lib/jxl/base/padded_bytes.h"
+#include "lib/jxl/base/span.h"
+#include "lib/jxl/dec_ans.h"
+#include "lib/jxl/enc_ans.h"
+#include "lib/jxl/enc_bit_writer.h"
+#include "lib/jxl/image.h"
+#include "lib/jxl/modular/encoding/context_predict.h"
+#include "lib/jxl/modular/encoding/enc_ma.h"
+#include "lib/jxl/modular/encoding/encoding.h"
+#include "lib/jxl/modular/modular_image.h"
+#include "lib/jxl/modular/options.h"
+#include "lib/jxl/modular/transform/transform.h"
+
+namespace jxl {
+
+Tree LearnTree(TreeSamples &&tree_samples, size_t total_pixels,
+ const ModularOptions &options,
+ const std::vector<ModularMultiplierInfo> &multiplier_info = {},
+ StaticPropRange static_prop_range = {});
+
+// TODO(veluca): make cleaner interfaces.
+
+Status ModularGenericCompress(
+ Image &image, const ModularOptions &opts, BitWriter *writer,
+ AuxOut *aux_out = nullptr, size_t layer = 0, size_t group_id = 0,
+ // For gathering data for producing a global tree.
+ TreeSamples *tree_samples = nullptr, size_t *total_pixels = nullptr,
+ // For encoding with global tree.
+ const Tree *tree = nullptr, GroupHeader *header = nullptr,
+ std::vector<Token> *tokens = nullptr, size_t *widths = nullptr);
+} // namespace jxl
+
+#endif // LIB_JXL_MODULAR_ENCODING_ENC_ENCODING_H_
diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_ma.cc b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_ma.cc
new file mode 100644
index 0000000000..d0f6b47566
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_ma.cc
@@ -0,0 +1,1023 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#include "lib/jxl/modular/encoding/enc_ma.h"
+
+#include <algorithm>
+#include <limits>
+#include <numeric>
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "lib/jxl/modular/encoding/ma_common.h"
+
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "lib/jxl/modular/encoding/enc_ma.cc"
+#include <hwy/foreach_target.h>
+#include <hwy/highway.h>
+
+#include "lib/jxl/base/random.h"
+#include "lib/jxl/enc_ans.h"
+#include "lib/jxl/fast_math-inl.h"
+#include "lib/jxl/modular/encoding/context_predict.h"
+#include "lib/jxl/modular/options.h"
+HWY_BEFORE_NAMESPACE();
+namespace jxl {
+namespace HWY_NAMESPACE {
+
+// These templates are not found via ADL.
+using hwy::HWY_NAMESPACE::Eq;
+using hwy::HWY_NAMESPACE::IfThenElse;
+using hwy::HWY_NAMESPACE::Lt;
+
+const HWY_FULL(float) df;
+const HWY_FULL(int32_t) di;
+size_t Padded(size_t x) { return RoundUpTo(x, Lanes(df)); }
+
+float EstimateBits(const int32_t *counts, int32_t *rounded_counts,
+ size_t num_symbols) {
+ // Try to approximate the effect of rounding up nonzero probabilities.
+ int32_t total = std::accumulate(counts, counts + num_symbols, 0);
+ const auto min = Set(di, (total + ANS_TAB_SIZE - 1) >> ANS_LOG_TAB_SIZE);
+ const auto zero_i = Zero(di);
+ for (size_t i = 0; i < num_symbols; i += Lanes(df)) {
+ auto counts_v = LoadU(di, &counts[i]);
+ counts_v = IfThenElse(Eq(counts_v, zero_i), zero_i,
+ IfThenElse(Lt(counts_v, min), min, counts_v));
+ StoreU(counts_v, di, &rounded_counts[i]);
+ }
+ // Compute entropy of the "rounded" probabilities.
+ const auto zero = Zero(df);
+ const size_t total_scalar =
+ std::accumulate(rounded_counts, rounded_counts + num_symbols, 0);
+ const auto inv_total = Set(df, 1.0f / total_scalar);
+ auto bits_lanes = Zero(df);
+ auto total_v = Set(di, total_scalar);
+ for (size_t i = 0; i < num_symbols; i += Lanes(df)) {
+ const auto counts_v = ConvertTo(df, LoadU(di, &counts[i]));
+ const auto round_counts_v = LoadU(di, &rounded_counts[i]);
+ const auto probs = Mul(ConvertTo(df, round_counts_v), inv_total);
+ const auto nbps = IfThenElse(Eq(round_counts_v, total_v), BitCast(di, zero),
+ BitCast(di, FastLog2f(df, probs)));
+ bits_lanes = Sub(bits_lanes, IfThenElse(Eq(counts_v, zero), zero,
+ Mul(counts_v, BitCast(df, nbps))));
+ }
+ return GetLane(SumOfLanes(df, bits_lanes));
+}
+
+void MakeSplitNode(size_t pos, int property, int splitval, Predictor lpred,
+ int64_t loff, Predictor rpred, int64_t roff, Tree *tree) {
+ // Note that the tree splits on *strictly greater*.
+ (*tree)[pos].lchild = tree->size();
+ (*tree)[pos].rchild = tree->size() + 1;
+ (*tree)[pos].splitval = splitval;
+ (*tree)[pos].property = property;
+ tree->emplace_back();
+ tree->back().property = -1;
+ tree->back().predictor = rpred;
+ tree->back().predictor_offset = roff;
+ tree->back().multiplier = 1;
+ tree->emplace_back();
+ tree->back().property = -1;
+ tree->back().predictor = lpred;
+ tree->back().predictor_offset = loff;
+ tree->back().multiplier = 1;
+}
+
+enum class IntersectionType { kNone, kPartial, kInside };
+IntersectionType BoxIntersects(StaticPropRange needle, StaticPropRange haystack,
+ uint32_t &partial_axis, uint32_t &partial_val) {
+ bool partial = false;
+ for (size_t i = 0; i < kNumStaticProperties; i++) {
+ if (haystack[i][0] >= needle[i][1]) {
+ return IntersectionType::kNone;
+ }
+ if (haystack[i][1] <= needle[i][0]) {
+ return IntersectionType::kNone;
+ }
+ if (haystack[i][0] <= needle[i][0] && haystack[i][1] >= needle[i][1]) {
+ continue;
+ }
+ partial = true;
+ partial_axis = i;
+ if (haystack[i][0] > needle[i][0] && haystack[i][0] < needle[i][1]) {
+ partial_val = haystack[i][0] - 1;
+ } else {
+ JXL_DASSERT(haystack[i][1] > needle[i][0] &&
+ haystack[i][1] < needle[i][1]);
+ partial_val = haystack[i][1] - 1;
+ }
+ }
+ return partial ? IntersectionType::kPartial : IntersectionType::kInside;
+}
+
+void SplitTreeSamples(TreeSamples &tree_samples, size_t begin, size_t pos,
+ size_t end, size_t prop) {
+ auto cmp = [&](size_t a, size_t b) {
+ return int32_t(tree_samples.Property(prop, a)) -
+ int32_t(tree_samples.Property(prop, b));
+ };
+ Rng rng(0);
+ while (end > begin + 1) {
+ {
+ size_t pivot = rng.UniformU(begin, end);
+ tree_samples.Swap(begin, pivot);
+ }
+ size_t pivot_begin = begin;
+ size_t pivot_end = pivot_begin + 1;
+ for (size_t i = begin + 1; i < end; i++) {
+ JXL_DASSERT(i >= pivot_end);
+ JXL_DASSERT(pivot_end > pivot_begin);
+ int32_t cmp_result = cmp(i, pivot_begin);
+ if (cmp_result < 0) { // i < pivot, move pivot forward and put i before
+ // the pivot.
+ tree_samples.ThreeShuffle(pivot_begin, pivot_end, i);
+ pivot_begin++;
+ pivot_end++;
+ } else if (cmp_result == 0) {
+ tree_samples.Swap(pivot_end, i);
+ pivot_end++;
+ }
+ }
+ JXL_DASSERT(pivot_begin >= begin);
+ JXL_DASSERT(pivot_end > pivot_begin);
+ JXL_DASSERT(pivot_end <= end);
+ for (size_t i = begin; i < pivot_begin; i++) {
+ JXL_DASSERT(cmp(i, pivot_begin) < 0);
+ }
+ for (size_t i = pivot_end; i < end; i++) {
+ JXL_DASSERT(cmp(i, pivot_begin) > 0);
+ }
+ for (size_t i = pivot_begin; i < pivot_end; i++) {
+ JXL_DASSERT(cmp(i, pivot_begin) == 0);
+ }
+ // We now have that [begin, pivot_begin) is < pivot, [pivot_begin,
+ // pivot_end) is = pivot, and [pivot_end, end) is > pivot.
+ // If pos falls in the first or the last interval, we continue in that
+ // interval; otherwise, we are done.
+ if (pivot_begin > pos) {
+ end = pivot_begin;
+ } else if (pivot_end < pos) {
+ begin = pivot_end;
+ } else {
+ break;
+ }
+ }
+}
+
+void FindBestSplit(TreeSamples &tree_samples, float threshold,
+ const std::vector<ModularMultiplierInfo> &mul_info,
+ StaticPropRange initial_static_prop_range,
+ float fast_decode_multiplier, Tree *tree) {
+ struct NodeInfo {
+ size_t pos;
+ size_t begin;
+ size_t end;
+ uint64_t used_properties;
+ StaticPropRange static_prop_range;
+ };
+ std::vector<NodeInfo> nodes;
+ nodes.push_back(NodeInfo{0, 0, tree_samples.NumDistinctSamples(), 0,
+ initial_static_prop_range});
+
+ size_t num_predictors = tree_samples.NumPredictors();
+ size_t num_properties = tree_samples.NumProperties();
+
+ // TODO(veluca): consider parallelizing the search (processing multiple nodes
+ // at a time).
+ while (!nodes.empty()) {
+ size_t pos = nodes.back().pos;
+ size_t begin = nodes.back().begin;
+ size_t end = nodes.back().end;
+ uint64_t used_properties = nodes.back().used_properties;
+ StaticPropRange static_prop_range = nodes.back().static_prop_range;
+ nodes.pop_back();
+ if (begin == end) continue;
+
+ struct SplitInfo {
+ size_t prop = 0;
+ uint32_t val = 0;
+ size_t pos = 0;
+ float lcost = std::numeric_limits<float>::max();
+ float rcost = std::numeric_limits<float>::max();
+ Predictor lpred = Predictor::Zero;
+ Predictor rpred = Predictor::Zero;
+ float Cost() { return lcost + rcost; }
+ };
+
+ SplitInfo best_split_static_constant;
+ SplitInfo best_split_static;
+ SplitInfo best_split_nonstatic;
+ SplitInfo best_split_nowp;
+
+ JXL_DASSERT(begin <= end);
+ JXL_DASSERT(end <= tree_samples.NumDistinctSamples());
+
+ // Compute the maximum token in the range.
+ size_t max_symbols = 0;
+ for (size_t pred = 0; pred < num_predictors; pred++) {
+ for (size_t i = begin; i < end; i++) {
+ uint32_t tok = tree_samples.Token(pred, i);
+ max_symbols = max_symbols > tok + 1 ? max_symbols : tok + 1;
+ }
+ }
+ max_symbols = Padded(max_symbols);
+ std::vector<int32_t> rounded_counts(max_symbols);
+ std::vector<int32_t> counts(max_symbols * num_predictors);
+ std::vector<uint32_t> tot_extra_bits(num_predictors);
+ for (size_t pred = 0; pred < num_predictors; pred++) {
+ for (size_t i = begin; i < end; i++) {
+ counts[pred * max_symbols + tree_samples.Token(pred, i)] +=
+ tree_samples.Count(i);
+ tot_extra_bits[pred] +=
+ tree_samples.NBits(pred, i) * tree_samples.Count(i);
+ }
+ }
+
+ float base_bits;
+ {
+ size_t pred = tree_samples.PredictorIndex((*tree)[pos].predictor);
+ base_bits = EstimateBits(counts.data() + pred * max_symbols,
+ rounded_counts.data(), max_symbols) +
+ tot_extra_bits[pred];
+ }
+
+ SplitInfo *best = &best_split_nonstatic;
+
+ SplitInfo forced_split;
+ // The multiplier ranges cut halfway through the current ranges of static
+ // properties. We do this even if the current node is not a leaf, to
+ // minimize the number of nodes in the resulting tree.
+ for (size_t i = 0; i < mul_info.size(); i++) {
+ uint32_t axis, val;
+ IntersectionType t =
+ BoxIntersects(static_prop_range, mul_info[i].range, axis, val);
+ if (t == IntersectionType::kNone) continue;
+ if (t == IntersectionType::kInside) {
+ (*tree)[pos].multiplier = mul_info[i].multiplier;
+ break;
+ }
+ if (t == IntersectionType::kPartial) {
+ forced_split.val = tree_samples.QuantizeProperty(axis, val);
+ forced_split.prop = axis;
+ forced_split.lcost = forced_split.rcost = base_bits / 2 - threshold;
+ forced_split.lpred = forced_split.rpred = (*tree)[pos].predictor;
+ best = &forced_split;
+ best->pos = begin;
+ JXL_ASSERT(best->prop == tree_samples.PropertyFromIndex(best->prop));
+ for (size_t x = begin; x < end; x++) {
+ if (tree_samples.Property(best->prop, x) <= best->val) {
+ best->pos++;
+ }
+ }
+ break;
+ }
+ }
+
+ if (best != &forced_split) {
+ std::vector<int> prop_value_used_count;
+ std::vector<int> count_increase;
+ std::vector<size_t> extra_bits_increase;
+ // For each property, compute which of its values are used, and what
+ // tokens correspond to those usages. Then, iterate through the values,
+ // and compute the entropy of each side of the split (of the form `prop >
+ // threshold`). Finally, find the split that minimizes the cost.
+ struct CostInfo {
+ float cost = std::numeric_limits<float>::max();
+ float extra_cost = 0;
+ float Cost() const { return cost + extra_cost; }
+ Predictor pred; // will be uninitialized in some cases, but never used.
+ };
+ std::vector<CostInfo> costs_l;
+ std::vector<CostInfo> costs_r;
+
+ std::vector<int32_t> counts_above(max_symbols);
+ std::vector<int32_t> counts_below(max_symbols);
+
+ // The lower the threshold, the higher the expected noisiness of the
+ // estimate. Thus, discourage changing predictors.
+ float change_pred_penalty = 800.0f / (100.0f + threshold);
+ for (size_t prop = 0; prop < num_properties && base_bits > threshold;
+ prop++) {
+ costs_l.clear();
+ costs_r.clear();
+ size_t prop_size = tree_samples.NumPropertyValues(prop);
+ if (extra_bits_increase.size() < prop_size) {
+ count_increase.resize(prop_size * max_symbols);
+ extra_bits_increase.resize(prop_size);
+ }
+ // Clear prop_value_used_count (which cannot be cleared "on the go")
+ prop_value_used_count.clear();
+ prop_value_used_count.resize(prop_size);
+
+ size_t first_used = prop_size;
+ size_t last_used = 0;
+
+ // TODO(veluca): consider finding multiple splits along a single
+ // property at the same time, possibly with a bottom-up approach.
+ for (size_t i = begin; i < end; i++) {
+ size_t p = tree_samples.Property(prop, i);
+ prop_value_used_count[p]++;
+ last_used = std::max(last_used, p);
+ first_used = std::min(first_used, p);
+ }
+ costs_l.resize(last_used - first_used);
+ costs_r.resize(last_used - first_used);
+ // For all predictors, compute the right and left costs of each split.
+ for (size_t pred = 0; pred < num_predictors; pred++) {
+ // Compute cost and histogram increments for each property value.
+ for (size_t i = begin; i < end; i++) {
+ size_t p = tree_samples.Property(prop, i);
+ size_t cnt = tree_samples.Count(i);
+ size_t sym = tree_samples.Token(pred, i);
+ count_increase[p * max_symbols + sym] += cnt;
+ extra_bits_increase[p] += tree_samples.NBits(pred, i) * cnt;
+ }
+ memcpy(counts_above.data(), counts.data() + pred * max_symbols,
+ max_symbols * sizeof counts_above[0]);
+ memset(counts_below.data(), 0, max_symbols * sizeof counts_below[0]);
+ size_t extra_bits_below = 0;
+ // Exclude last used: this ensures neither counts_above nor
+ // counts_below is empty.
+ for (size_t i = first_used; i < last_used; i++) {
+ if (!prop_value_used_count[i]) continue;
+ extra_bits_below += extra_bits_increase[i];
+ // The increase for this property value has been used, and will not
+ // be used again: clear it. Also below.
+ extra_bits_increase[i] = 0;
+ for (size_t sym = 0; sym < max_symbols; sym++) {
+ counts_above[sym] -= count_increase[i * max_symbols + sym];
+ counts_below[sym] += count_increase[i * max_symbols + sym];
+ count_increase[i * max_symbols + sym] = 0;
+ }
+ float rcost = EstimateBits(counts_above.data(),
+ rounded_counts.data(), max_symbols) +
+ tot_extra_bits[pred] - extra_bits_below;
+ float lcost = EstimateBits(counts_below.data(),
+ rounded_counts.data(), max_symbols) +
+ extra_bits_below;
+ JXL_DASSERT(extra_bits_below <= tot_extra_bits[pred]);
+ float penalty = 0;
+ // Never discourage moving away from the Weighted predictor.
+ if (tree_samples.PredictorFromIndex(pred) !=
+ (*tree)[pos].predictor &&
+ (*tree)[pos].predictor != Predictor::Weighted) {
+ penalty = change_pred_penalty;
+ }
+ // If everything else is equal, disfavour Weighted (slower) and
+ // favour Zero (faster if it's the only predictor used in a
+ // group+channel combination)
+ if (tree_samples.PredictorFromIndex(pred) == Predictor::Weighted) {
+ penalty += 1e-8;
+ }
+ if (tree_samples.PredictorFromIndex(pred) == Predictor::Zero) {
+ penalty -= 1e-8;
+ }
+ if (rcost + penalty < costs_r[i - first_used].Cost()) {
+ costs_r[i - first_used].cost = rcost;
+ costs_r[i - first_used].extra_cost = penalty;
+ costs_r[i - first_used].pred =
+ tree_samples.PredictorFromIndex(pred);
+ }
+ if (lcost + penalty < costs_l[i - first_used].Cost()) {
+ costs_l[i - first_used].cost = lcost;
+ costs_l[i - first_used].extra_cost = penalty;
+ costs_l[i - first_used].pred =
+ tree_samples.PredictorFromIndex(pred);
+ }
+ }
+ }
+ // Iterate through the possible splits and find the one with minimum sum
+ // of costs of the two sides.
+ size_t split = begin;
+ for (size_t i = first_used; i < last_used; i++) {
+ if (!prop_value_used_count[i]) continue;
+ split += prop_value_used_count[i];
+ float rcost = costs_r[i - first_used].cost;
+ float lcost = costs_l[i - first_used].cost;
+ // WP was not used + we would use the WP property or predictor
+ bool adds_wp =
+ (tree_samples.PropertyFromIndex(prop) == kWPProp &&
+ (used_properties & (1LU << prop)) == 0) ||
+ ((costs_l[i - first_used].pred == Predictor::Weighted ||
+ costs_r[i - first_used].pred == Predictor::Weighted) &&
+ (*tree)[pos].predictor != Predictor::Weighted);
+ bool zero_entropy_side = rcost == 0 || lcost == 0;
+
+ SplitInfo &best =
+ prop < kNumStaticProperties
+ ? (zero_entropy_side ? best_split_static_constant
+ : best_split_static)
+ : (adds_wp ? best_split_nonstatic : best_split_nowp);
+ if (lcost + rcost < best.Cost()) {
+ best.prop = prop;
+ best.val = i;
+ best.pos = split;
+ best.lcost = lcost;
+ best.lpred = costs_l[i - first_used].pred;
+ best.rcost = rcost;
+ best.rpred = costs_r[i - first_used].pred;
+ }
+ }
+ // Clear extra_bits_increase and cost_increase for last_used.
+ extra_bits_increase[last_used] = 0;
+ for (size_t sym = 0; sym < max_symbols; sym++) {
+ count_increase[last_used * max_symbols + sym] = 0;
+ }
+ }
+
+ // Try to avoid introducing WP.
+ if (best_split_nowp.Cost() + threshold < base_bits &&
+ best_split_nowp.Cost() <= fast_decode_multiplier * best->Cost()) {
+ best = &best_split_nowp;
+ }
+ // Split along static props if possible and not significantly more
+ // expensive.
+ if (best_split_static.Cost() + threshold < base_bits &&
+ best_split_static.Cost() <= fast_decode_multiplier * best->Cost()) {
+ best = &best_split_static;
+ }
+ // Split along static props to create constant nodes if possible.
+ if (best_split_static_constant.Cost() + threshold < base_bits) {
+ best = &best_split_static_constant;
+ }
+ }
+
+ if (best->Cost() + threshold < base_bits) {
+ uint32_t p = tree_samples.PropertyFromIndex(best->prop);
+ pixel_type dequant =
+ tree_samples.UnquantizeProperty(best->prop, best->val);
+ // Split node and try to split children.
+ MakeSplitNode(pos, p, dequant, best->lpred, 0, best->rpred, 0, tree);
+ // "Sort" according to winning property
+ SplitTreeSamples(tree_samples, begin, best->pos, end, best->prop);
+ if (p >= kNumStaticProperties) {
+ used_properties |= 1 << best->prop;
+ }
+ auto new_sp_range = static_prop_range;
+ if (p < kNumStaticProperties) {
+ JXL_ASSERT(static_cast<uint32_t>(dequant + 1) <= new_sp_range[p][1]);
+ new_sp_range[p][1] = dequant + 1;
+ JXL_ASSERT(new_sp_range[p][0] < new_sp_range[p][1]);
+ }
+ nodes.push_back(NodeInfo{(*tree)[pos].rchild, begin, best->pos,
+ used_properties, new_sp_range});
+ new_sp_range = static_prop_range;
+ if (p < kNumStaticProperties) {
+ JXL_ASSERT(new_sp_range[p][0] <= static_cast<uint32_t>(dequant + 1));
+ new_sp_range[p][0] = dequant + 1;
+ JXL_ASSERT(new_sp_range[p][0] < new_sp_range[p][1]);
+ }
+ nodes.push_back(NodeInfo{(*tree)[pos].lchild, best->pos, end,
+ used_properties, new_sp_range});
+ }
+ }
+}
+
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace jxl
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace jxl {
+
+HWY_EXPORT(FindBestSplit); // Local function.
+
+void ComputeBestTree(TreeSamples &tree_samples, float threshold,
+ const std::vector<ModularMultiplierInfo> &mul_info,
+ StaticPropRange static_prop_range,
+ float fast_decode_multiplier, Tree *tree) {
+ // TODO(veluca): take into account that different contexts can have different
+ // uint configs.
+ //
+ // Initialize tree.
+ tree->emplace_back();
+ tree->back().property = -1;
+ tree->back().predictor = tree_samples.PredictorFromIndex(0);
+ tree->back().predictor_offset = 0;
+ tree->back().multiplier = 1;
+ JXL_ASSERT(tree_samples.NumProperties() < 64);
+
+ JXL_ASSERT(tree_samples.NumDistinctSamples() <=
+ std::numeric_limits<uint32_t>::max());
+ HWY_DYNAMIC_DISPATCH(FindBestSplit)
+ (tree_samples, threshold, mul_info, static_prop_range, fast_decode_multiplier,
+ tree);
+}
+
+constexpr int32_t TreeSamples::kPropertyRange;
+constexpr uint32_t TreeSamples::kDedupEntryUnused;
+
+Status TreeSamples::SetPredictor(Predictor predictor,
+ ModularOptions::TreeMode wp_tree_mode) {
+ if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) {
+ predictors = {Predictor::Weighted};
+ residuals.resize(1);
+ return true;
+ }
+ if (wp_tree_mode == ModularOptions::TreeMode::kNoWP &&
+ predictor == Predictor::Weighted) {
+ return JXL_FAILURE("Invalid predictor settings");
+ }
+ if (predictor == Predictor::Variable) {
+ for (size_t i = 0; i < kNumModularPredictors; i++) {
+ predictors.push_back(static_cast<Predictor>(i));
+ }
+ std::swap(predictors[0], predictors[static_cast<int>(Predictor::Weighted)]);
+ std::swap(predictors[1], predictors[static_cast<int>(Predictor::Gradient)]);
+ } else if (predictor == Predictor::Best) {
+ predictors = {Predictor::Weighted, Predictor::Gradient};
+ } else {
+ predictors = {predictor};
+ }
+ if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) {
+ auto wp_it =
+ std::find(predictors.begin(), predictors.end(), Predictor::Weighted);
+ if (wp_it != predictors.end()) {
+ predictors.erase(wp_it);
+ }
+ }
+ residuals.resize(predictors.size());
+ return true;
+}
+
+Status TreeSamples::SetProperties(const std::vector<uint32_t> &properties,
+ ModularOptions::TreeMode wp_tree_mode) {
+ props_to_use = properties;
+ if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) {
+ props_to_use = {static_cast<uint32_t>(kWPProp)};
+ }
+ if (wp_tree_mode == ModularOptions::TreeMode::kGradientOnly) {
+ props_to_use = {static_cast<uint32_t>(kGradientProp)};
+ }
+ if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) {
+ auto it = std::find(props_to_use.begin(), props_to_use.end(), kWPProp);
+ if (it != props_to_use.end()) {
+ props_to_use.erase(it);
+ }
+ }
+ if (props_to_use.empty()) {
+ return JXL_FAILURE("Invalid property set configuration");
+ }
+ props.resize(props_to_use.size());
+ return true;
+}
+
+void TreeSamples::InitTable(size_t size) {
+ JXL_DASSERT((size & (size - 1)) == 0);
+ if (dedup_table_.size() == size) return;
+ dedup_table_.resize(size, kDedupEntryUnused);
+ for (size_t i = 0; i < NumDistinctSamples(); i++) {
+ if (sample_counts[i] != std::numeric_limits<uint16_t>::max()) {
+ AddToTable(i);
+ }
+ }
+}
+
+bool TreeSamples::AddToTableAndMerge(size_t a) {
+ size_t pos1 = Hash1(a);
+ size_t pos2 = Hash2(a);
+ if (dedup_table_[pos1] != kDedupEntryUnused &&
+ IsSameSample(a, dedup_table_[pos1])) {
+ JXL_DASSERT(sample_counts[a] == 1);
+ sample_counts[dedup_table_[pos1]]++;
+ // Remove from hash table samples that are saturated.
+ if (sample_counts[dedup_table_[pos1]] ==
+ std::numeric_limits<uint16_t>::max()) {
+ dedup_table_[pos1] = kDedupEntryUnused;
+ }
+ return true;
+ }
+ if (dedup_table_[pos2] != kDedupEntryUnused &&
+ IsSameSample(a, dedup_table_[pos2])) {
+ JXL_DASSERT(sample_counts[a] == 1);
+ sample_counts[dedup_table_[pos2]]++;
+ // Remove from hash table samples that are saturated.
+ if (sample_counts[dedup_table_[pos2]] ==
+ std::numeric_limits<uint16_t>::max()) {
+ dedup_table_[pos2] = kDedupEntryUnused;
+ }
+ return true;
+ }
+ AddToTable(a);
+ return false;
+}
+
+void TreeSamples::AddToTable(size_t a) {
+ size_t pos1 = Hash1(a);
+ size_t pos2 = Hash2(a);
+ if (dedup_table_[pos1] == kDedupEntryUnused) {
+ dedup_table_[pos1] = a;
+ } else if (dedup_table_[pos2] == kDedupEntryUnused) {
+ dedup_table_[pos2] = a;
+ }
+}
+
+void TreeSamples::PrepareForSamples(size_t num_samples) {
+ for (auto &res : residuals) {
+ res.reserve(res.size() + num_samples);
+ }
+ for (auto &p : props) {
+ p.reserve(p.size() + num_samples);
+ }
+ size_t total_num_samples = num_samples + sample_counts.size();
+ size_t next_pow2 = 1LLU << CeilLog2Nonzero(total_num_samples * 3 / 2);
+ InitTable(next_pow2);
+}
+
+size_t TreeSamples::Hash1(size_t a) const {
+ constexpr uint64_t constant = 0x1e35a7bd;
+ uint64_t h = constant;
+ for (const auto &r : residuals) {
+ h = h * constant + r[a].tok;
+ h = h * constant + r[a].nbits;
+ }
+ for (const auto &p : props) {
+ h = h * constant + p[a];
+ }
+ return (h >> 16) & (dedup_table_.size() - 1);
+}
+size_t TreeSamples::Hash2(size_t a) const {
+ constexpr uint64_t constant = 0x1e35a7bd1e35a7bd;
+ uint64_t h = constant;
+ for (const auto &p : props) {
+ h = h * constant ^ p[a];
+ }
+ for (const auto &r : residuals) {
+ h = h * constant ^ r[a].tok;
+ h = h * constant ^ r[a].nbits;
+ }
+ return (h >> 16) & (dedup_table_.size() - 1);
+}
+
+bool TreeSamples::IsSameSample(size_t a, size_t b) const {
+ bool ret = true;
+ for (const auto &r : residuals) {
+ if (r[a].tok != r[b].tok) {
+ ret = false;
+ }
+ if (r[a].nbits != r[b].nbits) {
+ ret = false;
+ }
+ }
+ for (const auto &p : props) {
+ if (p[a] != p[b]) {
+ ret = false;
+ }
+ }
+ return ret;
+}
+
+void TreeSamples::AddSample(pixel_type_w pixel, const Properties &properties,
+ const pixel_type_w *predictions) {
+ for (size_t i = 0; i < predictors.size(); i++) {
+ pixel_type v = pixel - predictions[static_cast<int>(predictors[i])];
+ uint32_t tok, nbits, bits;
+ HybridUintConfig(4, 1, 2).Encode(PackSigned(v), &tok, &nbits, &bits);
+ JXL_DASSERT(tok < 256);
+ JXL_DASSERT(nbits < 256);
+ residuals[i].emplace_back(
+ ResidualToken{static_cast<uint8_t>(tok), static_cast<uint8_t>(nbits)});
+ }
+ for (size_t i = 0; i < props_to_use.size(); i++) {
+ props[i].push_back(QuantizeProperty(i, properties[props_to_use[i]]));
+ }
+ sample_counts.push_back(1);
+ num_samples++;
+ if (AddToTableAndMerge(sample_counts.size() - 1)) {
+ for (auto &r : residuals) r.pop_back();
+ for (auto &p : props) p.pop_back();
+ sample_counts.pop_back();
+ }
+}
+
+void TreeSamples::Swap(size_t a, size_t b) {
+ if (a == b) return;
+ for (auto &r : residuals) {
+ std::swap(r[a], r[b]);
+ }
+ for (auto &p : props) {
+ std::swap(p[a], p[b]);
+ }
+ std::swap(sample_counts[a], sample_counts[b]);
+}
+
+void TreeSamples::ThreeShuffle(size_t a, size_t b, size_t c) {
+ if (b == c) return Swap(a, b);
+ for (auto &r : residuals) {
+ auto tmp = r[a];
+ r[a] = r[c];
+ r[c] = r[b];
+ r[b] = tmp;
+ }
+ for (auto &p : props) {
+ auto tmp = p[a];
+ p[a] = p[c];
+ p[c] = p[b];
+ p[b] = tmp;
+ }
+ auto tmp = sample_counts[a];
+ sample_counts[a] = sample_counts[c];
+ sample_counts[c] = sample_counts[b];
+ sample_counts[b] = tmp;
+}
+
+namespace {
+std::vector<int32_t> QuantizeHistogram(const std::vector<uint32_t> &histogram,
+ size_t num_chunks) {
+ if (histogram.empty()) return {};
+ // TODO(veluca): selecting distinct quantiles is likely not the best
+ // way to go about this.
+ std::vector<int32_t> thresholds;
+ size_t sum = std::accumulate(histogram.begin(), histogram.end(), 0LU);
+ size_t cumsum = 0;
+ size_t threshold = 0;
+ for (size_t i = 0; i + 1 < histogram.size(); i++) {
+ cumsum += histogram[i];
+ if (cumsum > (threshold + 1) * sum / num_chunks) {
+ thresholds.push_back(i);
+ while (cumsum >= (threshold + 1) * sum / num_chunks) threshold++;
+ }
+ }
+ return thresholds;
+}
+
+std::vector<int32_t> QuantizeSamples(const std::vector<int32_t> &samples,
+ size_t num_chunks) {
+ if (samples.empty()) return {};
+ int min = *std::min_element(samples.begin(), samples.end());
+ constexpr int kRange = 512;
+ min = std::min(std::max(min, -kRange), kRange);
+ std::vector<uint32_t> counts(2 * kRange + 1);
+ for (int s : samples) {
+ uint32_t sample_offset = std::min(std::max(s, -kRange), kRange) - min;
+ counts[sample_offset]++;
+ }
+ std::vector<int32_t> thresholds = QuantizeHistogram(counts, num_chunks);
+ for (auto &v : thresholds) v += min;
+ return thresholds;
+}
+} // namespace
+
+void TreeSamples::PreQuantizeProperties(
+ const StaticPropRange &range,
+ const std::vector<ModularMultiplierInfo> &multiplier_info,
+ const std::vector<uint32_t> &group_pixel_count,
+ const std::vector<uint32_t> &channel_pixel_count,
+ std::vector<pixel_type> &pixel_samples,
+ std::vector<pixel_type> &diff_samples, size_t max_property_values) {
+ // If we have forced splits because of multipliers, choose channel and group
+ // thresholds accordingly.
+ std::vector<int32_t> group_multiplier_thresholds;
+ std::vector<int32_t> channel_multiplier_thresholds;
+ for (const auto &v : multiplier_info) {
+ if (v.range[0][0] != range[0][0]) {
+ channel_multiplier_thresholds.push_back(v.range[0][0] - 1);
+ }
+ if (v.range[0][1] != range[0][1]) {
+ channel_multiplier_thresholds.push_back(v.range[0][1] - 1);
+ }
+ if (v.range[1][0] != range[1][0]) {
+ group_multiplier_thresholds.push_back(v.range[1][0] - 1);
+ }
+ if (v.range[1][1] != range[1][1]) {
+ group_multiplier_thresholds.push_back(v.range[1][1] - 1);
+ }
+ }
+ std::sort(channel_multiplier_thresholds.begin(),
+ channel_multiplier_thresholds.end());
+ channel_multiplier_thresholds.resize(
+ std::unique(channel_multiplier_thresholds.begin(),
+ channel_multiplier_thresholds.end()) -
+ channel_multiplier_thresholds.begin());
+ std::sort(group_multiplier_thresholds.begin(),
+ group_multiplier_thresholds.end());
+ group_multiplier_thresholds.resize(
+ std::unique(group_multiplier_thresholds.begin(),
+ group_multiplier_thresholds.end()) -
+ group_multiplier_thresholds.begin());
+
+ compact_properties.resize(props_to_use.size());
+ auto quantize_channel = [&]() {
+ if (!channel_multiplier_thresholds.empty()) {
+ return channel_multiplier_thresholds;
+ }
+ return QuantizeHistogram(channel_pixel_count, max_property_values);
+ };
+ auto quantize_group_id = [&]() {
+ if (!group_multiplier_thresholds.empty()) {
+ return group_multiplier_thresholds;
+ }
+ return QuantizeHistogram(group_pixel_count, max_property_values);
+ };
+ auto quantize_coordinate = [&]() {
+ std::vector<int32_t> quantized;
+ quantized.reserve(max_property_values - 1);
+ for (size_t i = 0; i + 1 < max_property_values; i++) {
+ quantized.push_back((i + 1) * 256 / max_property_values - 1);
+ }
+ return quantized;
+ };
+ std::vector<int32_t> abs_pixel_thr;
+ std::vector<int32_t> pixel_thr;
+ auto quantize_pixel_property = [&]() {
+ if (pixel_thr.empty()) {
+ pixel_thr = QuantizeSamples(pixel_samples, max_property_values);
+ }
+ return pixel_thr;
+ };
+ auto quantize_abs_pixel_property = [&]() {
+ if (abs_pixel_thr.empty()) {
+ quantize_pixel_property(); // Compute the non-abs thresholds.
+ for (auto &v : pixel_samples) v = std::abs(v);
+ abs_pixel_thr = QuantizeSamples(pixel_samples, max_property_values);
+ }
+ return abs_pixel_thr;
+ };
+ std::vector<int32_t> abs_diff_thr;
+ std::vector<int32_t> diff_thr;
+ auto quantize_diff_property = [&]() {
+ if (diff_thr.empty()) {
+ diff_thr = QuantizeSamples(diff_samples, max_property_values);
+ }
+ return diff_thr;
+ };
+ auto quantize_abs_diff_property = [&]() {
+ if (abs_diff_thr.empty()) {
+ quantize_diff_property(); // Compute the non-abs thresholds.
+ for (auto &v : diff_samples) v = std::abs(v);
+ abs_diff_thr = QuantizeSamples(diff_samples, max_property_values);
+ }
+ return abs_diff_thr;
+ };
+ auto quantize_wp = [&]() {
+ if (max_property_values < 32) {
+ return std::vector<int32_t>{-127, -63, -31, -15, -7, -3, -1, 0,
+ 1, 3, 7, 15, 31, 63, 127};
+ }
+ if (max_property_values < 64) {
+ return std::vector<int32_t>{-255, -191, -127, -95, -63, -47, -31, -23,
+ -15, -11, -7, -5, -3, -1, 0, 1,
+ 3, 5, 7, 11, 15, 23, 31, 47,
+ 63, 95, 127, 191, 255};
+ }
+ return std::vector<int32_t>{
+ -255, -223, -191, -159, -127, -111, -95, -79, -63, -55, -47,
+ -39, -31, -27, -23, -19, -15, -13, -11, -9, -7, -6,
+ -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5,
+ 6, 7, 9, 11, 13, 15, 19, 23, 27, 31, 39,
+ 47, 55, 63, 79, 95, 111, 127, 159, 191, 223, 255};
+ };
+
+ property_mapping.resize(props_to_use.size());
+ for (size_t i = 0; i < props_to_use.size(); i++) {
+ if (props_to_use[i] == 0) {
+ compact_properties[i] = quantize_channel();
+ } else if (props_to_use[i] == 1) {
+ compact_properties[i] = quantize_group_id();
+ } else if (props_to_use[i] == 2 || props_to_use[i] == 3) {
+ compact_properties[i] = quantize_coordinate();
+ } else if (props_to_use[i] == 6 || props_to_use[i] == 7 ||
+ props_to_use[i] == 8 ||
+ (props_to_use[i] >= kNumNonrefProperties &&
+ (props_to_use[i] - kNumNonrefProperties) % 4 == 1)) {
+ compact_properties[i] = quantize_pixel_property();
+ } else if (props_to_use[i] == 4 || props_to_use[i] == 5 ||
+ (props_to_use[i] >= kNumNonrefProperties &&
+ (props_to_use[i] - kNumNonrefProperties) % 4 == 0)) {
+ compact_properties[i] = quantize_abs_pixel_property();
+ } else if (props_to_use[i] >= kNumNonrefProperties &&
+ (props_to_use[i] - kNumNonrefProperties) % 4 == 2) {
+ compact_properties[i] = quantize_abs_diff_property();
+ } else if (props_to_use[i] == kWPProp) {
+ compact_properties[i] = quantize_wp();
+ } else {
+ compact_properties[i] = quantize_diff_property();
+ }
+ property_mapping[i].resize(kPropertyRange * 2 + 1);
+ size_t mapped = 0;
+ for (size_t j = 0; j < property_mapping[i].size(); j++) {
+ while (mapped < compact_properties[i].size() &&
+ static_cast<int>(j) - kPropertyRange >
+ compact_properties[i][mapped]) {
+ mapped++;
+ }
+ // property_mapping[i] of a value V is `mapped` if
+ // compact_properties[i][mapped] <= j and
+ // compact_properties[i][mapped-1] > j
+ // This is because the decision node in the tree splits on (property) > j,
+ // hence everything that is not > of a threshold should be clustered
+ // together.
+ property_mapping[i][j] = mapped;
+ }
+ }
+}
+
+void CollectPixelSamples(const Image &image, const ModularOptions &options,
+ size_t group_id,
+ std::vector<uint32_t> &group_pixel_count,
+ std::vector<uint32_t> &channel_pixel_count,
+ std::vector<pixel_type> &pixel_samples,
+ std::vector<pixel_type> &diff_samples) {
+ if (options.nb_repeats == 0) return;
+ if (group_pixel_count.size() <= group_id) {
+ group_pixel_count.resize(group_id + 1);
+ }
+ if (channel_pixel_count.size() < image.channel.size()) {
+ channel_pixel_count.resize(image.channel.size());
+ }
+ Rng rng(group_id);
+ // Sample 10% of the final number of samples for property quantization.
+ float fraction = std::min(options.nb_repeats * 0.1, 0.99);
+ Rng::GeometricDistribution dist(fraction);
+ size_t total_pixels = 0;
+ std::vector<size_t> channel_ids;
+ for (size_t i = 0; i < image.channel.size(); i++) {
+ if (image.channel[i].w <= 1 || image.channel[i].h == 0) {
+ continue; // skip empty or width-1 channels.
+ }
+ if (i >= image.nb_meta_channels &&
+ (image.channel[i].w > options.max_chan_size ||
+ image.channel[i].h > options.max_chan_size)) {
+ break;
+ }
+ channel_ids.push_back(i);
+ group_pixel_count[group_id] += image.channel[i].w * image.channel[i].h;
+ channel_pixel_count[i] += image.channel[i].w * image.channel[i].h;
+ total_pixels += image.channel[i].w * image.channel[i].h;
+ }
+ if (channel_ids.empty()) return;
+ pixel_samples.reserve(pixel_samples.size() + fraction * total_pixels);
+ diff_samples.reserve(diff_samples.size() + fraction * total_pixels);
+ size_t i = 0;
+ size_t y = 0;
+ size_t x = 0;
+ auto advance = [&](size_t amount) {
+ x += amount;
+ // Detect row overflow (rare).
+ while (x >= image.channel[channel_ids[i]].w) {
+ x -= image.channel[channel_ids[i]].w;
+ y++;
+ // Detect end-of-channel (even rarer).
+ if (y == image.channel[channel_ids[i]].h) {
+ i++;
+ y = 0;
+ if (i >= channel_ids.size()) {
+ return;
+ }
+ }
+ }
+ };
+ advance(rng.Geometric(dist));
+ for (; i < channel_ids.size(); advance(rng.Geometric(dist) + 1)) {
+ const pixel_type *row = image.channel[channel_ids[i]].Row(y);
+ pixel_samples.push_back(row[x]);
+ size_t xp = x == 0 ? 1 : x - 1;
+ diff_samples.push_back((int64_t)row[x] - row[xp]);
+ }
+}
+
+// TODO(veluca): very simple encoding scheme. This should be improved.
+void TokenizeTree(const Tree &tree, std::vector<Token> *tokens,
+ Tree *decoder_tree) {
+ JXL_ASSERT(tree.size() <= kMaxTreeSize);
+ std::queue<int> q;
+ q.push(0);
+ size_t leaf_id = 0;
+ decoder_tree->clear();
+ while (!q.empty()) {
+ int cur = q.front();
+ q.pop();
+ JXL_ASSERT(tree[cur].property >= -1);
+ tokens->emplace_back(kPropertyContext, tree[cur].property + 1);
+ if (tree[cur].property == -1) {
+ tokens->emplace_back(kPredictorContext,
+ static_cast<int>(tree[cur].predictor));
+ tokens->emplace_back(kOffsetContext,
+ PackSigned(tree[cur].predictor_offset));
+ uint32_t mul_log = Num0BitsBelowLS1Bit_Nonzero(tree[cur].multiplier);
+ uint32_t mul_bits = (tree[cur].multiplier >> mul_log) - 1;
+ tokens->emplace_back(kMultiplierLogContext, mul_log);
+ tokens->emplace_back(kMultiplierBitsContext, mul_bits);
+ JXL_ASSERT(tree[cur].predictor < Predictor::Best);
+ decoder_tree->emplace_back(-1, 0, leaf_id++, 0, tree[cur].predictor,
+ tree[cur].predictor_offset,
+ tree[cur].multiplier);
+ continue;
+ }
+ decoder_tree->emplace_back(tree[cur].property, tree[cur].splitval,
+ decoder_tree->size() + q.size() + 1,
+ decoder_tree->size() + q.size() + 2,
+ Predictor::Zero, 0, 1);
+ q.push(tree[cur].lchild);
+ q.push(tree[cur].rchild);
+ tokens->emplace_back(kSplitValContext, PackSigned(tree[cur].splitval));
+ }
+}
+
+} // namespace jxl
+#endif // HWY_ONCE
diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_ma.h b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_ma.h
new file mode 100644
index 0000000000..ede37c8023
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_ma.h
@@ -0,0 +1,157 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#ifndef LIB_JXL_MODULAR_ENCODING_ENC_MA_H_
+#define LIB_JXL_MODULAR_ENCODING_ENC_MA_H_
+
+#include <numeric>
+
+#include "lib/jxl/enc_ans.h"
+#include "lib/jxl/entropy_coder.h"
+#include "lib/jxl/modular/encoding/dec_ma.h"
+#include "lib/jxl/modular/modular_image.h"
+#include "lib/jxl/modular/options.h"
+
+namespace jxl {
+
+// Struct to collect all the data needed to build a tree.
+struct TreeSamples {
+ bool HasSamples() const {
+ return !residuals.empty() && !residuals[0].empty();
+ }
+ size_t NumDistinctSamples() const { return sample_counts.size(); }
+ size_t NumSamples() const { return num_samples; }
+ // Set the predictor to use. Must be called before adding any samples.
+ Status SetPredictor(Predictor predictor,
+ ModularOptions::TreeMode wp_tree_mode);
+ // Set the properties to use. Must be called before adding any samples.
+ Status SetProperties(const std::vector<uint32_t> &properties,
+ ModularOptions::TreeMode wp_tree_mode);
+
+ size_t Token(size_t pred, size_t i) const { return residuals[pred][i].tok; }
+ size_t NBits(size_t pred, size_t i) const { return residuals[pred][i].nbits; }
+ size_t Count(size_t i) const { return sample_counts[i]; }
+ size_t PredictorIndex(Predictor predictor) const {
+ const auto predictor_elem =
+ std::find(predictors.begin(), predictors.end(), predictor);
+ JXL_DASSERT(predictor_elem != predictors.end());
+ return predictor_elem - predictors.begin();
+ }
+ size_t PropertyIndex(size_t property) const {
+ const auto property_elem =
+ std::find(props_to_use.begin(), props_to_use.end(), property);
+ JXL_DASSERT(property_elem != props_to_use.end());
+ return property_elem - props_to_use.begin();
+ }
+ size_t NumPropertyValues(size_t property_index) const {
+ return compact_properties[property_index].size() + 1;
+ }
+ // Returns the *quantized* property value.
+ size_t Property(size_t property_index, size_t i) const {
+ return props[property_index][i];
+ }
+ int UnquantizeProperty(size_t property_index, uint32_t quant) const {
+ JXL_ASSERT(quant < compact_properties[property_index].size());
+ return compact_properties[property_index][quant];
+ }
+
+ Predictor PredictorFromIndex(size_t index) const {
+ JXL_DASSERT(index < predictors.size());
+ return predictors[index];
+ }
+ size_t PropertyFromIndex(size_t index) const {
+ JXL_DASSERT(index < props_to_use.size());
+ return props_to_use[index];
+ }
+ size_t NumPredictors() const { return predictors.size(); }
+ size_t NumProperties() const { return props_to_use.size(); }
+
+ // Preallocate data for a given number of samples. MUST be called before
+ // adding any sample.
+ void PrepareForSamples(size_t num_samples);
+ // Add a sample.
+ void AddSample(pixel_type_w pixel, const Properties &properties,
+ const pixel_type_w *predictions);
+ // Pre-cluster property values.
+ void PreQuantizeProperties(
+ const StaticPropRange &range,
+ const std::vector<ModularMultiplierInfo> &multiplier_info,
+ const std::vector<uint32_t> &group_pixel_count,
+ const std::vector<uint32_t> &channel_pixel_count,
+ std::vector<pixel_type> &pixel_samples,
+ std::vector<pixel_type> &diff_samples, size_t max_property_values);
+
+ void AllSamplesDone() { dedup_table_ = std::vector<uint32_t>(); }
+
+ uint32_t QuantizeProperty(uint32_t prop, pixel_type v) const {
+ v = std::min(std::max(v, -kPropertyRange), kPropertyRange) + kPropertyRange;
+ return property_mapping[prop][v];
+ }
+
+ // Swaps samples in position a and b. Does nothing if a == b.
+ void Swap(size_t a, size_t b);
+
+ // Cycles samples: a -> b -> c -> a. We assume a <= b <= c, so that we can
+ // just call Swap(a, b) if b==c.
+ void ThreeShuffle(size_t a, size_t b, size_t c);
+
+ private:
+ // TODO(veluca): as the total number of properties and predictors are known
+ // before adding any samples, it might be better to interleave predictors,
+ // properties and counts in a single vector to improve locality.
+ // A first attempt at doing this actually results in much slower encoding,
+ // possibly because of the more complex addressing.
+ struct ResidualToken {
+ uint8_t tok;
+ uint8_t nbits;
+ };
+ // Residual information: token and number of extra bits, per predictor.
+ std::vector<std::vector<ResidualToken>> residuals;
+ // Number of occurrences of each sample.
+ std::vector<uint16_t> sample_counts;
+ // Property values, quantized to at most 256 distinct values.
+ std::vector<std::vector<uint8_t>> props;
+ // Decompactification info for `props`.
+ std::vector<std::vector<int32_t>> compact_properties;
+ // List of properties to use.
+ std::vector<uint32_t> props_to_use;
+ // List of predictors to use.
+ std::vector<Predictor> predictors;
+ // Mapping property value -> quantized property value.
+ static constexpr int32_t kPropertyRange = 511;
+ std::vector<std::vector<uint8_t>> property_mapping;
+ // Number of samples seen.
+ size_t num_samples = 0;
+ // Table for deduplication.
+ static constexpr uint32_t kDedupEntryUnused{static_cast<uint32_t>(-1)};
+ std::vector<uint32_t> dedup_table_;
+
+ // Functions for sample deduplication.
+ bool IsSameSample(size_t a, size_t b) const;
+ size_t Hash1(size_t a) const;
+ size_t Hash2(size_t a) const;
+ void InitTable(size_t size);
+ // Returns true if `a` was already present in the table.
+ bool AddToTableAndMerge(size_t a);
+ void AddToTable(size_t a);
+};
+
+void TokenizeTree(const Tree &tree, std::vector<Token> *tokens,
+ Tree *decoder_tree);
+
+void CollectPixelSamples(const Image &image, const ModularOptions &options,
+ size_t group_id,
+ std::vector<uint32_t> &group_pixel_count,
+ std::vector<uint32_t> &channel_pixel_count,
+ std::vector<pixel_type> &pixel_samples,
+ std::vector<pixel_type> &diff_samples);
+
+void ComputeBestTree(TreeSamples &tree_samples, float threshold,
+ const std::vector<ModularMultiplierInfo> &mul_info,
+ StaticPropRange static_prop_range,
+ float fast_decode_multiplier, Tree *tree);
+
+} // namespace jxl
+#endif // LIB_JXL_MODULAR_ENCODING_ENC_MA_H_
diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/encoding.cc b/third_party/jpeg-xl/lib/jxl/modular/encoding/encoding.cc
new file mode 100644
index 0000000000..9d2c3e5cf9
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/encoding.cc
@@ -0,0 +1,622 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#include "lib/jxl/modular/encoding/encoding.h"
+
+#include <stdint.h>
+#include <stdlib.h>
+
+#include <queue>
+
+#include "lib/jxl/base/printf_macros.h"
+#include "lib/jxl/base/scope_guard.h"
+#include "lib/jxl/modular/encoding/context_predict.h"
+#include "lib/jxl/modular/options.h"
+
+namespace jxl {
+
+// Removes all nodes that use a static property (i.e. channel or group ID) from
+// the tree and collapses each node on even levels with its two children to
+// produce a flatter tree. Also computes whether the resulting tree requires
+// using the weighted predictor.
+FlatTree FilterTree(const Tree &global_tree,
+ std::array<pixel_type, kNumStaticProperties> &static_props,
+ size_t *num_props, bool *use_wp, bool *wp_only,
+ bool *gradient_only) {
+ *num_props = 0;
+ bool has_wp = false;
+ bool has_non_wp = false;
+ *gradient_only = true;
+ const auto mark_property = [&](int32_t p) {
+ if (p == kWPProp) {
+ has_wp = true;
+ } else if (p >= kNumStaticProperties) {
+ has_non_wp = true;
+ }
+ if (p >= kNumStaticProperties && p != kGradientProp) {
+ *gradient_only = false;
+ }
+ };
+ FlatTree output;
+ std::queue<size_t> nodes;
+ nodes.push(0);
+ // Produces a trimmed and flattened tree by doing a BFS visit of the original
+ // tree, ignoring branches that are known to be false and proceeding two
+ // levels at a time to collapse nodes in a flatter tree; if an inner parent
+ // node has a leaf as a child, the leaf is duplicated and an implicit fake
+ // node is added. This allows to reduce the number of branches when traversing
+ // the resulting flat tree.
+ while (!nodes.empty()) {
+ size_t cur = nodes.front();
+ nodes.pop();
+ // Skip nodes that we can decide now, by jumping directly to their children.
+ while (global_tree[cur].property < kNumStaticProperties &&
+ global_tree[cur].property != -1) {
+ if (static_props[global_tree[cur].property] > global_tree[cur].splitval) {
+ cur = global_tree[cur].lchild;
+ } else {
+ cur = global_tree[cur].rchild;
+ }
+ }
+ FlatDecisionNode flat;
+ if (global_tree[cur].property == -1) {
+ flat.property0 = -1;
+ flat.childID = global_tree[cur].lchild;
+ flat.predictor = global_tree[cur].predictor;
+ flat.predictor_offset = global_tree[cur].predictor_offset;
+ flat.multiplier = global_tree[cur].multiplier;
+ *gradient_only &= flat.predictor == Predictor::Gradient;
+ has_wp |= flat.predictor == Predictor::Weighted;
+ has_non_wp |= flat.predictor != Predictor::Weighted;
+ output.push_back(flat);
+ continue;
+ }
+ flat.childID = output.size() + nodes.size() + 1;
+
+ flat.property0 = global_tree[cur].property;
+ *num_props = std::max<size_t>(flat.property0 + 1, *num_props);
+ flat.splitval0 = global_tree[cur].splitval;
+
+ for (size_t i = 0; i < 2; i++) {
+ size_t cur_child =
+ i == 0 ? global_tree[cur].lchild : global_tree[cur].rchild;
+ // Skip nodes that we can decide now.
+ while (global_tree[cur_child].property < kNumStaticProperties &&
+ global_tree[cur_child].property != -1) {
+ if (static_props[global_tree[cur_child].property] >
+ global_tree[cur_child].splitval) {
+ cur_child = global_tree[cur_child].lchild;
+ } else {
+ cur_child = global_tree[cur_child].rchild;
+ }
+ }
+ // We ended up in a leaf, add a dummy decision and two copies of the leaf.
+ if (global_tree[cur_child].property == -1) {
+ flat.properties[i] = 0;
+ flat.splitvals[i] = 0;
+ nodes.push(cur_child);
+ nodes.push(cur_child);
+ } else {
+ flat.properties[i] = global_tree[cur_child].property;
+ flat.splitvals[i] = global_tree[cur_child].splitval;
+ nodes.push(global_tree[cur_child].lchild);
+ nodes.push(global_tree[cur_child].rchild);
+ *num_props = std::max<size_t>(flat.properties[i] + 1, *num_props);
+ }
+ }
+
+ for (size_t j = 0; j < 2; j++) mark_property(flat.properties[j]);
+ mark_property(flat.property0);
+ output.push_back(flat);
+ }
+ if (*num_props > kNumNonrefProperties) {
+ *num_props =
+ DivCeil(*num_props - kNumNonrefProperties, kExtraPropsPerChannel) *
+ kExtraPropsPerChannel +
+ kNumNonrefProperties;
+ } else {
+ *num_props = kNumNonrefProperties;
+ }
+ *use_wp = has_wp;
+ *wp_only = has_wp && !has_non_wp;
+
+ return output;
+}
+
+Status DecodeModularChannelMAANS(BitReader *br, ANSSymbolReader *reader,
+ const std::vector<uint8_t> &context_map,
+ const Tree &global_tree,
+ const weighted::Header &wp_header,
+ pixel_type chan, size_t group_id,
+ Image *image) {
+ Channel &channel = image->channel[chan];
+
+ std::array<pixel_type, kNumStaticProperties> static_props = {
+ {chan, (int)group_id}};
+ // TODO(veluca): filter the tree according to static_props.
+
+ // zero pixel channel? could happen
+ if (channel.w == 0 || channel.h == 0) return true;
+
+ bool tree_has_wp_prop_or_pred = false;
+ bool is_wp_only = false;
+ bool is_gradient_only = false;
+ size_t num_props;
+ FlatTree tree =
+ FilterTree(global_tree, static_props, &num_props,
+ &tree_has_wp_prop_or_pred, &is_wp_only, &is_gradient_only);
+
+ // From here on, tree lookup returns a *clustered* context ID.
+ // This avoids an extra memory lookup after tree traversal.
+ for (size_t i = 0; i < tree.size(); i++) {
+ if (tree[i].property0 == -1) {
+ tree[i].childID = context_map[tree[i].childID];
+ }
+ }
+
+ JXL_DEBUG_V(3, "Decoded MA tree with %" PRIuS " nodes", tree.size());
+
+ // MAANS decode
+ const auto make_pixel = [](uint64_t v, pixel_type multiplier,
+ pixel_type_w offset) -> pixel_type {
+ JXL_DASSERT((v & 0xFFFFFFFF) == v);
+ pixel_type_w val = UnpackSigned(v);
+ // if it overflows, it overflows, and we have a problem anyway
+ return val * multiplier + offset;
+ };
+
+ if (tree.size() == 1) {
+ // special optimized case: no meta-adaptation, so no need
+ // to compute properties.
+ Predictor predictor = tree[0].predictor;
+ int64_t offset = tree[0].predictor_offset;
+ int32_t multiplier = tree[0].multiplier;
+ size_t ctx_id = tree[0].childID;
+ if (predictor == Predictor::Zero) {
+ uint32_t value;
+ if (reader->IsSingleValueAndAdvance(ctx_id, &value,
+ channel.w * channel.h)) {
+ // Special-case: histogram has a single symbol, with no extra bits, and
+ // we use ANS mode.
+ JXL_DEBUG_V(8, "Fastest track.");
+ pixel_type v = make_pixel(value, multiplier, offset);
+ for (size_t y = 0; y < channel.h; y++) {
+ pixel_type *JXL_RESTRICT r = channel.Row(y);
+ std::fill(r, r + channel.w, v);
+ }
+ } else {
+ JXL_DEBUG_V(8, "Fast track.");
+ if (multiplier == 1 && offset == 0) {
+ for (size_t y = 0; y < channel.h; y++) {
+ pixel_type *JXL_RESTRICT r = channel.Row(y);
+ for (size_t x = 0; x < channel.w; x++) {
+ uint32_t v = reader->ReadHybridUintClustered(ctx_id, br);
+ r[x] = UnpackSigned(v);
+ }
+ }
+ } else {
+ for (size_t y = 0; y < channel.h; y++) {
+ pixel_type *JXL_RESTRICT r = channel.Row(y);
+ for (size_t x = 0; x < channel.w; x++) {
+ uint32_t v = reader->ReadHybridUintClustered(ctx_id, br);
+ r[x] = make_pixel(v, multiplier, offset);
+ }
+ }
+ }
+ }
+ } else if (predictor == Predictor::Gradient && offset == 0 &&
+ multiplier == 1 && reader->HuffRleOnly()) {
+ JXL_DEBUG_V(8, "Gradient RLE (fjxl) very fast track.");
+ uint32_t run = 0;
+ uint32_t v = 0;
+ pixel_type_w sv = 0;
+ for (size_t y = 0; y < channel.h; y++) {
+ pixel_type *JXL_RESTRICT r = channel.Row(y);
+ const pixel_type *JXL_RESTRICT rtop = (y ? channel.Row(y - 1) : r - 1);
+ const pixel_type *JXL_RESTRICT rtopleft =
+ (y ? channel.Row(y - 1) - 1 : r - 1);
+ pixel_type_w guess = (y ? rtop[0] : 0);
+ if (run == 0) {
+ reader->ReadHybridUintClusteredHuffRleOnly(ctx_id, br, &v, &run);
+ sv = UnpackSigned(v);
+ } else {
+ run--;
+ }
+ r[0] = sv + guess;
+ for (size_t x = 1; x < channel.w; x++) {
+ pixel_type left = r[x - 1];
+ pixel_type top = rtop[x];
+ pixel_type topleft = rtopleft[x];
+ pixel_type_w guess = ClampedGradient(top, left, topleft);
+ if (!run) {
+ reader->ReadHybridUintClusteredHuffRleOnly(ctx_id, br, &v, &run);
+ sv = UnpackSigned(v);
+ } else {
+ run--;
+ }
+ r[x] = sv + guess;
+ }
+ }
+ } else if (predictor == Predictor::Gradient && offset == 0 &&
+ multiplier == 1) {
+ JXL_DEBUG_V(8, "Gradient very fast track.");
+ const intptr_t onerow = channel.plane.PixelsPerRow();
+ for (size_t y = 0; y < channel.h; y++) {
+ pixel_type *JXL_RESTRICT r = channel.Row(y);
+ for (size_t x = 0; x < channel.w; x++) {
+ pixel_type left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0);
+ pixel_type top = (y ? *(r + x - onerow) : left);
+ pixel_type topleft = (x && y ? *(r + x - 1 - onerow) : left);
+ pixel_type guess = ClampedGradient(top, left, topleft);
+ uint64_t v = reader->ReadHybridUintClustered(ctx_id, br);
+ r[x] = make_pixel(v, 1, guess);
+ }
+ }
+ } else if (predictor != Predictor::Weighted) {
+ // special optimized case: no wp
+ JXL_DEBUG_V(8, "Quite fast track.");
+ const intptr_t onerow = channel.plane.PixelsPerRow();
+ for (size_t y = 0; y < channel.h; y++) {
+ pixel_type *JXL_RESTRICT r = channel.Row(y);
+ for (size_t x = 0; x < channel.w; x++) {
+ PredictionResult pred =
+ PredictNoTreeNoWP(channel.w, r + x, onerow, x, y, predictor);
+ pixel_type_w g = pred.guess + offset;
+ uint64_t v = reader->ReadHybridUintClustered(ctx_id, br);
+ // NOTE: pred.multiplier is unset.
+ r[x] = make_pixel(v, multiplier, g);
+ }
+ }
+ } else {
+ JXL_DEBUG_V(8, "Somewhat fast track.");
+ const intptr_t onerow = channel.plane.PixelsPerRow();
+ weighted::State wp_state(wp_header, channel.w, channel.h);
+ for (size_t y = 0; y < channel.h; y++) {
+ pixel_type *JXL_RESTRICT r = channel.Row(y);
+ for (size_t x = 0; x < channel.w; x++) {
+ pixel_type_w g = PredictNoTreeWP(channel.w, r + x, onerow, x, y,
+ predictor, &wp_state)
+ .guess +
+ offset;
+ uint64_t v = reader->ReadHybridUintClustered(ctx_id, br);
+ r[x] = make_pixel(v, multiplier, g);
+ wp_state.UpdateErrors(r[x], x, y, channel.w);
+ }
+ }
+ }
+ return true;
+ }
+
+ // Check if this tree is a WP-only tree with a small enough property value
+ // range.
+ // Initialized to avoid clang-tidy complaining.
+ uint8_t context_lookup[2 * kPropRangeFast] = {};
+ int8_t multipliers[2 * kPropRangeFast] = {};
+ int8_t offsets[2 * kPropRangeFast] = {};
+ if (is_wp_only) {
+ is_wp_only = TreeToLookupTable(tree, context_lookup, offsets, multipliers);
+ }
+ if (is_gradient_only) {
+ is_gradient_only =
+ TreeToLookupTable(tree, context_lookup, offsets, multipliers);
+ }
+
+ if (is_gradient_only) {
+ JXL_DEBUG_V(8, "Gradient fast track.");
+ const intptr_t onerow = channel.plane.PixelsPerRow();
+ for (size_t y = 0; y < channel.h; y++) {
+ pixel_type *JXL_RESTRICT r = channel.Row(y);
+ for (size_t x = 0; x < channel.w; x++) {
+ pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0);
+ pixel_type_w top = (y ? *(r + x - onerow) : left);
+ pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left);
+ int32_t guess = ClampedGradient(top, left, topleft);
+ uint32_t pos =
+ kPropRangeFast +
+ std::min<pixel_type_w>(
+ std::max<pixel_type_w>(-kPropRangeFast, top + left - topleft),
+ kPropRangeFast - 1);
+ uint32_t ctx_id = context_lookup[pos];
+ uint64_t v = reader->ReadHybridUintClustered(ctx_id, br);
+ r[x] = make_pixel(v, multipliers[pos],
+ static_cast<pixel_type_w>(offsets[pos]) + guess);
+ }
+ }
+ } else if (is_wp_only) {
+ JXL_DEBUG_V(8, "WP fast track.");
+ const intptr_t onerow = channel.plane.PixelsPerRow();
+ weighted::State wp_state(wp_header, channel.w, channel.h);
+ Properties properties(1);
+ for (size_t y = 0; y < channel.h; y++) {
+ pixel_type *JXL_RESTRICT r = channel.Row(y);
+ for (size_t x = 0; x < channel.w; x++) {
+ size_t offset = 0;
+ pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0);
+ pixel_type_w top = (y ? *(r + x - onerow) : left);
+ pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left);
+ pixel_type_w topright =
+ (x + 1 < channel.w && y ? *(r + x + 1 - onerow) : top);
+ pixel_type_w toptop = (y > 1 ? *(r + x - onerow - onerow) : top);
+ int32_t guess = wp_state.Predict</*compute_properties=*/true>(
+ x, y, channel.w, top, left, topright, topleft, toptop, &properties,
+ offset);
+ uint32_t pos =
+ kPropRangeFast + std::min(std::max(-kPropRangeFast, properties[0]),
+ kPropRangeFast - 1);
+ uint32_t ctx_id = context_lookup[pos];
+ uint64_t v = reader->ReadHybridUintClustered(ctx_id, br);
+ r[x] = make_pixel(v, multipliers[pos],
+ static_cast<pixel_type_w>(offsets[pos]) + guess);
+ wp_state.UpdateErrors(r[x], x, y, channel.w);
+ }
+ }
+ } else if (!tree_has_wp_prop_or_pred) {
+ // special optimized case: the weighted predictor and its properties are not
+ // used, so no need to compute weights and properties.
+ JXL_DEBUG_V(8, "Slow track.");
+ MATreeLookup tree_lookup(tree);
+ Properties properties = Properties(num_props);
+ const intptr_t onerow = channel.plane.PixelsPerRow();
+ Channel references(properties.size() - kNumNonrefProperties, channel.w);
+ for (size_t y = 0; y < channel.h; y++) {
+ pixel_type *JXL_RESTRICT p = channel.Row(y);
+ PrecomputeReferences(channel, y, *image, chan, &references);
+ InitPropsRow(&properties, static_props, y);
+ if (y > 1 && channel.w > 8 && references.w == 0) {
+ for (size_t x = 0; x < 2; x++) {
+ PredictionResult res =
+ PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y,
+ tree_lookup, references);
+ uint64_t v = reader->ReadHybridUintClustered(res.context, br);
+ p[x] = make_pixel(v, res.multiplier, res.guess);
+ }
+ for (size_t x = 2; x < channel.w - 2; x++) {
+ PredictionResult res =
+ PredictTreeNoWPNEC(&properties, channel.w, p + x, onerow, x, y,
+ tree_lookup, references);
+ uint64_t v = reader->ReadHybridUintClustered(res.context, br);
+ p[x] = make_pixel(v, res.multiplier, res.guess);
+ }
+ for (size_t x = channel.w - 2; x < channel.w; x++) {
+ PredictionResult res =
+ PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y,
+ tree_lookup, references);
+ uint64_t v = reader->ReadHybridUintClustered(res.context, br);
+ p[x] = make_pixel(v, res.multiplier, res.guess);
+ }
+ } else {
+ for (size_t x = 0; x < channel.w; x++) {
+ PredictionResult res =
+ PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y,
+ tree_lookup, references);
+ uint64_t v = reader->ReadHybridUintClustered(res.context, br);
+ p[x] = make_pixel(v, res.multiplier, res.guess);
+ }
+ }
+ }
+ } else {
+ JXL_DEBUG_V(8, "Slowest track.");
+ MATreeLookup tree_lookup(tree);
+ Properties properties = Properties(num_props);
+ const intptr_t onerow = channel.plane.PixelsPerRow();
+ Channel references(properties.size() - kNumNonrefProperties, channel.w);
+ weighted::State wp_state(wp_header, channel.w, channel.h);
+ for (size_t y = 0; y < channel.h; y++) {
+ pixel_type *JXL_RESTRICT p = channel.Row(y);
+ InitPropsRow(&properties, static_props, y);
+ PrecomputeReferences(channel, y, *image, chan, &references);
+ for (size_t x = 0; x < channel.w; x++) {
+ PredictionResult res =
+ PredictTreeWP(&properties, channel.w, p + x, onerow, x, y,
+ tree_lookup, references, &wp_state);
+ uint64_t v = reader->ReadHybridUintClustered(res.context, br);
+ p[x] = make_pixel(v, res.multiplier, res.guess);
+ wp_state.UpdateErrors(p[x], x, y, channel.w);
+ }
+ }
+ }
+ return true;
+}
+
+GroupHeader::GroupHeader() { Bundle::Init(this); }
+
+Status ValidateChannelDimensions(const Image &image,
+ const ModularOptions &options) {
+ size_t nb_channels = image.channel.size();
+ for (bool is_dc : {true, false}) {
+ size_t group_dim = options.group_dim * (is_dc ? kBlockDim : 1);
+ size_t c = image.nb_meta_channels;
+ for (; c < nb_channels; c++) {
+ const Channel &ch = image.channel[c];
+ if (ch.w > options.group_dim || ch.h > options.group_dim) break;
+ }
+ for (; c < nb_channels; c++) {
+ const Channel &ch = image.channel[c];
+ if (ch.w == 0 || ch.h == 0) continue; // skip empty
+ bool is_dc_channel = std::min(ch.hshift, ch.vshift) >= 3;
+ if (is_dc_channel != is_dc) continue;
+ size_t tile_dim = group_dim >> std::max(ch.hshift, ch.vshift);
+ if (tile_dim == 0) {
+ return JXL_FAILURE("Inconsistent transforms");
+ }
+ }
+ }
+ return true;
+}
+
+Status ModularDecode(BitReader *br, Image &image, GroupHeader &header,
+ size_t group_id, ModularOptions *options,
+ const Tree *global_tree, const ANSCode *global_code,
+ const std::vector<uint8_t> *global_ctx_map,
+ bool allow_truncated_group) {
+ if (image.channel.empty()) return true;
+
+ // decode transforms
+ Status status = Bundle::Read(br, &header);
+ if (!allow_truncated_group) JXL_RETURN_IF_ERROR(status);
+ if (status.IsFatalError()) return status;
+ if (!br->AllReadsWithinBounds()) {
+ // Don't do/undo transforms if header is incomplete.
+ header.transforms.clear();
+ image.transform = header.transforms;
+ for (size_t c = 0; c < image.channel.size(); c++) {
+ ZeroFillImage(&image.channel[c].plane);
+ }
+ return Status(StatusCode::kNotEnoughBytes);
+ }
+
+ JXL_DEBUG_V(3, "Image data underwent %" PRIuS " transformations: ",
+ header.transforms.size());
+ image.transform = header.transforms;
+ for (Transform &transform : image.transform) {
+ JXL_RETURN_IF_ERROR(transform.MetaApply(image));
+ }
+ if (image.error) {
+ return JXL_FAILURE("Corrupt file. Aborting.");
+ }
+ JXL_RETURN_IF_ERROR(ValidateChannelDimensions(image, *options));
+
+ size_t nb_channels = image.channel.size();
+
+ size_t num_chans = 0;
+ size_t distance_multiplier = 0;
+ for (size_t i = 0; i < nb_channels; i++) {
+ Channel &channel = image.channel[i];
+ if (!channel.w || !channel.h) {
+ continue; // skip empty channels
+ }
+ if (i >= image.nb_meta_channels && (channel.w > options->max_chan_size ||
+ channel.h > options->max_chan_size)) {
+ break;
+ }
+ if (channel.w > distance_multiplier) {
+ distance_multiplier = channel.w;
+ }
+ num_chans++;
+ }
+ if (num_chans == 0) return true;
+
+ size_t next_channel = 0;
+ auto scope_guard = MakeScopeGuard([&]() {
+ // Do not do anything if truncated groups are not allowed.
+ if (!allow_truncated_group) return;
+ for (size_t c = next_channel; c < nb_channels; c++) {
+ ZeroFillImage(&image.channel[c].plane);
+ }
+ });
+
+ // Read tree.
+ Tree tree_storage;
+ std::vector<uint8_t> context_map_storage;
+ ANSCode code_storage;
+ const Tree *tree = &tree_storage;
+ const ANSCode *code = &code_storage;
+ const std::vector<uint8_t> *context_map = &context_map_storage;
+ if (!header.use_global_tree) {
+ size_t max_tree_size = 1024;
+ for (size_t i = 0; i < nb_channels; i++) {
+ Channel &channel = image.channel[i];
+ if (!channel.w || !channel.h) {
+ continue; // skip empty channels
+ }
+ if (i >= image.nb_meta_channels && (channel.w > options->max_chan_size ||
+ channel.h > options->max_chan_size)) {
+ break;
+ }
+ size_t pixels = channel.w * channel.h;
+ if (pixels / channel.w != channel.h) {
+ return JXL_FAILURE("Tree size overflow");
+ }
+ max_tree_size += pixels;
+ if (max_tree_size < pixels) return JXL_FAILURE("Tree size overflow");
+ }
+ max_tree_size = std::min(static_cast<size_t>(1 << 20), max_tree_size);
+ JXL_RETURN_IF_ERROR(DecodeTree(br, &tree_storage, max_tree_size));
+ JXL_RETURN_IF_ERROR(DecodeHistograms(br, (tree_storage.size() + 1) / 2,
+ &code_storage, &context_map_storage));
+ } else {
+ if (!global_tree || !global_code || !global_ctx_map ||
+ global_tree->empty()) {
+ return JXL_FAILURE("No global tree available but one was requested");
+ }
+ tree = global_tree;
+ code = global_code;
+ context_map = global_ctx_map;
+ }
+
+ // Read channels
+ ANSSymbolReader reader(code, br, distance_multiplier);
+ for (; next_channel < nb_channels; next_channel++) {
+ Channel &channel = image.channel[next_channel];
+ if (!channel.w || !channel.h) {
+ continue; // skip empty channels
+ }
+ if (next_channel >= image.nb_meta_channels &&
+ (channel.w > options->max_chan_size ||
+ channel.h > options->max_chan_size)) {
+ break;
+ }
+ JXL_RETURN_IF_ERROR(DecodeModularChannelMAANS(
+ br, &reader, *context_map, *tree, header.wp_header, next_channel,
+ group_id, &image));
+ // Truncated group.
+ if (!br->AllReadsWithinBounds()) {
+ if (!allow_truncated_group) return JXL_FAILURE("Truncated input");
+ return Status(StatusCode::kNotEnoughBytes);
+ }
+ }
+
+ // Make sure no zero-filling happens even if next_channel < nb_channels.
+ scope_guard.Disarm();
+
+ if (!reader.CheckANSFinalState()) {
+ return JXL_FAILURE("ANS decode final state failed");
+ }
+ return true;
+}
+
+Status ModularGenericDecompress(BitReader *br, Image &image,
+ GroupHeader *header, size_t group_id,
+ ModularOptions *options, bool undo_transforms,
+ const Tree *tree, const ANSCode *code,
+ const std::vector<uint8_t> *ctx_map,
+ bool allow_truncated_group) {
+#ifdef JXL_ENABLE_ASSERT
+ std::vector<std::pair<uint32_t, uint32_t>> req_sizes(image.channel.size());
+ for (size_t c = 0; c < req_sizes.size(); c++) {
+ req_sizes[c] = {image.channel[c].w, image.channel[c].h};
+ }
+#endif
+ GroupHeader local_header;
+ if (header == nullptr) header = &local_header;
+ size_t bit_pos = br->TotalBitsConsumed();
+ auto dec_status = ModularDecode(br, image, *header, group_id, options, tree,
+ code, ctx_map, allow_truncated_group);
+ if (!allow_truncated_group) JXL_RETURN_IF_ERROR(dec_status);
+ if (dec_status.IsFatalError()) return dec_status;
+ if (undo_transforms) image.undo_transforms(header->wp_header);
+ if (image.error) return JXL_FAILURE("Corrupt file. Aborting.");
+ JXL_DEBUG_V(4,
+ "Modular-decoded a %" PRIuS "x%" PRIuS " nbchans=%" PRIuS
+ " image from %" PRIuS " bytes",
+ image.w, image.h, image.channel.size(),
+ (br->TotalBitsConsumed() - bit_pos) / 8);
+ JXL_DEBUG_V(5, "Modular image: %s", image.DebugString().c_str());
+ (void)bit_pos;
+#ifdef JXL_ENABLE_ASSERT
+ // Check that after applying all transforms we are back to the requested image
+ // sizes, otherwise there's a programming error with the transformations.
+ if (undo_transforms) {
+ JXL_ASSERT(image.channel.size() == req_sizes.size());
+ for (size_t c = 0; c < req_sizes.size(); c++) {
+ JXL_ASSERT(req_sizes[c].first == image.channel[c].w);
+ JXL_ASSERT(req_sizes[c].second == image.channel[c].h);
+ }
+ }
+#endif
+ return dec_status;
+}
+
+} // namespace jxl
diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/encoding.h b/third_party/jpeg-xl/lib/jxl/modular/encoding/encoding.h
new file mode 100644
index 0000000000..89697bce87
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/encoding.h
@@ -0,0 +1,135 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#ifndef LIB_JXL_MODULAR_ENCODING_ENCODING_H_
+#define LIB_JXL_MODULAR_ENCODING_ENCODING_H_
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <vector>
+
+#include "lib/jxl/dec_ans.h"
+#include "lib/jxl/image.h"
+#include "lib/jxl/modular/encoding/context_predict.h"
+#include "lib/jxl/modular/encoding/dec_ma.h"
+#include "lib/jxl/modular/modular_image.h"
+#include "lib/jxl/modular/options.h"
+#include "lib/jxl/modular/transform/transform.h"
+
+namespace jxl {
+
+// Valid range of properties for using lookup tables instead of trees.
+constexpr int32_t kPropRangeFast = 512;
+
+struct GroupHeader : public Fields {
+ GroupHeader();
+
+ JXL_FIELDS_NAME(GroupHeader)
+
+ Status VisitFields(Visitor *JXL_RESTRICT visitor) override {
+ JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &use_global_tree));
+ JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&wp_header));
+ uint32_t num_transforms = static_cast<uint32_t>(transforms.size());
+ JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(0), Val(1), BitsOffset(4, 2),
+ BitsOffset(8, 18), 0,
+ &num_transforms));
+ if (visitor->IsReading()) transforms.resize(num_transforms);
+ for (size_t i = 0; i < num_transforms; i++) {
+ JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&transforms[i]));
+ }
+ return true;
+ }
+
+ bool use_global_tree;
+ weighted::Header wp_header;
+
+ std::vector<Transform> transforms;
+};
+
+FlatTree FilterTree(const Tree &global_tree,
+ std::array<pixel_type, kNumStaticProperties> &static_props,
+ size_t *num_props, bool *use_wp, bool *wp_only,
+ bool *gradient_only);
+
+template <typename T>
+bool TreeToLookupTable(const FlatTree &tree,
+ T context_lookup[2 * kPropRangeFast],
+ int8_t offsets[2 * kPropRangeFast],
+ int8_t multipliers[2 * kPropRangeFast] = nullptr) {
+ struct TreeRange {
+ // Begin *excluded*, end *included*. This works best with > vs <= decision
+ // nodes.
+ int begin, end;
+ size_t pos;
+ };
+ std::vector<TreeRange> ranges;
+ ranges.push_back(TreeRange{-kPropRangeFast - 1, kPropRangeFast - 1, 0});
+ while (!ranges.empty()) {
+ TreeRange cur = ranges.back();
+ ranges.pop_back();
+ if (cur.begin < -kPropRangeFast - 1 || cur.begin >= kPropRangeFast - 1 ||
+ cur.end > kPropRangeFast - 1) {
+ // Tree is outside the allowed range, exit.
+ return false;
+ }
+ auto &node = tree[cur.pos];
+ // Leaf.
+ if (node.property0 == -1) {
+ if (node.predictor_offset < std::numeric_limits<int8_t>::min() ||
+ node.predictor_offset > std::numeric_limits<int8_t>::max()) {
+ return false;
+ }
+ if (node.multiplier < std::numeric_limits<int8_t>::min() ||
+ node.multiplier > std::numeric_limits<int8_t>::max()) {
+ return false;
+ }
+ if (multipliers == nullptr && node.multiplier != 1) {
+ return false;
+ }
+ for (int i = cur.begin + 1; i < cur.end + 1; i++) {
+ context_lookup[i + kPropRangeFast] = node.childID;
+ if (multipliers) multipliers[i + kPropRangeFast] = node.multiplier;
+ offsets[i + kPropRangeFast] = node.predictor_offset;
+ }
+ continue;
+ }
+ // > side of top node.
+ if (node.properties[0] >= kNumStaticProperties) {
+ ranges.push_back(TreeRange({node.splitvals[0], cur.end, node.childID}));
+ ranges.push_back(
+ TreeRange({node.splitval0, node.splitvals[0], node.childID + 1}));
+ } else {
+ ranges.push_back(TreeRange({node.splitval0, cur.end, node.childID}));
+ }
+ // <= side
+ if (node.properties[1] >= kNumStaticProperties) {
+ ranges.push_back(
+ TreeRange({node.splitvals[1], node.splitval0, node.childID + 2}));
+ ranges.push_back(
+ TreeRange({cur.begin, node.splitvals[1], node.childID + 3}));
+ } else {
+ ranges.push_back(
+ TreeRange({cur.begin, node.splitval0, node.childID + 2}));
+ }
+ }
+ return true;
+}
+// TODO(veluca): make cleaner interfaces.
+
+Status ValidateChannelDimensions(const Image &image,
+ const ModularOptions &options);
+
+Status ModularGenericDecompress(BitReader *br, Image &image,
+ GroupHeader *header, size_t group_id,
+ ModularOptions *options,
+ bool undo_transforms = true,
+ const Tree *tree = nullptr,
+ const ANSCode *code = nullptr,
+ const std::vector<uint8_t> *ctx_map = nullptr,
+ bool allow_truncated_group = false);
+} // namespace jxl
+
+#endif // LIB_JXL_MODULAR_ENCODING_ENCODING_H_
diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/ma_common.h b/third_party/jpeg-xl/lib/jxl/modular/encoding/ma_common.h
new file mode 100644
index 0000000000..71b7847321
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/ma_common.h
@@ -0,0 +1,28 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#ifndef LIB_JXL_MODULAR_ENCODING_MA_COMMON_H_
+#define LIB_JXL_MODULAR_ENCODING_MA_COMMON_H_
+
+#include <stddef.h>
+
+namespace jxl {
+
+enum MATreeContext : size_t {
+ kSplitValContext = 0,
+ kPropertyContext = 1,
+ kPredictorContext = 2,
+ kOffsetContext = 3,
+ kMultiplierLogContext = 4,
+ kMultiplierBitsContext = 5,
+
+ kNumTreeContexts = 6,
+};
+
+static constexpr size_t kMaxTreeSize = 1 << 22;
+
+} // namespace jxl
+
+#endif // LIB_JXL_MODULAR_ENCODING_MA_COMMON_H_
diff --git a/third_party/jpeg-xl/lib/jxl/modular/modular_image.cc b/third_party/jpeg-xl/lib/jxl/modular/modular_image.cc
new file mode 100644
index 0000000000..785d0c5443
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/modular/modular_image.cc
@@ -0,0 +1,77 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#include "lib/jxl/modular/modular_image.h"
+
+#include <sstream>
+
+#include "lib/jxl/base/status.h"
+#include "lib/jxl/common.h"
+#include "lib/jxl/modular/transform/transform.h"
+
+namespace jxl {
+
+void Image::undo_transforms(const weighted::Header &wp_header,
+ jxl::ThreadPool *pool) {
+ while (!transform.empty()) {
+ Transform t = transform.back();
+ JXL_DEBUG_V(4, "Undoing transform");
+ Status result = t.Inverse(*this, wp_header, pool);
+ if (result == false) {
+ JXL_NOTIFY_ERROR("Error while undoing transform.");
+ error = true;
+ return;
+ }
+ JXL_DEBUG_V(8, "Undoing transform: done");
+ transform.pop_back();
+ }
+}
+
+Image::Image(size_t iw, size_t ih, int bitdepth, int nb_chans)
+ : w(iw), h(ih), bitdepth(bitdepth), nb_meta_channels(0), error(false) {
+ for (int i = 0; i < nb_chans; i++) channel.emplace_back(Channel(iw, ih));
+}
+
+Image::Image() : w(0), h(0), bitdepth(8), nb_meta_channels(0), error(true) {}
+
+Image &Image::operator=(Image &&other) noexcept {
+ w = other.w;
+ h = other.h;
+ bitdepth = other.bitdepth;
+ nb_meta_channels = other.nb_meta_channels;
+ error = other.error;
+ channel = std::move(other.channel);
+ transform = std::move(other.transform);
+ return *this;
+}
+
+Image Image::clone() {
+ Image c(w, h, bitdepth, 0);
+ c.nb_meta_channels = nb_meta_channels;
+ c.error = error;
+ c.transform = transform;
+ for (Channel &ch : channel) {
+ Channel a(ch.w, ch.h, ch.hshift, ch.vshift);
+ CopyImageTo(ch.plane, &a.plane);
+ c.channel.push_back(std::move(a));
+ }
+ return c;
+}
+
+std::string Image::DebugString() const {
+ std::ostringstream os;
+ os << w << "x" << h << ", depth: " << bitdepth;
+ if (!channel.empty()) {
+ os << ", channels:";
+ for (size_t i = 0; i < channel.size(); ++i) {
+ os << " " << channel[i].w << "x" << channel[i].h
+ << "(shift: " << channel[i].hshift << "," << channel[i].vshift << ")";
+ if (i < nb_meta_channels) os << "*";
+ }
+ }
+ return os.str();
+}
+
+} // namespace jxl
diff --git a/third_party/jpeg-xl/lib/jxl/modular/modular_image.h b/third_party/jpeg-xl/lib/jxl/modular/modular_image.h
new file mode 100644
index 0000000000..3e9b5a8a08
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/modular/modular_image.h
@@ -0,0 +1,118 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#ifndef LIB_JXL_MODULAR_MODULAR_IMAGE_H_
+#define LIB_JXL_MODULAR_MODULAR_IMAGE_H_
+
+#include <stddef.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <string.h>
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "lib/jxl/base/compiler_specific.h"
+#include "lib/jxl/base/data_parallel.h"
+#include "lib/jxl/base/status.h"
+#include "lib/jxl/image.h"
+#include "lib/jxl/image_ops.h"
+
+namespace jxl {
+
+typedef int32_t pixel_type; // can use int16_t if it's only for 8-bit images.
+ // Need some wiggle room for YCoCg / Squeeze etc
+
+typedef int64_t pixel_type_w;
+
+namespace weighted {
+struct Header;
+}
+
+class Channel {
+ public:
+ jxl::Plane<pixel_type> plane;
+ size_t w, h;
+ int hshift, vshift; // w ~= image.w >> hshift; h ~= image.h >> vshift
+ Channel(size_t iw, size_t ih, int hsh = 0, int vsh = 0)
+ : plane(iw, ih), w(iw), h(ih), hshift(hsh), vshift(vsh) {}
+
+ Channel(const Channel& other) = delete;
+ Channel& operator=(const Channel& other) = delete;
+
+ // Move assignment
+ Channel& operator=(Channel&& other) noexcept {
+ w = other.w;
+ h = other.h;
+ hshift = other.hshift;
+ vshift = other.vshift;
+ plane = std::move(other.plane);
+ return *this;
+ }
+
+ // Move constructor
+ Channel(Channel&& other) noexcept = default;
+
+ void shrink() {
+ if (plane.xsize() == w && plane.ysize() == h) return;
+ jxl::Plane<pixel_type> resizedplane(w, h);
+ plane = std::move(resizedplane);
+ }
+ void shrink(int nw, int nh) {
+ w = nw;
+ h = nh;
+ shrink();
+ }
+
+ JXL_INLINE pixel_type* Row(const size_t y) { return plane.Row(y); }
+ JXL_INLINE const pixel_type* Row(const size_t y) const {
+ return plane.Row(y);
+ }
+};
+
+class Transform;
+
+class Image {
+ public:
+ // image data, transforms can dramatically change the number of channels and
+ // their semantics
+ std::vector<Channel> channel;
+ // transforms that have been applied (and that have to be undone)
+ std::vector<Transform> transform;
+
+ // image dimensions (channels may have different dimensions due to transforms)
+ size_t w, h;
+ int bitdepth;
+ size_t nb_meta_channels; // first few channels might contain palette(s)
+ bool error; // true if a fatal error occurred, false otherwise
+
+ Image(size_t iw, size_t ih, int bitdepth, int nb_chans);
+ Image();
+
+ Image(const Image& other) = delete;
+ Image& operator=(const Image& other) = delete;
+
+ Image& operator=(Image&& other) noexcept;
+ Image(Image&& other) noexcept = default;
+
+ bool empty() const {
+ for (const auto& ch : channel) {
+ if (ch.w && ch.h) return false;
+ }
+ return true;
+ }
+
+ Image clone();
+
+ void undo_transforms(const weighted::Header& wp_header,
+ jxl::ThreadPool* pool = nullptr);
+
+ std::string DebugString() const;
+};
+
+} // namespace jxl
+
+#endif // LIB_JXL_MODULAR_MODULAR_IMAGE_H_
diff --git a/third_party/jpeg-xl/lib/jxl/modular/options.h b/third_party/jpeg-xl/lib/jxl/modular/options.h
new file mode 100644
index 0000000000..ce6596b912
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/modular/options.h
@@ -0,0 +1,117 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#ifndef LIB_JXL_MODULAR_OPTIONS_H_
+#define LIB_JXL_MODULAR_OPTIONS_H_
+
+#include <stdint.h>
+
+#include <array>
+#include <vector>
+
+namespace jxl {
+
+using PropertyVal = int32_t;
+using Properties = std::vector<PropertyVal>;
+
+enum class Predictor : uint32_t {
+ Zero = 0,
+ Left = 1,
+ Top = 2,
+ Average0 = 3,
+ Select = 4,
+ Gradient = 5,
+ Weighted = 6,
+ TopRight = 7,
+ TopLeft = 8,
+ LeftLeft = 9,
+ Average1 = 10,
+ Average2 = 11,
+ Average3 = 12,
+ Average4 = 13,
+ // The following predictors are encoder-only.
+ Best = 14, // Best of Gradient and Weighted
+ Variable =
+ 15, // Find the best decision tree for predictors/predictor per row
+};
+
+constexpr size_t kNumModularPredictors =
+ static_cast<size_t>(Predictor::Average4) + 1;
+constexpr size_t kNumModularEncoderPredictors =
+ static_cast<size_t>(Predictor::Variable) + 1;
+
+static constexpr ssize_t kNumStaticProperties = 2; // channel, group_id.
+
+using StaticPropRange =
+ std::array<std::array<uint32_t, 2>, kNumStaticProperties>;
+
+struct ModularMultiplierInfo {
+ StaticPropRange range;
+ uint32_t multiplier;
+};
+
+struct ModularOptions {
+ /// Used in both encode and decode:
+
+ // Stop encoding/decoding when reaching a (non-meta) channel that has a
+ // dimension bigger than max_chan_size.
+ size_t max_chan_size = 0xFFFFFF;
+
+ // Used during decoding for validation of transforms (sqeeezing) scheme.
+ size_t group_dim = 0x1FFFFFFF;
+
+ /// Encode options:
+ // Fraction of pixels to look at to learn a MA tree
+ // Number of iterations to do to learn a MA tree
+ // (if zero there is no MA context model)
+ float nb_repeats = .5f;
+
+ // Maximum number of (previous channel) properties to use in the MA trees
+ int max_properties = 0; // no previous channels
+
+ // Alternative heuristic tweaks.
+ // Properties default to channel, group, weighted, gradient residual, W-NW,
+ // NW-N, N-NE, N-NN
+ std::vector<uint32_t> splitting_heuristics_properties = {0, 1, 15, 9,
+ 10, 11, 12, 13};
+ float splitting_heuristics_node_threshold = 96;
+ size_t max_property_values = 32;
+
+ // Predictor to use for each channel.
+ Predictor predictor = static_cast<Predictor>(-1);
+
+ int wp_mode = 0;
+
+ float fast_decode_multiplier = 1.01f;
+
+ // Forces the encoder to produce a tree that is compatible with the WP-only
+ // decode path (or with the no-wp path, or the gradient-only path).
+ enum class TreeMode { kGradientOnly, kWPOnly, kNoWP, kDefault };
+ TreeMode wp_tree_mode = TreeMode::kDefault;
+
+ // Skip fast paths in the encoder.
+ bool skip_encoder_fast_path = false;
+
+ // Kind of tree to use.
+ // TODO(veluca): add tree kinds for JPEG recompression with CfL enabled,
+ // general AC metadata, different DC qualities, and others.
+ enum class TreeKind {
+ kTrivialTreeNoPredictor,
+ kLearn,
+ kJpegTranscodeACMeta,
+ kFalconACMeta,
+ kACMeta,
+ kWPFixedDC,
+ kGradientFixedDC,
+ };
+ TreeKind tree_kind = TreeKind::kLearn;
+
+ // Ignore the image and just pretend all tokens are zeroes
+ bool zero_tokens = false;
+};
+
+} // namespace jxl
+
+#endif // LIB_JXL_MODULAR_OPTIONS_H_
diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/enc_palette.cc b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_palette.cc
new file mode 100644
index 0000000000..bc31445bc5
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_palette.cc
@@ -0,0 +1,606 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#include "lib/jxl/modular/transform/enc_palette.h"
+
+#include <array>
+#include <map>
+#include <set>
+
+#include "lib/jxl/base/data_parallel.h"
+#include "lib/jxl/base/status.h"
+#include "lib/jxl/common.h"
+#include "lib/jxl/modular/encoding/context_predict.h"
+#include "lib/jxl/modular/modular_image.h"
+#include "lib/jxl/modular/transform/enc_transform.h"
+#include "lib/jxl/modular/transform/palette.h"
+
+namespace jxl {
+
+namespace palette_internal {
+
+static constexpr bool kEncodeToHighQualityImplicitPalette = true;
+
+// Inclusive.
+static constexpr int kMinImplicitPaletteIndex = -(2 * 72 - 1);
+
+float ColorDistance(const std::vector<float> &JXL_RESTRICT a,
+ const std::vector<pixel_type> &JXL_RESTRICT b) {
+ JXL_ASSERT(a.size() == b.size());
+ float distance = 0;
+ float ave3 = 0;
+ if (a.size() >= 3) {
+ ave3 = (a[0] + b[0] + a[1] + b[1] + a[2] + b[2]) * (1.21f / 3.0f);
+ }
+ float sum_a = 0, sum_b = 0;
+ for (size_t c = 0; c < a.size(); ++c) {
+ const float difference =
+ static_cast<float>(a[c]) - static_cast<float>(b[c]);
+ float weight = c == 0 ? 3 : c == 1 ? 5 : 2;
+ if (c < 3 && (a[c] + b[c] >= ave3)) {
+ const float add_w[3] = {
+ 1.15,
+ 1.15,
+ 1.12,
+ };
+ weight += add_w[c];
+ if (c == 2 && ((a[2] + b[2]) < 1.22 * ave3)) {
+ weight -= 0.5;
+ }
+ }
+ distance += difference * difference * weight * weight;
+ const int sum_weight = c == 0 ? 3 : c == 1 ? 5 : 1;
+ sum_a += a[c] * sum_weight;
+ sum_b += b[c] * sum_weight;
+ }
+ distance *= 4;
+ float sum_difference = sum_a - sum_b;
+ distance += sum_difference * sum_difference;
+ return distance;
+}
+
+static int QuantizeColorToImplicitPaletteIndex(
+ const std::vector<pixel_type> &color, const int palette_size,
+ const int bit_depth, bool high_quality) {
+ int index = 0;
+ if (high_quality) {
+ int multiplier = 1;
+ for (size_t c = 0; c < color.size(); c++) {
+ int quantized = ((kLargeCube - 1) * color[c] + (1 << (bit_depth - 1))) /
+ ((1 << bit_depth) - 1);
+ JXL_ASSERT((quantized % kLargeCube) == quantized);
+ index += quantized * multiplier;
+ multiplier *= kLargeCube;
+ }
+ return index + palette_size + kLargeCubeOffset;
+ } else {
+ int multiplier = 1;
+ for (size_t c = 0; c < color.size(); c++) {
+ int value = color[c];
+ value -= 1 << (std::max(0, bit_depth - 3));
+ value = std::max(0, value);
+ int quantized = ((kLargeCube - 1) * value + (1 << (bit_depth - 1))) /
+ ((1 << bit_depth) - 1);
+ JXL_ASSERT((quantized % kLargeCube) == quantized);
+ if (quantized > kSmallCube - 1) {
+ quantized = kSmallCube - 1;
+ }
+ index += quantized * multiplier;
+ multiplier *= kSmallCube;
+ }
+ return index + palette_size;
+ }
+}
+
+} // namespace palette_internal
+
+int RoundInt(int value, int div) { // symmetric rounding around 0
+ if (value < 0) return -RoundInt(-value, div);
+ return (value + div / 2) / div;
+}
+
+struct PaletteIterationData {
+ static constexpr int kMaxDeltas = 128;
+ bool final_run = false;
+ std::vector<pixel_type> deltas[3];
+ std::vector<double> delta_distances;
+ std::vector<pixel_type> frequent_deltas[3];
+
+ // Populates `frequent_deltas` with items from `deltas` based on frequencies
+ // and color distances.
+ void FindFrequentColorDeltas(int num_pixels, int bitdepth) {
+ using pixel_type_3d = std::array<pixel_type, 3>;
+ std::map<pixel_type_3d, double> delta_frequency_map;
+ pixel_type bucket_size = 3 << std::max(0, bitdepth - 8);
+ // Store frequency weighted by delta distance from quantized value.
+ for (size_t i = 0; i < deltas[0].size(); ++i) {
+ pixel_type_3d delta = {
+ {RoundInt(deltas[0][i], bucket_size),
+ RoundInt(deltas[1][i], bucket_size),
+ RoundInt(deltas[2][i], bucket_size)}}; // a basic form of clustering
+ if (delta[0] == 0 && delta[1] == 0 && delta[2] == 0) continue;
+ delta_frequency_map[delta] += sqrt(sqrt(delta_distances[i]));
+ }
+
+ const float delta_distance_multiplier = 1.0f / num_pixels;
+
+ // Weigh frequencies by magnitude and normalize.
+ for (auto &delta_frequency : delta_frequency_map) {
+ std::vector<pixel_type> current_delta = {delta_frequency.first[0],
+ delta_frequency.first[1],
+ delta_frequency.first[2]};
+ float delta_distance =
+ sqrt(palette_internal::ColorDistance({0, 0, 0}, current_delta)) + 1;
+ delta_frequency.second *= delta_distance * delta_distance_multiplier;
+ }
+
+ // Sort by weighted frequency.
+ using pixel_type_3d_frequency = std::pair<pixel_type_3d, double>;
+ std::vector<pixel_type_3d_frequency> sorted_delta_frequency_map(
+ delta_frequency_map.begin(), delta_frequency_map.end());
+ std::sort(
+ sorted_delta_frequency_map.begin(), sorted_delta_frequency_map.end(),
+ [](const pixel_type_3d_frequency &a, const pixel_type_3d_frequency &b) {
+ return a.second > b.second;
+ });
+
+ // Store the top deltas.
+ for (auto &delta_frequency : sorted_delta_frequency_map) {
+ if (frequent_deltas[0].size() >= kMaxDeltas) break;
+ // Number obtained by optimizing on jyrki31 corpus:
+ if (delta_frequency.second < 17) break;
+ for (int c = 0; c < 3; ++c) {
+ frequent_deltas[c].push_back(delta_frequency.first[c] * bucket_size);
+ }
+ }
+ }
+};
+
+Status FwdPaletteIteration(Image &input, uint32_t begin_c, uint32_t end_c,
+ uint32_t &nb_colors, uint32_t &nb_deltas,
+ bool ordered, bool lossy, Predictor &predictor,
+ const weighted::Header &wp_header,
+ PaletteIterationData &palette_iteration_data) {
+ JXL_QUIET_RETURN_IF_ERROR(CheckEqualChannels(input, begin_c, end_c));
+ JXL_ASSERT(begin_c >= input.nb_meta_channels);
+ uint32_t nb = end_c - begin_c + 1;
+
+ size_t w = input.channel[begin_c].w;
+ size_t h = input.channel[begin_c].h;
+
+ if (!lossy && nb == 1) {
+ // Channel palette special case
+ if (nb_colors == 0) return false;
+ std::vector<pixel_type> lookup;
+ pixel_type minval, maxval;
+ compute_minmax(input.channel[begin_c], &minval, &maxval);
+ size_t lookup_table_size =
+ static_cast<int64_t>(maxval) - static_cast<int64_t>(minval) + 1;
+ if (lookup_table_size > palette_internal::kMaxPaletteLookupTableSize) {
+ // a lookup table would use too much memory, instead use a slower approach
+ // with std::set
+ std::set<pixel_type> chpalette;
+ pixel_type idx = 0;
+ for (size_t y = 0; y < h; y++) {
+ const pixel_type *p = input.channel[begin_c].Row(y);
+ for (size_t x = 0; x < w; x++) {
+ const bool new_color = chpalette.insert(p[x]).second;
+ if (new_color) {
+ idx++;
+ if (idx > (int)nb_colors) return false;
+ }
+ }
+ }
+ JXL_DEBUG_V(6, "Channel %i uses only %i colors.", begin_c, idx);
+ Channel pch(idx, 1);
+ pch.hshift = -1;
+ pch.vshift = -1;
+ nb_colors = idx;
+ idx = 0;
+ pixel_type *JXL_RESTRICT p_palette = pch.Row(0);
+ for (pixel_type p : chpalette) {
+ p_palette[idx++] = p;
+ }
+ for (size_t y = 0; y < h; y++) {
+ pixel_type *p = input.channel[begin_c].Row(y);
+ for (size_t x = 0; x < w; x++) {
+ for (idx = 0; p[x] != p_palette[idx] && idx < (int)nb_colors; idx++) {
+ }
+ JXL_DASSERT(idx < (int)nb_colors);
+ p[x] = idx;
+ }
+ }
+ predictor = Predictor::Zero;
+ input.nb_meta_channels++;
+ input.channel.insert(input.channel.begin(), std::move(pch));
+
+ return true;
+ }
+ lookup.resize(lookup_table_size, 0);
+ pixel_type idx = 0;
+ for (size_t y = 0; y < h; y++) {
+ const pixel_type *p = input.channel[begin_c].Row(y);
+ for (size_t x = 0; x < w; x++) {
+ if (lookup[p[x] - minval] == 0) {
+ lookup[p[x] - minval] = 1;
+ idx++;
+ if (idx > (int)nb_colors) return false;
+ }
+ }
+ }
+ JXL_DEBUG_V(6, "Channel %i uses only %i colors.", begin_c, idx);
+ Channel pch(idx, 1);
+ pch.hshift = -1;
+ pch.vshift = -1;
+ nb_colors = idx;
+ idx = 0;
+ pixel_type *JXL_RESTRICT p_palette = pch.Row(0);
+ for (size_t i = 0; i < lookup_table_size; i++) {
+ if (lookup[i]) {
+ p_palette[idx] = i + minval;
+ lookup[i] = idx;
+ idx++;
+ }
+ }
+ for (size_t y = 0; y < h; y++) {
+ pixel_type *p = input.channel[begin_c].Row(y);
+ for (size_t x = 0; x < w; x++) p[x] = lookup[p[x] - minval];
+ }
+ predictor = Predictor::Zero;
+ input.nb_meta_channels++;
+ input.channel.insert(input.channel.begin(), std::move(pch));
+ return true;
+ }
+
+ Image quantized_input;
+ if (lossy) {
+ quantized_input = Image(w, h, input.bitdepth, nb);
+ for (size_t c = 0; c < nb; c++) {
+ CopyImageTo(input.channel[begin_c + c].plane,
+ &quantized_input.channel[c].plane);
+ }
+ }
+
+ JXL_DEBUG_V(
+ 7, "Trying to represent channels %i-%i using at most a %i-color palette.",
+ begin_c, end_c, nb_colors);
+ nb_deltas = 0;
+ bool delta_used = false;
+ std::set<std::vector<pixel_type>>
+ candidate_palette; // ordered lexicographically
+ std::vector<std::vector<pixel_type>> candidate_palette_imageorder;
+ std::vector<pixel_type> color(nb);
+ std::vector<float> color_with_error(nb);
+ std::vector<const pixel_type *> p_in(nb);
+
+ if (lossy) {
+ palette_iteration_data.FindFrequentColorDeltas(w * h, input.bitdepth);
+ nb_deltas = palette_iteration_data.frequent_deltas[0].size();
+
+ // Count color frequency for colors that make a cross.
+ std::map<std::vector<pixel_type>, size_t> color_freq_map;
+ for (size_t y = 1; y + 1 < h; y++) {
+ for (uint32_t c = 0; c < nb; c++) {
+ p_in[c] = input.channel[begin_c + c].Row(y);
+ }
+ for (size_t x = 1; x + 1 < w; x++) {
+ for (uint32_t c = 0; c < nb; c++) {
+ color[c] = p_in[c][x];
+ }
+ int offsets[4][2] = {{1, 0}, {-1, 0}, {0, 1}, {0, -1}};
+ bool makes_cross = true;
+ for (int i = 0; i < 4 && makes_cross; ++i) {
+ int dx = offsets[i][0];
+ int dy = offsets[i][1];
+ for (uint32_t c = 0; c < nb && makes_cross; c++) {
+ if (input.channel[begin_c + c].Row(y + dy)[x + dx] != color[c]) {
+ makes_cross = false;
+ }
+ }
+ }
+ if (makes_cross) color_freq_map[color] += 1;
+ }
+ }
+ // Add colors satisfying frequency condition to the palette.
+ constexpr float kImageFraction = 0.01f;
+ size_t color_frequency_lower_bound = 5 + input.h * input.w * kImageFraction;
+ for (const auto &color_freq : color_freq_map) {
+ if (color_freq.second > color_frequency_lower_bound) {
+ candidate_palette.insert(color_freq.first);
+ candidate_palette_imageorder.push_back(color_freq.first);
+ }
+ }
+ }
+
+ for (size_t y = 0; y < h; y++) {
+ for (uint32_t c = 0; c < nb; c++) {
+ p_in[c] = input.channel[begin_c + c].Row(y);
+ }
+ for (size_t x = 0; x < w; x++) {
+ if (lossy && candidate_palette.size() >= nb_colors) break;
+ for (uint32_t c = 0; c < nb; c++) {
+ color[c] = p_in[c][x];
+ }
+ const bool new_color = candidate_palette.insert(color).second;
+ if (new_color) {
+ candidate_palette_imageorder.push_back(color);
+ }
+ if (candidate_palette.size() > nb_colors) {
+ return false; // too many colors
+ }
+ }
+ }
+
+ nb_colors = nb_deltas + candidate_palette.size();
+ JXL_DEBUG_V(6, "Channels %i-%i can be represented using a %i-color palette.",
+ begin_c, end_c, nb_colors);
+
+ Channel pch(nb_colors, nb);
+ pch.hshift = -1;
+ pch.vshift = -1;
+ pixel_type *JXL_RESTRICT p_palette = pch.Row(0);
+ intptr_t onerow = pch.plane.PixelsPerRow();
+ intptr_t onerow_image = input.channel[begin_c].plane.PixelsPerRow();
+ const int bit_depth = std::min(input.bitdepth, 24);
+
+ if (lossy) {
+ for (uint32_t i = 0; i < nb_deltas; i++) {
+ for (size_t c = 0; c < 3; c++) {
+ p_palette[c * onerow + i] =
+ palette_iteration_data.frequent_deltas[c][i];
+ }
+ }
+ }
+
+ int x = 0;
+ if (ordered) {
+ JXL_DEBUG_V(7, "Palette of %i colors, using lexicographic order",
+ nb_colors);
+ for (auto pcol : candidate_palette) {
+ JXL_DEBUG_V(9, " Color %i : ", x);
+ for (size_t i = 0; i < nb; i++) {
+ p_palette[nb_deltas + i * onerow + x] = pcol[i];
+ }
+ for (size_t i = 0; i < nb; i++) {
+ JXL_DEBUG_V(9, "%i ", pcol[i]);
+ }
+ x++;
+ }
+ } else {
+ JXL_DEBUG_V(7, "Palette of %i colors, using image order", nb_colors);
+ for (auto pcol : candidate_palette_imageorder) {
+ JXL_DEBUG_V(9, " Color %i : ", x);
+ for (size_t i = 0; i < nb; i++)
+ p_palette[nb_deltas + i * onerow + x] = pcol[i];
+ for (size_t i = 0; i < nb; i++) JXL_DEBUG_V(9, "%i ", pcol[i]);
+ x++;
+ }
+ }
+ std::vector<weighted::State> wp_states;
+ for (size_t c = 0; c < nb; c++) {
+ wp_states.emplace_back(wp_header, w, h);
+ }
+ std::vector<pixel_type *> p_quant(nb);
+ // Three rows of error for dithering: y to y + 2.
+ // Each row has two pixels of padding in the ends, which is
+ // beneficial for both precision and encoding speed.
+ std::vector<std::vector<float>> error_row[3];
+ if (lossy) {
+ for (int i = 0; i < 3; ++i) {
+ error_row[i].resize(nb);
+ for (size_t c = 0; c < nb; ++c) {
+ error_row[i][c].resize(w + 4);
+ }
+ }
+ }
+ for (size_t y = 0; y < h; y++) {
+ for (size_t c = 0; c < nb; c++) {
+ p_in[c] = input.channel[begin_c + c].Row(y);
+ if (lossy) p_quant[c] = quantized_input.channel[c].Row(y);
+ }
+ pixel_type *JXL_RESTRICT p = input.channel[begin_c].Row(y);
+ for (size_t x = 0; x < w; x++) {
+ int index;
+ if (!lossy) {
+ for (size_t c = 0; c < nb; c++) color[c] = p_in[c][x];
+ // Exact search.
+ for (index = 0; static_cast<uint32_t>(index) < nb_colors; index++) {
+ bool found = true;
+ for (size_t c = 0; c < nb; c++) {
+ if (color[c] != p_palette[c * onerow + index]) {
+ found = false;
+ break;
+ }
+ }
+ if (found) break;
+ }
+ if (index < static_cast<int>(nb_deltas)) {
+ delta_used = true;
+ }
+ } else {
+ int best_index = 0;
+ bool best_is_delta = false;
+ float best_distance = std::numeric_limits<float>::infinity();
+ std::vector<pixel_type> best_val(nb, 0);
+ std::vector<pixel_type> ideal_residual(nb, 0);
+ std::vector<pixel_type> quantized_val(nb);
+ std::vector<pixel_type> predictions(nb);
+ static const double kDiffusionMultiplier[] = {0.55, 0.75};
+ for (int diffusion_index = 0; diffusion_index < 2; ++diffusion_index) {
+ for (size_t c = 0; c < nb; c++) {
+ color_with_error[c] =
+ p_in[c][x] + palette_iteration_data.final_run *
+ kDiffusionMultiplier[diffusion_index] *
+ error_row[0][c][x + 2];
+ color[c] = Clamp1(lroundf(color_with_error[c]), 0l,
+ (1l << input.bitdepth) - 1);
+ }
+
+ for (size_t c = 0; c < nb; ++c) {
+ predictions[c] = PredictNoTreeWP(w, p_quant[c] + x, onerow_image, x,
+ y, predictor, &wp_states[c])
+ .guess;
+ }
+ const auto TryIndex = [&](const int index) {
+ for (size_t c = 0; c < nb; c++) {
+ quantized_val[c] = palette_internal::GetPaletteValue(
+ p_palette, index, /*c=*/c,
+ /*palette_size=*/nb_colors,
+ /*onerow=*/onerow, /*bit_depth=*/bit_depth);
+ if (index < static_cast<int>(nb_deltas)) {
+ quantized_val[c] += predictions[c];
+ }
+ }
+ const float color_distance =
+ 32.0 / (1LL << std::max(0, 2 * (bit_depth - 8))) *
+ palette_internal::ColorDistance(color_with_error,
+ quantized_val);
+ float index_penalty = 0;
+ if (index == -1) {
+ index_penalty = -124;
+ } else if (index < 0) {
+ index_penalty = -2 * index;
+ } else if (index < static_cast<int>(nb_deltas)) {
+ index_penalty = 250;
+ } else if (index < static_cast<int>(nb_colors)) {
+ index_penalty = 150;
+ } else if (index < static_cast<int>(nb_colors) +
+ palette_internal::kLargeCubeOffset) {
+ index_penalty = 70;
+ } else {
+ index_penalty = 256;
+ }
+ const float distance = color_distance + index_penalty;
+ if (distance < best_distance) {
+ best_distance = distance;
+ best_index = index;
+ best_is_delta = index < static_cast<int>(nb_deltas);
+ best_val.swap(quantized_val);
+ for (size_t c = 0; c < nb; ++c) {
+ ideal_residual[c] = color_with_error[c] - predictions[c];
+ }
+ }
+ };
+ for (index = palette_internal::kMinImplicitPaletteIndex;
+ index < static_cast<int32_t>(nb_colors); index++) {
+ TryIndex(index);
+ }
+ TryIndex(palette_internal::QuantizeColorToImplicitPaletteIndex(
+ color, nb_colors, bit_depth,
+ /*high_quality=*/false));
+ if (palette_internal::kEncodeToHighQualityImplicitPalette) {
+ TryIndex(palette_internal::QuantizeColorToImplicitPaletteIndex(
+ color, nb_colors, bit_depth,
+ /*high_quality=*/true));
+ }
+ }
+ index = best_index;
+ delta_used |= best_is_delta;
+ if (!palette_iteration_data.final_run) {
+ for (size_t c = 0; c < 3; ++c) {
+ palette_iteration_data.deltas[c].push_back(ideal_residual[c]);
+ }
+ palette_iteration_data.delta_distances.push_back(best_distance);
+ }
+
+ for (size_t c = 0; c < nb; ++c) {
+ wp_states[c].UpdateErrors(best_val[c], x, y, w);
+ p_quant[c][x] = best_val[c];
+ }
+ float len_error = 0;
+ for (size_t c = 0; c < nb; ++c) {
+ float local_error = color_with_error[c] - best_val[c];
+ len_error += local_error * local_error;
+ }
+ len_error = sqrt(len_error);
+ float modulate = 1.0;
+ int len_limit = 38 << std::max(0, bit_depth - 8);
+ if (len_error > len_limit) {
+ modulate *= len_limit / len_error;
+ }
+ for (size_t c = 0; c < nb; ++c) {
+ float total_error = (color_with_error[c] - best_val[c]);
+
+ // If the neighboring pixels have some error in the opposite
+ // direction of total_error, cancel some or all of it out before
+ // spreading among them.
+ constexpr int offsets[12][2] = {{1, 2}, {0, 3}, {0, 4}, {1, 1},
+ {1, 3}, {2, 2}, {1, 0}, {1, 4},
+ {2, 1}, {2, 3}, {2, 0}, {2, 4}};
+ float total_available = 0;
+ for (int i = 0; i < 11; ++i) {
+ const int row = offsets[i][0];
+ const int col = offsets[i][1];
+ if (std::signbit(error_row[row][c][x + col]) !=
+ std::signbit(total_error)) {
+ total_available += error_row[row][c][x + col];
+ }
+ }
+ float weight =
+ std::abs(total_error) / (std::abs(total_available) + 1e-3);
+ weight = std::min(weight, 1.0f);
+ for (int i = 0; i < 11; ++i) {
+ const int row = offsets[i][0];
+ const int col = offsets[i][1];
+ if (std::signbit(error_row[row][c][x + col]) !=
+ std::signbit(total_error)) {
+ total_error += weight * error_row[row][c][x + col];
+ error_row[row][c][x + col] *= (1 - weight);
+ }
+ }
+ total_error *= modulate;
+ const float remaining_error = (1.0f / 14.) * total_error;
+ error_row[0][c][x + 3] += 2 * remaining_error;
+ error_row[0][c][x + 4] += remaining_error;
+ error_row[1][c][x + 0] += remaining_error;
+ for (int i = 0; i < 5; ++i) {
+ error_row[1][c][x + i] += remaining_error;
+ error_row[2][c][x + i] += remaining_error;
+ }
+ }
+ }
+ if (palette_iteration_data.final_run) p[x] = index;
+ }
+ if (lossy) {
+ for (size_t c = 0; c < nb; ++c) {
+ error_row[0][c].swap(error_row[1][c]);
+ error_row[1][c].swap(error_row[2][c]);
+ std::fill(error_row[2][c].begin(), error_row[2][c].end(), 0.f);
+ }
+ }
+ }
+ if (!delta_used) {
+ predictor = Predictor::Zero;
+ }
+ if (palette_iteration_data.final_run) {
+ input.nb_meta_channels++;
+ input.channel.erase(input.channel.begin() + begin_c + 1,
+ input.channel.begin() + end_c + 1);
+ input.channel.insert(input.channel.begin(), std::move(pch));
+ }
+ nb_colors -= nb_deltas;
+ return true;
+}
+
+Status FwdPalette(Image &input, uint32_t begin_c, uint32_t end_c,
+ uint32_t &nb_colors, uint32_t &nb_deltas, bool ordered,
+ bool lossy, Predictor &predictor,
+ const weighted::Header &wp_header) {
+ PaletteIterationData palette_iteration_data;
+ uint32_t nb_colors_orig = nb_colors;
+ uint32_t nb_deltas_orig = nb_deltas;
+ // preprocessing pass in case of lossy palette
+ if (lossy && input.bitdepth >= 8) {
+ JXL_RETURN_IF_ERROR(FwdPaletteIteration(
+ input, begin_c, end_c, nb_colors_orig, nb_deltas_orig, ordered, lossy,
+ predictor, wp_header, palette_iteration_data));
+ }
+ palette_iteration_data.final_run = true;
+ return FwdPaletteIteration(input, begin_c, end_c, nb_colors, nb_deltas,
+ ordered, lossy, predictor, wp_header,
+ palette_iteration_data);
+}
+
+} // namespace jxl
diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/enc_palette.h b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_palette.h
new file mode 100644
index 0000000000..0f3d66825b
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_palette.h
@@ -0,0 +1,22 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#ifndef LIB_JXL_MODULAR_TRANSFORM_ENC_PALETTE_H_
+#define LIB_JXL_MODULAR_TRANSFORM_ENC_PALETTE_H_
+
+#include "lib/jxl/fields.h"
+#include "lib/jxl/modular/encoding/context_predict.h"
+#include "lib/jxl/modular/modular_image.h"
+
+namespace jxl {
+
+Status FwdPalette(Image &input, uint32_t begin_c, uint32_t end_c,
+ uint32_t &nb_colors, uint32_t &nb_deltas, bool ordered,
+ bool lossy, Predictor &predictor,
+ const weighted::Header &wp_header);
+
+} // namespace jxl
+
+#endif // LIB_JXL_MODULAR_TRANSFORM_ENC_PALETTE_H_
diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/enc_rct.cc b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_rct.cc
new file mode 100644
index 0000000000..050563a3c2
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_rct.cc
@@ -0,0 +1,73 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#include "lib/jxl/modular/transform/enc_rct.h"
+
+#include "lib/jxl/base/status.h"
+#include "lib/jxl/common.h"
+#include "lib/jxl/modular/modular_image.h"
+#include "lib/jxl/modular/transform/transform.h" // CheckEqualChannels
+
+namespace jxl {
+
+Status FwdRCT(Image& input, size_t begin_c, size_t rct_type, ThreadPool* pool) {
+ JXL_RETURN_IF_ERROR(CheckEqualChannels(input, begin_c, begin_c + 2));
+ if (rct_type == 0) { // noop
+ return false;
+ }
+ // Permutation: 0=RGB, 1=GBR, 2=BRG, 3=RBG, 4=GRB, 5=BGR
+ int permutation = rct_type / 7;
+ // 0-5 values have the low bit corresponding to Third and the high bits
+ // corresponding to Second. 6 corresponds to YCoCg.
+ //
+ // Second: 0=nop, 1=SubtractFirst, 2=SubtractAvgFirstThird
+ //
+ // Third: 0=nop, 1=SubtractFirst
+ int custom = rct_type % 7;
+ size_t m = begin_c;
+ size_t w = input.channel[m + 0].w;
+ size_t h = input.channel[m + 0].h;
+ int second = (custom % 7) >> 1;
+ int third = (custom % 7) & 1;
+ const auto do_rct = [&](const int y, const int thread) {
+ const pixel_type* in0 = input.channel[m + (permutation % 3)].Row(y);
+ const pixel_type* in1 =
+ input.channel[m + ((permutation + 1 + permutation / 3) % 3)].Row(y);
+ const pixel_type* in2 =
+ input.channel[m + ((permutation + 2 - permutation / 3) % 3)].Row(y);
+ pixel_type* out0 = input.channel[m].Row(y);
+ pixel_type* out1 = input.channel[m + 1].Row(y);
+ pixel_type* out2 = input.channel[m + 2].Row(y);
+ if (custom == 6) {
+ for (size_t x = 0; x < w; x++) {
+ pixel_type R = in0[x];
+ pixel_type G = in1[x];
+ pixel_type B = in2[x];
+ out1[x] = R - B;
+ pixel_type tmp = B + (out1[x] >> 1);
+ out2[x] = G - tmp;
+ out0[x] = tmp + (out2[x] >> 1);
+ }
+ } else {
+ for (size_t x = 0; x < w; x++) {
+ pixel_type First = in0[x];
+ pixel_type Second = in1[x];
+ pixel_type Third = in2[x];
+ if (second == 1) {
+ Second = Second - First;
+ } else if (second == 2) {
+ Second = Second - ((First + Third) >> 1);
+ }
+ if (third) Third = Third - First;
+ out0[x] = First;
+ out1[x] = Second;
+ out2[x] = Third;
+ }
+ }
+ };
+ return RunOnPool(pool, 0, h, ThreadPool::NoInit, do_rct, "FwdRCT");
+}
+
+} // namespace jxl
diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/enc_rct.h b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_rct.h
new file mode 100644
index 0000000000..cb5a193c8d
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_rct.h
@@ -0,0 +1,17 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#ifndef LIB_JXL_MODULAR_TRANSFORM_ENC_RCT_H_
+#define LIB_JXL_MODULAR_TRANSFORM_ENC_RCT_H_
+
+#include "lib/jxl/modular/modular_image.h"
+
+namespace jxl {
+
+Status FwdRCT(Image &input, size_t begin_c, size_t rct_type, ThreadPool *pool);
+
+} // namespace jxl
+
+#endif // LIB_JXL_MODULAR_TRANSFORM_ENC_RCT_H_
diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/enc_squeeze.cc b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_squeeze.cc
new file mode 100644
index 0000000000..dfd90cde68
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_squeeze.cc
@@ -0,0 +1,141 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#include "lib/jxl/modular/transform/enc_squeeze.h"
+
+#include <stdlib.h>
+
+#include "lib/jxl/base/data_parallel.h"
+#include "lib/jxl/common.h"
+#include "lib/jxl/modular/modular_image.h"
+#include "lib/jxl/modular/transform/squeeze.h"
+#include "lib/jxl/modular/transform/transform.h"
+
+namespace jxl {
+
+void FwdHSqueeze(Image &input, int c, int rc) {
+ const Channel &chin = input.channel[c];
+
+ JXL_DEBUG_V(4, "Doing horizontal squeeze of channel %i to new channel %i", c,
+ rc);
+
+ Channel chout((chin.w + 1) / 2, chin.h, chin.hshift + 1, chin.vshift);
+ Channel chout_residual(chin.w - chout.w, chout.h, chin.hshift + 1,
+ chin.vshift);
+
+ for (size_t y = 0; y < chout.h; y++) {
+ const pixel_type *JXL_RESTRICT p_in = chin.Row(y);
+ pixel_type *JXL_RESTRICT p_out = chout.Row(y);
+ pixel_type *JXL_RESTRICT p_res = chout_residual.Row(y);
+ for (size_t x = 0; x < chout_residual.w; x++) {
+ pixel_type A = p_in[x * 2];
+ pixel_type B = p_in[x * 2 + 1];
+ pixel_type avg = (A + B + (A > B)) >> 1;
+ p_out[x] = avg;
+
+ pixel_type diff = A - B;
+
+ pixel_type next_avg = avg;
+ if (x + 1 < chout_residual.w) {
+ next_avg = (p_in[x * 2 + 2] + p_in[x * 2 + 3] +
+ (p_in[x * 2 + 2] > p_in[x * 2 + 3])) >>
+ 1; // which will be chout.value(y,x+1)
+ } else if (chin.w & 1)
+ next_avg = p_in[x * 2 + 2];
+ pixel_type left = (x > 0 ? p_in[x * 2 - 1] : avg);
+ pixel_type tendency = SmoothTendency(left, avg, next_avg);
+
+ p_res[x] = diff - tendency;
+ }
+ if (chin.w & 1) {
+ int x = chout.w - 1;
+ p_out[x] = p_in[x * 2];
+ }
+ }
+ input.channel[c] = std::move(chout);
+ input.channel.insert(input.channel.begin() + rc, std::move(chout_residual));
+}
+
+void FwdVSqueeze(Image &input, int c, int rc) {
+ const Channel &chin = input.channel[c];
+
+ JXL_DEBUG_V(4, "Doing vertical squeeze of channel %i to new channel %i", c,
+ rc);
+
+ Channel chout(chin.w, (chin.h + 1) / 2, chin.hshift, chin.vshift + 1);
+ Channel chout_residual(chin.w, chin.h - chout.h, chin.hshift,
+ chin.vshift + 1);
+ intptr_t onerow_in = chin.plane.PixelsPerRow();
+ for (size_t y = 0; y < chout_residual.h; y++) {
+ const pixel_type *JXL_RESTRICT p_in = chin.Row(y * 2);
+ pixel_type *JXL_RESTRICT p_out = chout.Row(y);
+ pixel_type *JXL_RESTRICT p_res = chout_residual.Row(y);
+ for (size_t x = 0; x < chout.w; x++) {
+ pixel_type A = p_in[x];
+ pixel_type B = p_in[x + onerow_in];
+ pixel_type avg = (A + B + (A > B)) >> 1;
+ p_out[x] = avg;
+
+ pixel_type diff = A - B;
+
+ pixel_type next_avg = avg;
+ if (y + 1 < chout_residual.h) {
+ next_avg = (p_in[x + 2 * onerow_in] + p_in[x + 3 * onerow_in] +
+ (p_in[x + 2 * onerow_in] > p_in[x + 3 * onerow_in])) >>
+ 1; // which will be chout.value(y+1,x)
+ } else if (chin.h & 1) {
+ next_avg = p_in[x + 2 * onerow_in];
+ }
+ pixel_type top =
+ (y > 0 ? p_in[static_cast<ssize_t>(x) - onerow_in] : avg);
+ pixel_type tendency = SmoothTendency(top, avg, next_avg);
+
+ p_res[x] = diff - tendency;
+ }
+ }
+ if (chin.h & 1) {
+ size_t y = chout.h - 1;
+ const pixel_type *p_in = chin.Row(y * 2);
+ pixel_type *p_out = chout.Row(y);
+ for (size_t x = 0; x < chout.w; x++) {
+ p_out[x] = p_in[x];
+ }
+ }
+ input.channel[c] = std::move(chout);
+ input.channel.insert(input.channel.begin() + rc, std::move(chout_residual));
+}
+
+Status FwdSqueeze(Image &input, std::vector<SqueezeParams> parameters,
+ ThreadPool *pool) {
+ if (parameters.empty()) {
+ DefaultSqueezeParameters(&parameters, input);
+ }
+ // if nothing to do, don't do squeeze
+ if (parameters.empty()) return false;
+ for (size_t i = 0; i < parameters.size(); i++) {
+ JXL_RETURN_IF_ERROR(
+ CheckMetaSqueezeParams(parameters[i], input.channel.size()));
+ bool horizontal = parameters[i].horizontal;
+ bool in_place = parameters[i].in_place;
+ uint32_t beginc = parameters[i].begin_c;
+ uint32_t endc = parameters[i].begin_c + parameters[i].num_c - 1;
+ uint32_t offset;
+ if (in_place) {
+ offset = endc + 1;
+ } else {
+ offset = input.channel.size();
+ }
+ for (uint32_t c = beginc; c <= endc; c++) {
+ if (horizontal) {
+ FwdHSqueeze(input, c, offset + c - beginc);
+ } else {
+ FwdVSqueeze(input, c, offset + c - beginc);
+ }
+ }
+ }
+ return true;
+}
+
+} // namespace jxl
diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/enc_squeeze.h b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_squeeze.h
new file mode 100644
index 0000000000..39b001017b
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_squeeze.h
@@ -0,0 +1,20 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#ifndef LIB_JXL_MODULAR_TRANSFORM_ENC_SQUEEZE_H_
+#define LIB_JXL_MODULAR_TRANSFORM_ENC_SQUEEZE_H_
+
+#include "lib/jxl/fields.h"
+#include "lib/jxl/modular/modular_image.h"
+#include "lib/jxl/modular/transform/transform.h"
+
+namespace jxl {
+
+Status FwdSqueeze(Image &input, std::vector<SqueezeParams> parameters,
+ ThreadPool *pool);
+
+} // namespace jxl
+
+#endif // LIB_JXL_MODULAR_TRANSFORM_ENC_SQUEEZE_H_
diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/enc_transform.cc b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_transform.cc
new file mode 100644
index 0000000000..bdaaf9f87e
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_transform.cc
@@ -0,0 +1,46 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#include "lib/jxl/modular/transform/enc_transform.h"
+
+#include "lib/jxl/modular/transform/enc_palette.h"
+#include "lib/jxl/modular/transform/enc_rct.h"
+#include "lib/jxl/modular/transform/enc_squeeze.h"
+
+namespace jxl {
+
+Status TransformForward(Transform &t, Image &input,
+ const weighted::Header &wp_header, ThreadPool *pool) {
+ switch (t.id) {
+ case TransformId::kRCT:
+ return FwdRCT(input, t.begin_c, t.rct_type, pool);
+ case TransformId::kSqueeze:
+ return FwdSqueeze(input, t.squeezes, pool);
+ case TransformId::kPalette:
+ return FwdPalette(input, t.begin_c, t.begin_c + t.num_c - 1, t.nb_colors,
+ t.nb_deltas, t.ordered_palette, t.lossy_palette,
+ t.predictor, wp_header);
+ default:
+ return JXL_FAILURE("Unknown transformation (ID=%u)",
+ static_cast<unsigned int>(t.id));
+ }
+}
+
+void compute_minmax(const Channel &ch, pixel_type *min, pixel_type *max) {
+ pixel_type realmin = std::numeric_limits<pixel_type>::max();
+ pixel_type realmax = std::numeric_limits<pixel_type>::min();
+ for (size_t y = 0; y < ch.h; y++) {
+ const pixel_type *JXL_RESTRICT p = ch.Row(y);
+ for (size_t x = 0; x < ch.w; x++) {
+ if (p[x] < realmin) realmin = p[x];
+ if (p[x] > realmax) realmax = p[x];
+ }
+ }
+
+ if (min) *min = realmin;
+ if (max) *max = realmax;
+}
+
+} // namespace jxl
diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/enc_transform.h b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_transform.h
new file mode 100644
index 0000000000..07659e1b0a
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_transform.h
@@ -0,0 +1,22 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#ifndef LIB_JXL_MODULAR_TRANSFORM_ENC_TRANSFORM_H_
+#define LIB_JXL_MODULAR_TRANSFORM_ENC_TRANSFORM_H_
+
+#include "lib/jxl/fields.h"
+#include "lib/jxl/modular/modular_image.h"
+#include "lib/jxl/modular/transform/transform.h"
+
+namespace jxl {
+
+Status TransformForward(Transform &t, Image &input,
+ const weighted::Header &wp_header, ThreadPool *pool);
+
+void compute_minmax(const Channel &ch, pixel_type *min, pixel_type *max);
+
+} // namespace jxl
+
+#endif // LIB_JXL_MODULAR_TRANSFORM_ENC_TRANSFORM_H_
diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/palette.cc b/third_party/jpeg-xl/lib/jxl/modular/transform/palette.cc
new file mode 100644
index 0000000000..46129f19f0
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/modular/transform/palette.cc
@@ -0,0 +1,176 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#include "lib/jxl/modular/transform/palette.h"
+
+namespace jxl {
+
+Status InvPalette(Image &input, uint32_t begin_c, uint32_t nb_colors,
+ uint32_t nb_deltas, Predictor predictor,
+ const weighted::Header &wp_header, ThreadPool *pool) {
+ if (input.nb_meta_channels < 1) {
+ return JXL_FAILURE("Error: Palette transform without palette.");
+ }
+ std::atomic<int> num_errors{0};
+ int nb = input.channel[0].h;
+ uint32_t c0 = begin_c + 1;
+ if (c0 >= input.channel.size()) {
+ return JXL_FAILURE("Channel is out of range.");
+ }
+ size_t w = input.channel[c0].w;
+ size_t h = input.channel[c0].h;
+ if (nb < 1) return JXL_FAILURE("Corrupted transforms");
+ for (int i = 1; i < nb; i++) {
+ input.channel.insert(
+ input.channel.begin() + c0 + 1,
+ Channel(w, h, input.channel[c0].hshift, input.channel[c0].vshift));
+ }
+ const Channel &palette = input.channel[0];
+ const pixel_type *JXL_RESTRICT p_palette = input.channel[0].Row(0);
+ intptr_t onerow = input.channel[0].plane.PixelsPerRow();
+ intptr_t onerow_image = input.channel[c0].plane.PixelsPerRow();
+ const int bit_depth = std::min(input.bitdepth, 24);
+
+ if (w == 0) {
+ // Nothing to do.
+ // Avoid touching "empty" channels with non-zero height.
+ } else if (nb_deltas == 0 && predictor == Predictor::Zero) {
+ if (nb == 1) {
+ JXL_RETURN_IF_ERROR(RunOnPool(
+ pool, 0, h, ThreadPool::NoInit,
+ [&](const uint32_t task, size_t /* thread */) {
+ const size_t y = task;
+ pixel_type *p = input.channel[c0].Row(y);
+ for (size_t x = 0; x < w; x++) {
+ const int index = Clamp1<int>(p[x], 0, (pixel_type)palette.w - 1);
+ p[x] = palette_internal::GetPaletteValue(
+ p_palette, index, /*c=*/0,
+ /*palette_size=*/palette.w,
+ /*onerow=*/onerow, /*bit_depth=*/bit_depth);
+ }
+ },
+ "UndoChannelPalette"));
+ } else {
+ JXL_RETURN_IF_ERROR(RunOnPool(
+ pool, 0, h, ThreadPool::NoInit,
+ [&](const uint32_t task, size_t /* thread */) {
+ const size_t y = task;
+ std::vector<pixel_type *> p_out(nb);
+ const pixel_type *p_index = input.channel[c0].Row(y);
+ for (int c = 0; c < nb; c++)
+ p_out[c] = input.channel[c0 + c].Row(y);
+ for (size_t x = 0; x < w; x++) {
+ const int index = p_index[x];
+ for (int c = 0; c < nb; c++) {
+ p_out[c][x] = palette_internal::GetPaletteValue(
+ p_palette, index, /*c=*/c,
+ /*palette_size=*/palette.w,
+ /*onerow=*/onerow, /*bit_depth=*/bit_depth);
+ }
+ }
+ },
+ "UndoPalette"));
+ }
+ } else {
+ // Parallelized per channel.
+ ImageI indices = CopyImage(input.channel[c0].plane);
+ if (predictor == Predictor::Weighted) {
+ JXL_RETURN_IF_ERROR(RunOnPool(
+ pool, 0, nb, ThreadPool::NoInit,
+ [&](const uint32_t c, size_t /* thread */) {
+ Channel &channel = input.channel[c0 + c];
+ weighted::State wp_state(wp_header, channel.w, channel.h);
+ for (size_t y = 0; y < channel.h; y++) {
+ pixel_type *JXL_RESTRICT p = channel.Row(y);
+ const pixel_type *JXL_RESTRICT idx = indices.Row(y);
+ for (size_t x = 0; x < channel.w; x++) {
+ int index = idx[x];
+ pixel_type_w val = 0;
+ const pixel_type palette_entry =
+ palette_internal::GetPaletteValue(
+ p_palette, index, /*c=*/c,
+ /*palette_size=*/palette.w, /*onerow=*/onerow,
+ /*bit_depth=*/bit_depth);
+ if (index < static_cast<int32_t>(nb_deltas)) {
+ PredictionResult pred =
+ PredictNoTreeWP(channel.w, p + x, onerow_image, x, y,
+ predictor, &wp_state);
+ val = pred.guess + palette_entry;
+ } else {
+ val = palette_entry;
+ }
+ p[x] = val;
+ wp_state.UpdateErrors(p[x], x, y, channel.w);
+ }
+ }
+ },
+ "UndoDeltaPaletteWP"));
+ } else {
+ JXL_RETURN_IF_ERROR(RunOnPool(
+ pool, 0, nb, ThreadPool::NoInit,
+ [&](const uint32_t c, size_t /* thread */) {
+ Channel &channel = input.channel[c0 + c];
+ for (size_t y = 0; y < channel.h; y++) {
+ pixel_type *JXL_RESTRICT p = channel.Row(y);
+ const pixel_type *JXL_RESTRICT idx = indices.Row(y);
+ for (size_t x = 0; x < channel.w; x++) {
+ int index = idx[x];
+ pixel_type_w val = 0;
+ const pixel_type palette_entry =
+ palette_internal::GetPaletteValue(
+ p_palette, index, /*c=*/c,
+ /*palette_size=*/palette.w,
+ /*onerow=*/onerow, /*bit_depth=*/bit_depth);
+ if (index < static_cast<int32_t>(nb_deltas)) {
+ PredictionResult pred = PredictNoTreeNoWP(
+ channel.w, p + x, onerow_image, x, y, predictor);
+ val = pred.guess + palette_entry;
+ } else {
+ val = palette_entry;
+ }
+ p[x] = val;
+ }
+ }
+ },
+ "UndoDeltaPaletteNoWP"));
+ }
+ }
+ if (c0 >= input.nb_meta_channels) {
+ // Palette was done on normal channels
+ input.nb_meta_channels--;
+ } else {
+ // Palette was done on metachannels
+ JXL_ASSERT(static_cast<int>(input.nb_meta_channels) >= 2 - nb);
+ input.nb_meta_channels -= 2 - nb;
+ JXL_ASSERT(begin_c + nb - 1 < input.nb_meta_channels);
+ }
+ input.channel.erase(input.channel.begin(), input.channel.begin() + 1);
+ return num_errors.load(std::memory_order_relaxed) == 0;
+}
+
+Status MetaPalette(Image &input, uint32_t begin_c, uint32_t end_c,
+ uint32_t nb_colors, uint32_t nb_deltas, bool lossy) {
+ JXL_RETURN_IF_ERROR(CheckEqualChannels(input, begin_c, end_c));
+
+ size_t nb = end_c - begin_c + 1;
+ if (begin_c >= input.nb_meta_channels) {
+ // Palette was done on normal channels
+ input.nb_meta_channels++;
+ } else {
+ // Palette was done on metachannels
+ JXL_ASSERT(end_c < input.nb_meta_channels);
+ // we remove nb-1 metachannels and add one
+ input.nb_meta_channels += 2 - nb;
+ }
+ input.channel.erase(input.channel.begin() + begin_c + 1,
+ input.channel.begin() + end_c + 1);
+ Channel pch(nb_colors + nb_deltas, nb);
+ pch.hshift = -1;
+ pch.vshift = -1;
+ input.channel.insert(input.channel.begin(), std::move(pch));
+ return true;
+}
+
+} // namespace jxl
diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/palette.h b/third_party/jpeg-xl/lib/jxl/modular/transform/palette.h
new file mode 100644
index 0000000000..cc0f67960b
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/modular/transform/palette.h
@@ -0,0 +1,129 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#ifndef LIB_JXL_MODULAR_TRANSFORM_PALETTE_H_
+#define LIB_JXL_MODULAR_TRANSFORM_PALETTE_H_
+
+#include <atomic>
+
+#include "lib/jxl/base/data_parallel.h"
+#include "lib/jxl/base/status.h"
+#include "lib/jxl/common.h"
+#include "lib/jxl/modular/encoding/context_predict.h"
+#include "lib/jxl/modular/modular_image.h"
+#include "lib/jxl/modular/transform/transform.h" // CheckEqualChannels
+
+namespace jxl {
+
+namespace palette_internal {
+
+static constexpr int kMaxPaletteLookupTableSize = 1 << 16;
+
+static constexpr int kRgbChannels = 3;
+
+// 5x5x5 color cube for the larger cube.
+static constexpr int kLargeCube = 5;
+
+// Smaller interleaved color cube to fill the holes of the larger cube.
+static constexpr int kSmallCube = 4;
+static constexpr int kSmallCubeBits = 2;
+// kSmallCube ** 3
+static constexpr int kLargeCubeOffset = kSmallCube * kSmallCube * kSmallCube;
+
+static inline pixel_type Scale(uint64_t value, uint64_t bit_depth,
+ uint64_t denom) {
+ // return (value * ((static_cast<pixel_type_w>(1) << bit_depth) - 1)) / denom;
+ // We only call this function with kSmallCube or kLargeCube - 1 as denom,
+ // allowing us to avoid a division here.
+ JXL_ASSERT(denom == 4);
+ return (value * ((static_cast<uint64_t>(1) << bit_depth) - 1)) >> 2;
+}
+
+// The purpose of this function is solely to extend the interpretation of
+// palette indices to implicit values. If index < nb_deltas, indicating that the
+// result is a delta palette entry, it is the responsibility of the caller to
+// treat it as such.
+static JXL_MAYBE_UNUSED pixel_type
+GetPaletteValue(const pixel_type *const palette, int index, const size_t c,
+ const int palette_size, const int onerow, const int bit_depth) {
+ if (index < 0) {
+ static constexpr std::array<std::array<pixel_type, 3>, 72> kDeltaPalette = {
+ {
+ {{0, 0, 0}}, {{4, 4, 4}}, {{11, 0, 0}},
+ {{0, 0, -13}}, {{0, -12, 0}}, {{-10, -10, -10}},
+ {{-18, -18, -18}}, {{-27, -27, -27}}, {{-18, -18, 0}},
+ {{0, 0, -32}}, {{-32, 0, 0}}, {{-37, -37, -37}},
+ {{0, -32, -32}}, {{24, 24, 45}}, {{50, 50, 50}},
+ {{-45, -24, -24}}, {{-24, -45, -45}}, {{0, -24, -24}},
+ {{-34, -34, 0}}, {{-24, 0, -24}}, {{-45, -45, -24}},
+ {{64, 64, 64}}, {{-32, 0, -32}}, {{0, -32, 0}},
+ {{-32, 0, 32}}, {{-24, -45, -24}}, {{45, 24, 45}},
+ {{24, -24, -45}}, {{-45, -24, 24}}, {{80, 80, 80}},
+ {{64, 0, 0}}, {{0, 0, -64}}, {{0, -64, -64}},
+ {{-24, -24, 45}}, {{96, 96, 96}}, {{64, 64, 0}},
+ {{45, -24, -24}}, {{34, -34, 0}}, {{112, 112, 112}},
+ {{24, -45, -45}}, {{45, 45, -24}}, {{0, -32, 32}},
+ {{24, -24, 45}}, {{0, 96, 96}}, {{45, -24, 24}},
+ {{24, -45, -24}}, {{-24, -45, 24}}, {{0, -64, 0}},
+ {{96, 0, 0}}, {{128, 128, 128}}, {{64, 0, 64}},
+ {{144, 144, 144}}, {{96, 96, 0}}, {{-36, -36, 36}},
+ {{45, -24, -45}}, {{45, -45, -24}}, {{0, 0, -96}},
+ {{0, 128, 128}}, {{0, 96, 0}}, {{45, 24, -45}},
+ {{-128, 0, 0}}, {{24, -45, 24}}, {{-45, 24, -45}},
+ {{64, 0, -64}}, {{64, -64, -64}}, {{96, 0, 96}},
+ {{45, -45, 24}}, {{24, 45, -45}}, {{64, 64, -64}},
+ {{128, 128, 0}}, {{0, 0, -128}}, {{-24, 45, -45}},
+ }};
+ if (c >= kRgbChannels) {
+ return 0;
+ }
+ // Do not open the brackets, otherwise INT32_MIN negation could overflow.
+ index = -(index + 1);
+ index %= 1 + 2 * (kDeltaPalette.size() - 1);
+ static constexpr int kMultiplier[] = {-1, 1};
+ pixel_type result =
+ kDeltaPalette[((index + 1) >> 1)][c] * kMultiplier[index & 1];
+ if (bit_depth > 8) {
+ result *= static_cast<pixel_type>(1) << (bit_depth - 8);
+ }
+ return result;
+ } else if (palette_size <= index && index < palette_size + kLargeCubeOffset) {
+ if (c >= kRgbChannels) return 0;
+ index -= palette_size;
+ index >>= c * kSmallCubeBits;
+ return Scale(index % kSmallCube, bit_depth, kSmallCube) +
+ (1 << (std::max(0, bit_depth - 3)));
+ } else if (palette_size + kLargeCubeOffset <= index) {
+ if (c >= kRgbChannels) return 0;
+ index -= palette_size + kLargeCubeOffset;
+ // TODO(eustas): should we take care of ambiguity created by
+ // index >= kLargeCube ** 3 ?
+ switch (c) {
+ case 0:
+ break;
+ case 1:
+ index /= kLargeCube;
+ break;
+ case 2:
+ index /= kLargeCube * kLargeCube;
+ break;
+ }
+ return Scale(index % kLargeCube, bit_depth, kLargeCube - 1);
+ }
+ return palette[c * onerow + static_cast<size_t>(index)];
+}
+
+} // namespace palette_internal
+
+Status InvPalette(Image &input, uint32_t begin_c, uint32_t nb_colors,
+ uint32_t nb_deltas, Predictor predictor,
+ const weighted::Header &wp_header, ThreadPool *pool);
+
+Status MetaPalette(Image &input, uint32_t begin_c, uint32_t end_c,
+ uint32_t nb_colors, uint32_t nb_deltas, bool lossy);
+
+} // namespace jxl
+
+#endif // LIB_JXL_MODULAR_TRANSFORM_PALETTE_H_
diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/rct.cc b/third_party/jpeg-xl/lib/jxl/modular/transform/rct.cc
new file mode 100644
index 0000000000..f3002a5ac3
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/modular/transform/rct.cc
@@ -0,0 +1,153 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#include "lib/jxl/modular/transform/rct.h"
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "lib/jxl/modular/transform/rct.cc"
+#include <hwy/foreach_target.h>
+#include <hwy/highway.h>
+HWY_BEFORE_NAMESPACE();
+namespace jxl {
+namespace HWY_NAMESPACE {
+
+// These templates are not found via ADL.
+using hwy::HWY_NAMESPACE::Add;
+using hwy::HWY_NAMESPACE::ShiftRight;
+using hwy::HWY_NAMESPACE::Sub;
+
+template <int transform_type>
+void InvRCTRow(const pixel_type* in0, const pixel_type* in1,
+ const pixel_type* in2, pixel_type* out0, pixel_type* out1,
+ pixel_type* out2, size_t w) {
+ static_assert(transform_type >= 0 && transform_type < 7,
+ "Invalid transform type");
+ int second = transform_type >> 1;
+ int third = transform_type & 1;
+
+ size_t x = 0;
+ const HWY_FULL(pixel_type) d;
+ const size_t N = Lanes(d);
+ for (; x + N - 1 < w; x += N) {
+ if (transform_type == 6) {
+ auto Y = Load(d, in0 + x);
+ auto Co = Load(d, in1 + x);
+ auto Cg = Load(d, in2 + x);
+ Y = Sub(Y, ShiftRight<1>(Cg));
+ auto G = Add(Cg, Y);
+ Y = Sub(Y, ShiftRight<1>(Co));
+ auto R = Add(Y, Co);
+ Store(R, d, out0 + x);
+ Store(G, d, out1 + x);
+ Store(Y, d, out2 + x);
+ } else {
+ auto First = Load(d, in0 + x);
+ auto Second = Load(d, in1 + x);
+ auto Third = Load(d, in2 + x);
+ if (third) Third = Add(Third, First);
+ if (second == 1) {
+ Second = Add(Second, First);
+ } else if (second == 2) {
+ Second = Add(Second, ShiftRight<1>(Add(First, Third)));
+ }
+ Store(First, d, out0 + x);
+ Store(Second, d, out1 + x);
+ Store(Third, d, out2 + x);
+ }
+ }
+ for (; x < w; x++) {
+ if (transform_type == 6) {
+ pixel_type Y = in0[x];
+ pixel_type Co = in1[x];
+ pixel_type Cg = in2[x];
+ pixel_type tmp = PixelAdd(Y, -(Cg >> 1));
+ pixel_type G = PixelAdd(Cg, tmp);
+ pixel_type B = PixelAdd(tmp, -(Co >> 1));
+ pixel_type R = PixelAdd(B, Co);
+ out0[x] = R;
+ out1[x] = G;
+ out2[x] = B;
+ } else {
+ pixel_type First = in0[x];
+ pixel_type Second = in1[x];
+ pixel_type Third = in2[x];
+ if (third) Third = PixelAdd(Third, First);
+ if (second == 1) {
+ Second = PixelAdd(Second, First);
+ } else if (second == 2) {
+ Second = PixelAdd(Second, (PixelAdd(First, Third) >> 1));
+ }
+ out0[x] = First;
+ out1[x] = Second;
+ out2[x] = Third;
+ }
+ }
+}
+
+Status InvRCT(Image& input, size_t begin_c, size_t rct_type, ThreadPool* pool) {
+ JXL_RETURN_IF_ERROR(CheckEqualChannels(input, begin_c, begin_c + 2));
+ size_t m = begin_c;
+ Channel& c0 = input.channel[m + 0];
+ size_t w = c0.w;
+ size_t h = c0.h;
+ if (rct_type == 0) { // noop
+ return true;
+ }
+ // Permutation: 0=RGB, 1=GBR, 2=BRG, 3=RBG, 4=GRB, 5=BGR
+ int permutation = rct_type / 7;
+ JXL_CHECK(permutation < 6);
+ // 0-5 values have the low bit corresponding to Third and the high bits
+ // corresponding to Second. 6 corresponds to YCoCg.
+ //
+ // Second: 0=nop, 1=SubtractFirst, 2=SubtractAvgFirstThird
+ //
+ // Third: 0=nop, 1=SubtractFirst
+ int custom = rct_type % 7;
+ // Special case: permute-only. Swap channels around.
+ if (custom == 0) {
+ Channel ch0 = std::move(input.channel[m]);
+ Channel ch1 = std::move(input.channel[m + 1]);
+ Channel ch2 = std::move(input.channel[m + 2]);
+ input.channel[m + (permutation % 3)] = std::move(ch0);
+ input.channel[m + ((permutation + 1 + permutation / 3) % 3)] =
+ std::move(ch1);
+ input.channel[m + ((permutation + 2 - permutation / 3) % 3)] =
+ std::move(ch2);
+ return true;
+ }
+ constexpr decltype(&InvRCTRow<0>) inv_rct_row[] = {
+ InvRCTRow<0>, InvRCTRow<1>, InvRCTRow<2>, InvRCTRow<3>,
+ InvRCTRow<4>, InvRCTRow<5>, InvRCTRow<6>};
+ JXL_RETURN_IF_ERROR(RunOnPool(
+ pool, 0, h, ThreadPool::NoInit,
+ [&](const uint32_t task, size_t /* thread */) {
+ const size_t y = task;
+ const pixel_type* in0 = input.channel[m].Row(y);
+ const pixel_type* in1 = input.channel[m + 1].Row(y);
+ const pixel_type* in2 = input.channel[m + 2].Row(y);
+ pixel_type* out0 = input.channel[m + (permutation % 3)].Row(y);
+ pixel_type* out1 =
+ input.channel[m + ((permutation + 1 + permutation / 3) % 3)].Row(y);
+ pixel_type* out2 =
+ input.channel[m + ((permutation + 2 - permutation / 3) % 3)].Row(y);
+ inv_rct_row[custom](in0, in1, in2, out0, out1, out2, w);
+ },
+ "InvRCT"));
+ return true;
+}
+
+} // namespace HWY_NAMESPACE
+} // namespace jxl
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace jxl {
+
+HWY_EXPORT(InvRCT);
+Status InvRCT(Image& input, size_t begin_c, size_t rct_type, ThreadPool* pool) {
+ return HWY_DYNAMIC_DISPATCH(InvRCT)(input, begin_c, rct_type, pool);
+}
+
+} // namespace jxl
+#endif
diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/rct.h b/third_party/jpeg-xl/lib/jxl/modular/transform/rct.h
new file mode 100644
index 0000000000..aef65621d5
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/modular/transform/rct.h
@@ -0,0 +1,20 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#ifndef LIB_JXL_MODULAR_TRANSFORM_RCT_H_
+#define LIB_JXL_MODULAR_TRANSFORM_RCT_H_
+
+#include "lib/jxl/base/status.h"
+#include "lib/jxl/common.h"
+#include "lib/jxl/modular/modular_image.h"
+#include "lib/jxl/modular/transform/transform.h" // CheckEqualChannels
+
+namespace jxl {
+
+Status InvRCT(Image& input, size_t begin_c, size_t rct_type, ThreadPool* pool);
+
+} // namespace jxl
+
+#endif // LIB_JXL_MODULAR_TRANSFORM_RCT_H_
diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/squeeze.cc b/third_party/jpeg-xl/lib/jxl/modular/transform/squeeze.cc
new file mode 100644
index 0000000000..8440d9e804
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/modular/transform/squeeze.cc
@@ -0,0 +1,478 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#include "lib/jxl/modular/transform/squeeze.h"
+
+#include <stdlib.h>
+
+#include "lib/jxl/base/data_parallel.h"
+#include "lib/jxl/base/printf_macros.h"
+#include "lib/jxl/common.h"
+#include "lib/jxl/modular/modular_image.h"
+#include "lib/jxl/modular/transform/transform.h"
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "lib/jxl/modular/transform/squeeze.cc"
+#include <hwy/foreach_target.h>
+#include <hwy/highway.h>
+
+#include "lib/jxl/simd_util-inl.h"
+
+HWY_BEFORE_NAMESPACE();
+namespace jxl {
+namespace HWY_NAMESPACE {
+
+// These templates are not found via ADL.
+using hwy::HWY_NAMESPACE::Abs;
+using hwy::HWY_NAMESPACE::Add;
+using hwy::HWY_NAMESPACE::And;
+using hwy::HWY_NAMESPACE::Gt;
+using hwy::HWY_NAMESPACE::IfThenElse;
+using hwy::HWY_NAMESPACE::IfThenZeroElse;
+using hwy::HWY_NAMESPACE::Lt;
+using hwy::HWY_NAMESPACE::MulEven;
+using hwy::HWY_NAMESPACE::Ne;
+using hwy::HWY_NAMESPACE::Neg;
+using hwy::HWY_NAMESPACE::OddEven;
+using hwy::HWY_NAMESPACE::RebindToUnsigned;
+using hwy::HWY_NAMESPACE::ShiftLeft;
+using hwy::HWY_NAMESPACE::ShiftRight;
+using hwy::HWY_NAMESPACE::Sub;
+using hwy::HWY_NAMESPACE::Xor;
+
+#if HWY_TARGET != HWY_SCALAR
+
+JXL_INLINE void FastUnsqueeze(const pixel_type *JXL_RESTRICT p_residual,
+ const pixel_type *JXL_RESTRICT p_avg,
+ const pixel_type *JXL_RESTRICT p_navg,
+ const pixel_type *p_pout,
+ pixel_type *JXL_RESTRICT p_out,
+ pixel_type *p_nout) {
+ const HWY_CAPPED(pixel_type, 8) d;
+ const RebindToUnsigned<decltype(d)> du;
+ const size_t N = Lanes(d);
+ auto onethird = Set(d, 0x55555556);
+ for (size_t x = 0; x < 8; x += N) {
+ auto avg = Load(d, p_avg + x);
+ auto next_avg = Load(d, p_navg + x);
+ auto top = Load(d, p_pout + x);
+ // Equivalent to SmoothTendency(top,avg,next_avg), but without branches
+ auto Ba = Sub(top, avg);
+ auto an = Sub(avg, next_avg);
+ auto nonmono = Xor(Ba, an);
+ auto absBa = Abs(Ba);
+ auto absan = Abs(an);
+ auto absBn = Abs(Sub(top, next_avg));
+ // Compute a3 = absBa / 3
+ auto a3e = BitCast(d, ShiftRight<32>(MulEven(absBa, onethird)));
+ auto a3oi = MulEven(Reverse(d, absBa), onethird);
+ auto a3o = BitCast(
+ d, Reverse(hwy::HWY_NAMESPACE::Repartition<pixel_type_w, decltype(d)>(),
+ a3oi));
+ auto a3 = OddEven(a3o, a3e);
+ a3 = Add(a3, Add(absBn, Set(d, 2)));
+ auto absdiff = ShiftRight<2>(a3);
+ auto skipdiff = Ne(Ba, Zero(d));
+ skipdiff = And(skipdiff, Ne(an, Zero(d)));
+ skipdiff = And(skipdiff, Lt(nonmono, Zero(d)));
+ auto absBa2 = Add(ShiftLeft<1>(absBa), And(absdiff, Set(d, 1)));
+ absdiff = IfThenElse(Gt(absdiff, absBa2),
+ Add(ShiftLeft<1>(absBa), Set(d, 1)), absdiff);
+ auto absan2 = ShiftLeft<1>(absan);
+ absdiff = IfThenElse(Gt(Add(absdiff, And(absdiff, Set(d, 1))), absan2),
+ absan2, absdiff);
+ auto diff1 = IfThenElse(Lt(top, next_avg), Neg(absdiff), absdiff);
+ auto tendency = IfThenZeroElse(skipdiff, diff1);
+
+ auto diff_minus_tendency = Load(d, p_residual + x);
+ auto diff = Add(diff_minus_tendency, tendency);
+ auto out =
+ Add(avg, ShiftRight<1>(
+ Add(diff, BitCast(d, ShiftRight<31>(BitCast(du, diff))))));
+ Store(out, d, p_out + x);
+ Store(Sub(out, diff), d, p_nout + x);
+ }
+}
+
+#endif
+
+Status InvHSqueeze(Image &input, uint32_t c, uint32_t rc, ThreadPool *pool) {
+ JXL_ASSERT(c < input.channel.size());
+ JXL_ASSERT(rc < input.channel.size());
+ Channel &chin = input.channel[c];
+ const Channel &chin_residual = input.channel[rc];
+ // These must be valid since we ran MetaApply already.
+ JXL_ASSERT(chin.w == DivCeil(chin.w + chin_residual.w, 2));
+ JXL_ASSERT(chin.h == chin_residual.h);
+
+ if (chin_residual.w == 0) {
+ // Short-circuit: output channel has same dimensions as input.
+ input.channel[c].hshift--;
+ return true;
+ }
+
+ // Note: chin.w >= chin_residual.w and at most 1 different.
+ Channel chout(chin.w + chin_residual.w, chin.h, chin.hshift - 1, chin.vshift);
+ JXL_DEBUG_V(4,
+ "Undoing horizontal squeeze of channel %i using residuals in "
+ "channel %i (going from width %" PRIuS " to %" PRIuS ")",
+ c, rc, chin.w, chout.w);
+
+ if (chin_residual.h == 0) {
+ // Short-circuit: channel with no pixels.
+ input.channel[c] = std::move(chout);
+ return true;
+ }
+ auto unsqueeze_row = [&](size_t y, size_t x0) {
+ const pixel_type *JXL_RESTRICT p_residual = chin_residual.Row(y);
+ const pixel_type *JXL_RESTRICT p_avg = chin.Row(y);
+ pixel_type *JXL_RESTRICT p_out = chout.Row(y);
+ for (size_t x = x0; x < chin_residual.w; x++) {
+ pixel_type_w diff_minus_tendency = p_residual[x];
+ pixel_type_w avg = p_avg[x];
+ pixel_type_w next_avg = (x + 1 < chin.w ? p_avg[x + 1] : avg);
+ pixel_type_w left = (x ? p_out[(x << 1) - 1] : avg);
+ pixel_type_w tendency = SmoothTendency(left, avg, next_avg);
+ pixel_type_w diff = diff_minus_tendency + tendency;
+ pixel_type_w A = avg + (diff / 2);
+ p_out[(x << 1)] = A;
+ pixel_type_w B = A - diff;
+ p_out[(x << 1) + 1] = B;
+ }
+ if (chout.w & 1) p_out[chout.w - 1] = p_avg[chin.w - 1];
+ };
+
+ // somewhat complicated trickery just to be able to SIMD this.
+ // Horizontal unsqueeze has horizontal data dependencies, so we do
+ // 8 rows at a time and treat it as a vertical unsqueeze of a
+ // transposed 8x8 block (or 9x8 for one input).
+ static constexpr const size_t kRowsPerThread = 8;
+ const auto unsqueeze_span = [&](const uint32_t task, size_t /* thread */) {
+ const size_t y0 = task * kRowsPerThread;
+ const size_t rows = std::min(kRowsPerThread, chin.h - y0);
+ size_t x = 0;
+
+#if HWY_TARGET != HWY_SCALAR
+ intptr_t onerow_in = chin.plane.PixelsPerRow();
+ intptr_t onerow_inr = chin_residual.plane.PixelsPerRow();
+ intptr_t onerow_out = chout.plane.PixelsPerRow();
+ const pixel_type *JXL_RESTRICT p_residual = chin_residual.Row(y0);
+ const pixel_type *JXL_RESTRICT p_avg = chin.Row(y0);
+ pixel_type *JXL_RESTRICT p_out = chout.Row(y0);
+ HWY_ALIGN pixel_type b_p_avg[9 * kRowsPerThread];
+ HWY_ALIGN pixel_type b_p_residual[8 * kRowsPerThread];
+ HWY_ALIGN pixel_type b_p_out_even[8 * kRowsPerThread];
+ HWY_ALIGN pixel_type b_p_out_odd[8 * kRowsPerThread];
+ HWY_ALIGN pixel_type b_p_out_evenT[8 * kRowsPerThread];
+ HWY_ALIGN pixel_type b_p_out_oddT[8 * kRowsPerThread];
+ const HWY_CAPPED(pixel_type, 8) d;
+ const size_t N = Lanes(d);
+ if (chin_residual.w > 16 && rows == kRowsPerThread) {
+ for (; x < chin_residual.w - 9; x += 8) {
+ Transpose8x8Block(p_residual + x, b_p_residual, onerow_inr);
+ Transpose8x8Block(p_avg + x, b_p_avg, onerow_in);
+ for (size_t y = 0; y < kRowsPerThread; y++) {
+ b_p_avg[8 * 8 + y] = p_avg[x + 8 + onerow_in * y];
+ }
+ for (size_t i = 0; i < 8; i++) {
+ FastUnsqueeze(
+ b_p_residual + 8 * i, b_p_avg + 8 * i, b_p_avg + 8 * (i + 1),
+ (x + i ? b_p_out_odd + 8 * ((x + i - 1) & 7) : b_p_avg + 8 * i),
+ b_p_out_even + 8 * i, b_p_out_odd + 8 * i);
+ }
+
+ Transpose8x8Block(b_p_out_even, b_p_out_evenT, 8);
+ Transpose8x8Block(b_p_out_odd, b_p_out_oddT, 8);
+ for (size_t y = 0; y < kRowsPerThread; y++) {
+ for (size_t i = 0; i < kRowsPerThread; i += N) {
+ auto even = Load(d, b_p_out_evenT + 8 * y + i);
+ auto odd = Load(d, b_p_out_oddT + 8 * y + i);
+ StoreInterleaved(d, even, odd,
+ p_out + ((x + i) << 1) + onerow_out * y);
+ }
+ }
+ }
+ }
+#endif
+ for (size_t y = 0; y < rows; y++) {
+ unsqueeze_row(y0 + y, x);
+ }
+ };
+ JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, DivCeil(chin.h, kRowsPerThread),
+ ThreadPool::NoInit, unsqueeze_span,
+ "InvHorizontalSqueeze"));
+ input.channel[c] = std::move(chout);
+ return true;
+}
+
+Status InvVSqueeze(Image &input, uint32_t c, uint32_t rc, ThreadPool *pool) {
+ JXL_ASSERT(c < input.channel.size());
+ JXL_ASSERT(rc < input.channel.size());
+ const Channel &chin = input.channel[c];
+ const Channel &chin_residual = input.channel[rc];
+ // These must be valid since we ran MetaApply already.
+ JXL_ASSERT(chin.h == DivCeil(chin.h + chin_residual.h, 2));
+ JXL_ASSERT(chin.w == chin_residual.w);
+
+ if (chin_residual.h == 0) {
+ // Short-circuit: output channel has same dimensions as input.
+ input.channel[c].vshift--;
+ return true;
+ }
+
+ // Note: chin.h >= chin_residual.h and at most 1 different.
+ Channel chout(chin.w, chin.h + chin_residual.h, chin.hshift, chin.vshift - 1);
+ JXL_DEBUG_V(
+ 4,
+ "Undoing vertical squeeze of channel %i using residuals in channel "
+ "%i (going from height %" PRIuS " to %" PRIuS ")",
+ c, rc, chin.h, chout.h);
+
+ if (chin_residual.w == 0) {
+ // Short-circuit: channel with no pixels.
+ input.channel[c] = std::move(chout);
+ return true;
+ }
+
+ static constexpr const int kColsPerThread = 64;
+ const auto unsqueeze_slice = [&](const uint32_t task, size_t /* thread */) {
+ const size_t x0 = task * kColsPerThread;
+ const size_t x1 = std::min((size_t)(task + 1) * kColsPerThread, chin.w);
+ const size_t w = x1 - x0;
+ // We only iterate up to std::min(chin_residual.h, chin.h) which is
+ // always chin_residual.h.
+ for (size_t y = 0; y < chin_residual.h; y++) {
+ const pixel_type *JXL_RESTRICT p_residual = chin_residual.Row(y) + x0;
+ const pixel_type *JXL_RESTRICT p_avg = chin.Row(y) + x0;
+ const pixel_type *JXL_RESTRICT p_navg =
+ chin.Row(y + 1 < chin.h ? y + 1 : y) + x0;
+ pixel_type *JXL_RESTRICT p_out = chout.Row(y << 1) + x0;
+ pixel_type *JXL_RESTRICT p_nout = chout.Row((y << 1) + 1) + x0;
+ const pixel_type *p_pout = y > 0 ? chout.Row((y << 1) - 1) + x0 : p_avg;
+ size_t x = 0;
+#if HWY_TARGET != HWY_SCALAR
+ for (; x + 7 < w; x += 8) {
+ FastUnsqueeze(p_residual + x, p_avg + x, p_navg + x, p_pout + x,
+ p_out + x, p_nout + x);
+ }
+#endif
+ for (; x < w; x++) {
+ pixel_type_w avg = p_avg[x];
+ pixel_type_w next_avg = p_navg[x];
+ pixel_type_w top = p_pout[x];
+ pixel_type_w tendency = SmoothTendency(top, avg, next_avg);
+ pixel_type_w diff_minus_tendency = p_residual[x];
+ pixel_type_w diff = diff_minus_tendency + tendency;
+ pixel_type_w out = avg + (diff / 2);
+ p_out[x] = out;
+ // If the chin_residual.h == chin.h, the output has an even number
+ // of rows so the next line is fine. Otherwise, this loop won't
+ // write to the last output row which is handled separately.
+ p_nout[x] = out - diff;
+ }
+ }
+ };
+ JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, DivCeil(chin.w, kColsPerThread),
+ ThreadPool::NoInit, unsqueeze_slice,
+ "InvVertSqueeze"));
+
+ if (chout.h & 1) {
+ size_t y = chin.h - 1;
+ const pixel_type *p_avg = chin.Row(y);
+ pixel_type *p_out = chout.Row(y << 1);
+ for (size_t x = 0; x < chin.w; x++) {
+ p_out[x] = p_avg[x];
+ }
+ }
+ input.channel[c] = std::move(chout);
+ return true;
+}
+
+Status InvSqueeze(Image &input, std::vector<SqueezeParams> parameters,
+ ThreadPool *pool) {
+ for (int i = parameters.size() - 1; i >= 0; i--) {
+ JXL_RETURN_IF_ERROR(
+ CheckMetaSqueezeParams(parameters[i], input.channel.size()));
+ bool horizontal = parameters[i].horizontal;
+ bool in_place = parameters[i].in_place;
+ uint32_t beginc = parameters[i].begin_c;
+ uint32_t endc = parameters[i].begin_c + parameters[i].num_c - 1;
+ uint32_t offset;
+ if (in_place) {
+ offset = endc + 1;
+ } else {
+ offset = input.channel.size() + beginc - endc - 1;
+ }
+ if (beginc < input.nb_meta_channels) {
+ // This is checked in MetaSqueeze.
+ JXL_ASSERT(input.nb_meta_channels > parameters[i].num_c);
+ input.nb_meta_channels -= parameters[i].num_c;
+ }
+
+ for (uint32_t c = beginc; c <= endc; c++) {
+ uint32_t rc = offset + c - beginc;
+ // MetaApply should imply that `rc` is within range, otherwise there's a
+ // programming bug.
+ JXL_ASSERT(rc < input.channel.size());
+ if ((input.channel[c].w < input.channel[rc].w) ||
+ (input.channel[c].h < input.channel[rc].h)) {
+ return JXL_FAILURE("Corrupted squeeze transform");
+ }
+ if (horizontal) {
+ JXL_RETURN_IF_ERROR(InvHSqueeze(input, c, rc, pool));
+ } else {
+ JXL_RETURN_IF_ERROR(InvVSqueeze(input, c, rc, pool));
+ }
+ }
+ input.channel.erase(input.channel.begin() + offset,
+ input.channel.begin() + offset + (endc - beginc + 1));
+ }
+ return true;
+}
+
+} // namespace HWY_NAMESPACE
+} // namespace jxl
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+
+namespace jxl {
+
+HWY_EXPORT(InvSqueeze);
+Status InvSqueeze(Image &input, std::vector<SqueezeParams> parameters,
+ ThreadPool *pool) {
+ return HWY_DYNAMIC_DISPATCH(InvSqueeze)(input, parameters, pool);
+}
+
+void DefaultSqueezeParameters(std::vector<SqueezeParams> *parameters,
+ const Image &image) {
+ int nb_channels = image.channel.size() - image.nb_meta_channels;
+
+ parameters->clear();
+ size_t w = image.channel[image.nb_meta_channels].w;
+ size_t h = image.channel[image.nb_meta_channels].h;
+ JXL_DEBUG_V(
+ 7, "Default squeeze parameters for %" PRIuS "x%" PRIuS " image: ", w, h);
+
+ // do horizontal first on wide images; vertical first on tall images
+ bool wide = (w > h);
+
+ if (nb_channels > 2 && image.channel[image.nb_meta_channels + 1].w == w &&
+ image.channel[image.nb_meta_channels + 1].h == h) {
+ // assume channels 1 and 2 are chroma, and can be squeezed first for 4:2:0
+ // previews
+ JXL_DEBUG_V(7, "(4:2:0 chroma), %" PRIuS "x%" PRIuS " image", w, h);
+ SqueezeParams params;
+ // horizontal chroma squeeze
+ params.horizontal = true;
+ params.in_place = false;
+ params.begin_c = image.nb_meta_channels + 1;
+ params.num_c = 2;
+ parameters->push_back(params);
+ params.horizontal = false;
+ // vertical chroma squeeze
+ parameters->push_back(params);
+ }
+ SqueezeParams params;
+ params.begin_c = image.nb_meta_channels;
+ params.num_c = nb_channels;
+ params.in_place = true;
+
+ if (!wide) {
+ if (h > JXL_MAX_FIRST_PREVIEW_SIZE) {
+ params.horizontal = false;
+ parameters->push_back(params);
+ h = (h + 1) / 2;
+ JXL_DEBUG_V(7, "Vertical (%" PRIuS "x%" PRIuS "), ", w, h);
+ }
+ }
+ while (w > JXL_MAX_FIRST_PREVIEW_SIZE || h > JXL_MAX_FIRST_PREVIEW_SIZE) {
+ if (w > JXL_MAX_FIRST_PREVIEW_SIZE) {
+ params.horizontal = true;
+ parameters->push_back(params);
+ w = (w + 1) / 2;
+ JXL_DEBUG_V(7, "Horizontal (%" PRIuS "x%" PRIuS "), ", w, h);
+ }
+ if (h > JXL_MAX_FIRST_PREVIEW_SIZE) {
+ params.horizontal = false;
+ parameters->push_back(params);
+ h = (h + 1) / 2;
+ JXL_DEBUG_V(7, "Vertical (%" PRIuS "x%" PRIuS "), ", w, h);
+ }
+ }
+ JXL_DEBUG_V(7, "that's it");
+}
+
+Status CheckMetaSqueezeParams(const SqueezeParams &parameter,
+ int num_channels) {
+ int c1 = parameter.begin_c;
+ int c2 = parameter.begin_c + parameter.num_c - 1;
+ if (c1 < 0 || c1 >= num_channels || c2 < 0 || c2 >= num_channels || c2 < c1) {
+ return JXL_FAILURE("Invalid channel range");
+ }
+ return true;
+}
+
+Status MetaSqueeze(Image &image, std::vector<SqueezeParams> *parameters) {
+ if (parameters->empty()) {
+ DefaultSqueezeParameters(parameters, image);
+ }
+
+ for (size_t i = 0; i < parameters->size(); i++) {
+ JXL_RETURN_IF_ERROR(
+ CheckMetaSqueezeParams((*parameters)[i], image.channel.size()));
+ bool horizontal = (*parameters)[i].horizontal;
+ bool in_place = (*parameters)[i].in_place;
+ uint32_t beginc = (*parameters)[i].begin_c;
+ uint32_t endc = (*parameters)[i].begin_c + (*parameters)[i].num_c - 1;
+
+ uint32_t offset;
+ if (beginc < image.nb_meta_channels) {
+ if (endc >= image.nb_meta_channels) {
+ return JXL_FAILURE("Invalid squeeze: mix of meta and nonmeta channels");
+ }
+ if (!in_place) {
+ return JXL_FAILURE(
+ "Invalid squeeze: meta channels require in-place residuals");
+ }
+ image.nb_meta_channels += (*parameters)[i].num_c;
+ }
+ if (in_place) {
+ offset = endc + 1;
+ } else {
+ offset = image.channel.size();
+ }
+ for (uint32_t c = beginc; c <= endc; c++) {
+ if (image.channel[c].hshift > 30 || image.channel[c].vshift > 30) {
+ return JXL_FAILURE("Too many squeezes: shift > 30");
+ }
+ size_t w = image.channel[c].w;
+ size_t h = image.channel[c].h;
+ if (w == 0 || h == 0) return JXL_FAILURE("Squeezing empty channel");
+ if (horizontal) {
+ image.channel[c].w = (w + 1) / 2;
+ if (image.channel[c].hshift >= 0) image.channel[c].hshift++;
+ w = w - (w + 1) / 2;
+ } else {
+ image.channel[c].h = (h + 1) / 2;
+ if (image.channel[c].vshift >= 0) image.channel[c].vshift++;
+ h = h - (h + 1) / 2;
+ }
+ image.channel[c].shrink();
+ Channel dummy(w, h);
+ dummy.hshift = image.channel[c].hshift;
+ dummy.vshift = image.channel[c].vshift;
+
+ image.channel.insert(image.channel.begin() + offset + (c - beginc),
+ std::move(dummy));
+ JXL_DEBUG_V(8, "MetaSqueeze applied, current image: %s",
+ image.DebugString().c_str());
+ }
+ }
+ return true;
+}
+
+} // namespace jxl
+
+#endif
diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/squeeze.h b/third_party/jpeg-xl/lib/jxl/modular/transform/squeeze.h
new file mode 100644
index 0000000000..fb18710a6f
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/modular/transform/squeeze.h
@@ -0,0 +1,90 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#ifndef LIB_JXL_MODULAR_TRANSFORM_SQUEEZE_H_
+#define LIB_JXL_MODULAR_TRANSFORM_SQUEEZE_H_
+
+// Haar-like transform: halves the resolution in one direction
+// A B -> (A+B)>>1 in one channel (average) -> same range as
+// original channel
+// A-B - tendency in a new channel ('residual' needed to make
+// the transform reversible)
+// -> theoretically range could be 2.5
+// times larger (2 times without the
+// 'tendency'), but there should be lots
+// of zeroes
+// Repeated application (alternating horizontal and vertical squeezes) results
+// in downscaling
+//
+// The default coefficient ordering is low-frequency to high-frequency, as in
+// M. Antonini, M. Barlaud, P. Mathieu and I. Daubechies, "Image coding using
+// wavelet transform", IEEE Transactions on Image Processing, vol. 1, no. 2, pp.
+// 205-220, April 1992, doi: 10.1109/83.136597.
+
+#include <stdlib.h>
+
+#include "lib/jxl/base/data_parallel.h"
+#include "lib/jxl/common.h"
+#include "lib/jxl/modular/modular_image.h"
+#include "lib/jxl/modular/transform/transform.h"
+
+#define JXL_MAX_FIRST_PREVIEW_SIZE 8
+
+namespace jxl {
+
+/*
+ int avg=(A+B)>>1;
+ int diff=(A-B);
+ int rA=(diff+(avg<<1)+(diff&1))>>1;
+ int rB=rA-diff;
+
+*/
+// |A B|C D|E F|
+// p a n p=avg(A,B), a=avg(C,D), n=avg(E,F)
+//
+// Goal: estimate C-D (avoiding ringing artifacts)
+// (ensuring that in smooth areas, a zero residual corresponds to a smooth
+// gradient)
+
+// best estimate for C: (B + 2*a)/3
+// best estimate for D: (n + 3*a)/4
+// best estimate for C-D: 4*B - 3*n - a /12
+
+// avoid ringing by 1) only doing this if B <= a <= n or B >= a >= n
+// (otherwise, this is not a smooth area and we cannot really estimate C-D)
+// 2) making sure that B <= C <= D <= n or B >= C >= D >= n
+
+inline pixel_type_w SmoothTendency(pixel_type_w B, pixel_type_w a,
+ pixel_type_w n) {
+ pixel_type_w diff = 0;
+ if (B >= a && a >= n) {
+ diff = (4 * B - 3 * n - a + 6) / 12;
+ // 2C = a<<1 + diff - diff&1 <= 2B so diff - diff&1 <= 2B - 2a
+ // 2D = a<<1 - diff - diff&1 >= 2n so diff + diff&1 <= 2a - 2n
+ if (diff - (diff & 1) > 2 * (B - a)) diff = 2 * (B - a) + 1;
+ if (diff + (diff & 1) > 2 * (a - n)) diff = 2 * (a - n);
+ } else if (B <= a && a <= n) {
+ diff = (4 * B - 3 * n - a - 6) / 12;
+ // 2C = a<<1 + diff + diff&1 >= 2B so diff + diff&1 >= 2B - 2a
+ // 2D = a<<1 - diff + diff&1 <= 2n so diff - diff&1 >= 2a - 2n
+ if (diff + (diff & 1) < 2 * (B - a)) diff = 2 * (B - a) - 1;
+ if (diff - (diff & 1) < 2 * (a - n)) diff = 2 * (a - n);
+ }
+ return diff;
+}
+
+void DefaultSqueezeParameters(std::vector<SqueezeParams> *parameters,
+ const Image &image);
+
+Status CheckMetaSqueezeParams(const SqueezeParams &parameter, int num_channels);
+
+Status MetaSqueeze(Image &image, std::vector<SqueezeParams> *parameters);
+
+Status InvSqueeze(Image &input, std::vector<SqueezeParams> parameters,
+ ThreadPool *pool);
+
+} // namespace jxl
+
+#endif // LIB_JXL_MODULAR_TRANSFORM_SQUEEZE_H_
diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/transform.cc b/third_party/jpeg-xl/lib/jxl/modular/transform/transform.cc
new file mode 100644
index 0000000000..d9f2b435bf
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/modular/transform/transform.cc
@@ -0,0 +1,98 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#include "lib/jxl/modular/transform/transform.h"
+
+#include "lib/jxl/base/printf_macros.h"
+#include "lib/jxl/fields.h"
+#include "lib/jxl/modular/modular_image.h"
+#include "lib/jxl/modular/transform/palette.h"
+#include "lib/jxl/modular/transform/rct.h"
+#include "lib/jxl/modular/transform/squeeze.h"
+
+namespace jxl {
+
+SqueezeParams::SqueezeParams() { Bundle::Init(this); }
+Transform::Transform(TransformId id) {
+ Bundle::Init(this);
+ this->id = id;
+}
+
+Status Transform::Inverse(Image &input, const weighted::Header &wp_header,
+ ThreadPool *pool) {
+ JXL_DEBUG_V(6, "Input channels (%" PRIuS ", %" PRIuS " meta): ",
+ input.channel.size(), input.nb_meta_channels);
+ switch (id) {
+ case TransformId::kRCT:
+ return InvRCT(input, begin_c, rct_type, pool);
+ case TransformId::kSqueeze:
+ return InvSqueeze(input, squeezes, pool);
+ case TransformId::kPalette:
+ return InvPalette(input, begin_c, nb_colors, nb_deltas, predictor,
+ wp_header, pool);
+ default:
+ return JXL_FAILURE("Unknown transformation (ID=%u)",
+ static_cast<unsigned int>(id));
+ }
+}
+
+Status Transform::MetaApply(Image &input) {
+ JXL_DEBUG_V(6, "MetaApply input: %s", input.DebugString().c_str());
+ switch (id) {
+ case TransformId::kRCT:
+ JXL_DEBUG_V(2, "Transform: kRCT, rct_type=%" PRIu32, rct_type);
+ return CheckEqualChannels(input, begin_c, begin_c + 2);
+ case TransformId::kSqueeze:
+ JXL_DEBUG_V(2, "Transform: kSqueeze:");
+#if JXL_DEBUG_V_LEVEL >= 2
+ {
+ auto squeezes_copy = squeezes;
+ if (squeezes_copy.empty()) {
+ DefaultSqueezeParameters(&squeezes_copy, input);
+ }
+ for (const auto &params : squeezes_copy) {
+ JXL_DEBUG_V(
+ 2,
+ " squeeze params: horizontal=%d, in_place=%d, begin_c=%" PRIu32
+ ", num_c=%" PRIu32,
+ params.horizontal, params.in_place, params.begin_c, params.num_c);
+ }
+ }
+#endif
+ return MetaSqueeze(input, &squeezes);
+ case TransformId::kPalette:
+ JXL_DEBUG_V(2,
+ "Transform: kPalette, begin_c=%" PRIu32 ", num_c=%" PRIu32
+ ", nb_colors=%" PRIu32 ", nb_deltas=%" PRIu32,
+ begin_c, num_c, nb_colors, nb_deltas);
+ return MetaPalette(input, begin_c, begin_c + num_c - 1, nb_colors,
+ nb_deltas, lossy_palette);
+ default:
+ return JXL_FAILURE("Unknown transformation (ID=%u)",
+ static_cast<unsigned int>(id));
+ }
+}
+
+Status CheckEqualChannels(const Image &image, uint32_t c1, uint32_t c2) {
+ if (c1 > image.channel.size() || c2 >= image.channel.size() || c2 < c1) {
+ return JXL_FAILURE("Invalid channel range: %u..%u (there are only %" PRIuS
+ " channels)",
+ c1, c2, image.channel.size());
+ }
+ if (c1 < image.nb_meta_channels && c2 >= image.nb_meta_channels) {
+ return JXL_FAILURE("Invalid: transforming mix of meta and nonmeta");
+ }
+ const auto &ch1 = image.channel[c1];
+ for (size_t c = c1 + 1; c <= c2; c++) {
+ const auto &ch2 = image.channel[c];
+ if (ch1.w != ch2.w || ch1.h != ch2.h || ch1.hshift != ch2.hshift ||
+ ch1.vshift != ch2.vshift) {
+ return false;
+ }
+ }
+ return true;
+}
+
+} // namespace jxl
diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/transform.h b/third_party/jpeg-xl/lib/jxl/modular/transform/transform.h
new file mode 100644
index 0000000000..d5d3259f7a
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/modular/transform/transform.h
@@ -0,0 +1,148 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#ifndef LIB_JXL_MODULAR_TRANSFORM_TRANSFORM_H_
+#define LIB_JXL_MODULAR_TRANSFORM_TRANSFORM_H_
+
+#include <cstdint>
+#include <string>
+#include <vector>
+
+#include "lib/jxl/base/data_parallel.h"
+#include "lib/jxl/fields.h"
+#include "lib/jxl/modular/encoding/context_predict.h"
+#include "lib/jxl/modular/options.h"
+
+namespace jxl {
+
+enum class TransformId : uint32_t {
+ // G, R-G, B-G and variants (including YCoCg).
+ kRCT = 0,
+
+ // Color palette. Parameters are: [begin_c] [end_c] [nb_colors]
+ kPalette = 1,
+
+ // Squeezing (Haar-style)
+ kSqueeze = 2,
+
+ // Invalid for now.
+ kInvalid = 3,
+};
+
+struct SqueezeParams : public Fields {
+ JXL_FIELDS_NAME(SqueezeParams)
+ bool horizontal;
+ bool in_place;
+ uint32_t begin_c;
+ uint32_t num_c;
+ SqueezeParams();
+ Status VisitFields(Visitor *JXL_RESTRICT visitor) override {
+ JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &horizontal));
+ JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &in_place));
+ JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Bits(3), BitsOffset(6, 8),
+ BitsOffset(10, 72),
+ BitsOffset(13, 1096), 0, &begin_c));
+ JXL_QUIET_RETURN_IF_ERROR(
+ visitor->U32(Val(1), Val(2), Val(3), BitsOffset(4, 4), 2, &num_c));
+ return true;
+ }
+};
+
+class Transform : public Fields {
+ public:
+ TransformId id;
+ // for Palette and RCT.
+ uint32_t begin_c;
+ // for RCT. 42 possible values starting from 0.
+ uint32_t rct_type;
+ // Only for Palette and NearLossless.
+ uint32_t num_c;
+ // Only for Palette.
+ uint32_t nb_colors;
+ uint32_t nb_deltas;
+ // for Squeeze. Default squeeze if empty.
+ std::vector<SqueezeParams> squeezes;
+ // for NearLossless, not serialized.
+ int max_delta_error;
+ // Serialized for Palette.
+ Predictor predictor;
+ // for Palette, not serialized.
+ bool ordered_palette = true;
+ bool lossy_palette = false;
+
+ explicit Transform(TransformId id);
+ // default constructor for bundles.
+ Transform() : Transform(TransformId::kInvalid) {}
+
+ Status VisitFields(Visitor *JXL_RESTRICT visitor) override {
+ JXL_QUIET_RETURN_IF_ERROR(visitor->U32(
+ Val((uint32_t)TransformId::kRCT), Val((uint32_t)TransformId::kPalette),
+ Val((uint32_t)TransformId::kSqueeze),
+ Val((uint32_t)TransformId::kInvalid), (uint32_t)TransformId::kRCT,
+ reinterpret_cast<uint32_t *>(&id)));
+ if (id == TransformId::kInvalid) {
+ return JXL_FAILURE("Invalid transform ID");
+ }
+ if (visitor->Conditional(id == TransformId::kRCT ||
+ id == TransformId::kPalette)) {
+ JXL_QUIET_RETURN_IF_ERROR(
+ visitor->U32(Bits(3), BitsOffset(6, 8), BitsOffset(10, 72),
+ BitsOffset(13, 1096), 0, &begin_c));
+ }
+ if (visitor->Conditional(id == TransformId::kRCT)) {
+ // 0-41, default YCoCg.
+ JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(6), Bits(2), BitsOffset(4, 2),
+ BitsOffset(6, 10), 6, &rct_type));
+ if (rct_type >= 42) {
+ return JXL_FAILURE("Invalid transform RCT type");
+ }
+ }
+ if (visitor->Conditional(id == TransformId::kPalette)) {
+ JXL_QUIET_RETURN_IF_ERROR(
+ visitor->U32(Val(1), Val(3), Val(4), BitsOffset(13, 1), 3, &num_c));
+ JXL_QUIET_RETURN_IF_ERROR(visitor->U32(
+ BitsOffset(8, 0), BitsOffset(10, 256), BitsOffset(12, 1280),
+ BitsOffset(16, 5376), 256, &nb_colors));
+ JXL_QUIET_RETURN_IF_ERROR(
+ visitor->U32(Val(0), BitsOffset(8, 1), BitsOffset(10, 257),
+ BitsOffset(16, 1281), 0, &nb_deltas));
+ JXL_QUIET_RETURN_IF_ERROR(
+ visitor->Bits(4, (uint32_t)Predictor::Zero,
+ reinterpret_cast<uint32_t *>(&predictor)));
+ if (predictor >= Predictor::Best) {
+ return JXL_FAILURE("Invalid predictor");
+ }
+ }
+
+ if (visitor->Conditional(id == TransformId::kSqueeze)) {
+ uint32_t num_squeezes = static_cast<uint32_t>(squeezes.size());
+ JXL_QUIET_RETURN_IF_ERROR(
+ visitor->U32(Val(0), BitsOffset(4, 1), BitsOffset(6, 9),
+ BitsOffset(8, 41), 0, &num_squeezes));
+ if (visitor->IsReading()) squeezes.resize(num_squeezes);
+ for (size_t i = 0; i < num_squeezes; i++) {
+ JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&squeezes[i]));
+ }
+ }
+ return true;
+ }
+
+ JXL_FIELDS_NAME(Transform)
+
+ Status Inverse(Image &input, const weighted::Header &wp_header,
+ ThreadPool *pool = nullptr);
+ Status MetaApply(Image &input);
+};
+
+Status CheckEqualChannels(const Image &image, uint32_t c1, uint32_t c2);
+
+static inline pixel_type PixelAdd(pixel_type a, pixel_type b) {
+ return static_cast<pixel_type>(static_cast<uint32_t>(a) +
+ static_cast<uint32_t>(b));
+}
+
+} // namespace jxl
+
+#endif // LIB_JXL_MODULAR_TRANSFORM_TRANSFORM_H_