diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 01:47:29 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 01:47:29 +0000 |
commit | 0ebf5bdf043a27fd3dfb7f92e0cb63d88954c44d (patch) | |
tree | a31f07c9bcca9d56ce61e9a1ffd30ef350d513aa /third_party/jpeg-xl/lib/jxl/modular/encoding | |
parent | Initial commit. (diff) | |
download | firefox-esr-0ebf5bdf043a27fd3dfb7f92e0cb63d88954c44d.tar.xz firefox-esr-0ebf5bdf043a27fd3dfb7f92e0cb63d88954c44d.zip |
Adding upstream version 115.8.0esr.upstream/115.8.0esr
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/jpeg-xl/lib/jxl/modular/encoding')
12 files changed, 3524 insertions, 0 deletions
diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/context_predict.h b/third_party/jpeg-xl/lib/jxl/modular/encoding/context_predict.h new file mode 100644 index 0000000000..914cd6a4e4 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/context_predict.h @@ -0,0 +1,626 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_ +#define LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_ + +#include <utility> +#include <vector> + +#include "lib/jxl/fields.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/options.h" + +namespace jxl { + +namespace weighted { +constexpr static size_t kNumPredictors = 4; +constexpr static int64_t kPredExtraBits = 3; +constexpr static int64_t kPredictionRound = ((1 << kPredExtraBits) >> 1) - 1; +constexpr static size_t kNumProperties = 1; + +struct Header : public Fields { + JXL_FIELDS_NAME(WeightedPredictorHeader) + // TODO(janwas): move to cc file, avoid including fields.h. + Header() { Bundle::Init(this); } + + Status VisitFields(Visitor *JXL_RESTRICT visitor) override { + if (visitor->AllDefault(*this, &all_default)) { + // Overwrite all serialized fields, but not any nonserialized_*. + visitor->SetDefault(this); + return true; + } + auto visit_p = [visitor](pixel_type val, pixel_type *p) { + uint32_t up = *p; + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(5, val, &up)); + *p = up; + return Status(true); + }; + JXL_QUIET_RETURN_IF_ERROR(visit_p(16, &p1C)); + JXL_QUIET_RETURN_IF_ERROR(visit_p(10, &p2C)); + JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Ca)); + JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Cb)); + JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Cc)); + JXL_QUIET_RETURN_IF_ERROR(visit_p(0, &p3Cd)); + JXL_QUIET_RETURN_IF_ERROR(visit_p(0, &p3Ce)); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xd, &w[0])); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[1])); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[2])); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[3])); + return true; + } + + bool all_default; + pixel_type p1C = 0, p2C = 0, p3Ca = 0, p3Cb = 0, p3Cc = 0, p3Cd = 0, p3Ce = 0; + uint32_t w[kNumPredictors] = {}; +}; + +struct State { + pixel_type_w prediction[kNumPredictors] = {}; + pixel_type_w pred = 0; // *before* removing the added bits. + std::vector<uint32_t> pred_errors[kNumPredictors]; + std::vector<int32_t> error; + const Header header; + + // Allows to approximate division by a number from 1 to 64. + uint32_t divlookup[64]; + + constexpr static pixel_type_w AddBits(pixel_type_w x) { + return uint64_t(x) << kPredExtraBits; + } + + State(Header header, size_t xsize, size_t ysize) : header(header) { + // Extra margin to avoid out-of-bounds writes. + // All have space for two rows of data. + for (size_t i = 0; i < 4; i++) { + pred_errors[i].resize((xsize + 2) * 2); + } + error.resize((xsize + 2) * 2); + // Initialize division lookup table. + for (int i = 0; i < 64; i++) { + divlookup[i] = (1 << 24) / (i + 1); + } + } + + // Approximates 4+(maxweight<<24)/(x+1), avoiding division + JXL_INLINE uint32_t ErrorWeight(uint64_t x, uint32_t maxweight) const { + int shift = static_cast<int>(FloorLog2Nonzero(x + 1)) - 5; + if (shift < 0) shift = 0; + return 4 + ((maxweight * divlookup[x >> shift]) >> shift); + } + + // Approximates the weighted average of the input values with the given + // weights, avoiding division. Weights must sum to at least 16. + JXL_INLINE pixel_type_w + WeightedAverage(const pixel_type_w *JXL_RESTRICT p, + std::array<uint32_t, kNumPredictors> w) const { + uint32_t weight_sum = 0; + for (size_t i = 0; i < kNumPredictors; i++) { + weight_sum += w[i]; + } + JXL_DASSERT(weight_sum > 15); + uint32_t log_weight = FloorLog2Nonzero(weight_sum); // at least 4. + weight_sum = 0; + for (size_t i = 0; i < kNumPredictors; i++) { + w[i] >>= log_weight - 4; + weight_sum += w[i]; + } + // for rounding. + pixel_type_w sum = (weight_sum >> 1) - 1; + for (size_t i = 0; i < kNumPredictors; i++) { + sum += p[i] * w[i]; + } + return (sum * divlookup[weight_sum - 1]) >> 24; + } + + template <bool compute_properties> + JXL_INLINE pixel_type_w Predict(size_t x, size_t y, size_t xsize, + pixel_type_w N, pixel_type_w W, + pixel_type_w NE, pixel_type_w NW, + pixel_type_w NN, Properties *properties, + size_t offset) { + size_t cur_row = y & 1 ? 0 : (xsize + 2); + size_t prev_row = y & 1 ? (xsize + 2) : 0; + size_t pos_N = prev_row + x; + size_t pos_NE = x < xsize - 1 ? pos_N + 1 : pos_N; + size_t pos_NW = x > 0 ? pos_N - 1 : pos_N; + std::array<uint32_t, kNumPredictors> weights; + for (size_t i = 0; i < kNumPredictors; i++) { + // pred_errors[pos_N] also contains the error of pixel W. + // pred_errors[pos_NW] also contains the error of pixel WW. + weights[i] = pred_errors[i][pos_N] + pred_errors[i][pos_NE] + + pred_errors[i][pos_NW]; + weights[i] = ErrorWeight(weights[i], header.w[i]); + } + + N = AddBits(N); + W = AddBits(W); + NE = AddBits(NE); + NW = AddBits(NW); + NN = AddBits(NN); + + pixel_type_w teW = x == 0 ? 0 : error[cur_row + x - 1]; + pixel_type_w teN = error[pos_N]; + pixel_type_w teNW = error[pos_NW]; + pixel_type_w sumWN = teN + teW; + pixel_type_w teNE = error[pos_NE]; + + if (compute_properties) { + pixel_type_w p = teW; + if (std::abs(teN) > std::abs(p)) p = teN; + if (std::abs(teNW) > std::abs(p)) p = teNW; + if (std::abs(teNE) > std::abs(p)) p = teNE; + (*properties)[offset++] = p; + } + + prediction[0] = W + NE - N; + prediction[1] = N - (((sumWN + teNE) * header.p1C) >> 5); + prediction[2] = W - (((sumWN + teNW) * header.p2C) >> 5); + prediction[3] = + N - ((teNW * header.p3Ca + teN * header.p3Cb + teNE * header.p3Cc + + (NN - N) * header.p3Cd + (NW - W) * header.p3Ce) >> + 5); + + pred = WeightedAverage(prediction, weights); + + // If all three have the same sign, skip clamping. + if (((teN ^ teW) | (teN ^ teNW)) > 0) { + return (pred + kPredictionRound) >> kPredExtraBits; + } + + // Otherwise, clamp to min/max of neighbouring pixels (just W, NE, N). + pixel_type_w mx = std::max(W, std::max(NE, N)); + pixel_type_w mn = std::min(W, std::min(NE, N)); + pred = std::max(mn, std::min(mx, pred)); + return (pred + kPredictionRound) >> kPredExtraBits; + } + + JXL_INLINE void UpdateErrors(pixel_type_w val, size_t x, size_t y, + size_t xsize) { + size_t cur_row = y & 1 ? 0 : (xsize + 2); + size_t prev_row = y & 1 ? (xsize + 2) : 0; + val = AddBits(val); + error[cur_row + x] = pred - val; + for (size_t i = 0; i < kNumPredictors; i++) { + pixel_type_w err = + (std::abs(prediction[i] - val) + kPredictionRound) >> kPredExtraBits; + // For predicting in the next row. + pred_errors[i][cur_row + x] = err; + // Add the error on this pixel to the error on the NE pixel. This has the + // effect of adding the error on this pixel to the E and EE pixels. + pred_errors[i][prev_row + x + 1] += err; + } + } +}; + +// Encoder helper function to set the parameters to some presets. +inline void PredictorMode(int i, Header *header) { + switch (i) { + case 0: + // ~ lossless16 predictor + header->w[0] = 0xd; + header->w[1] = 0xc; + header->w[2] = 0xc; + header->w[3] = 0xc; + header->p1C = 16; + header->p2C = 10; + header->p3Ca = 7; + header->p3Cb = 7; + header->p3Cc = 7; + header->p3Cd = 0; + header->p3Ce = 0; + break; + case 1: + // ~ default lossless8 predictor + header->w[0] = 0xd; + header->w[1] = 0xc; + header->w[2] = 0xc; + header->w[3] = 0xb; + header->p1C = 8; + header->p2C = 8; + header->p3Ca = 4; + header->p3Cb = 0; + header->p3Cc = 3; + header->p3Cd = 23; + header->p3Ce = 2; + break; + case 2: + // ~ west lossless8 predictor + header->w[0] = 0xd; + header->w[1] = 0xc; + header->w[2] = 0xd; + header->w[3] = 0xc; + header->p1C = 10; + header->p2C = 9; + header->p3Ca = 7; + header->p3Cb = 0; + header->p3Cc = 0; + header->p3Cd = 16; + header->p3Ce = 9; + break; + case 3: + // ~ north lossless8 predictor + header->w[0] = 0xd; + header->w[1] = 0xd; + header->w[2] = 0xc; + header->w[3] = 0xc; + header->p1C = 16; + header->p2C = 8; + header->p3Ca = 0; + header->p3Cb = 16; + header->p3Cc = 0; + header->p3Cd = 23; + header->p3Ce = 0; + break; + case 4: + default: + // something else, because why not + header->w[0] = 0xd; + header->w[1] = 0xc; + header->w[2] = 0xc; + header->w[3] = 0xc; + header->p1C = 10; + header->p2C = 10; + header->p3Ca = 5; + header->p3Cb = 5; + header->p3Cc = 5; + header->p3Cd = 12; + header->p3Ce = 4; + break; + } +} +} // namespace weighted + +// Stores a node and its two children at the same time. This significantly +// reduces the number of branches needed during decoding. +struct FlatDecisionNode { + // Property + splitval of the top node. + int32_t property0; // -1 if leaf. + union { + PropertyVal splitval0; + Predictor predictor; + }; + uint32_t childID; // childID is ctx id if leaf. + // Property+splitval of the two child nodes. + union { + PropertyVal splitvals[2]; + int32_t multiplier; + }; + union { + int32_t properties[2]; + int64_t predictor_offset; + }; +}; +using FlatTree = std::vector<FlatDecisionNode>; + +class MATreeLookup { + public: + explicit MATreeLookup(const FlatTree &tree) : nodes_(tree) {} + struct LookupResult { + uint32_t context; + Predictor predictor; + int64_t offset; + int32_t multiplier; + }; + JXL_INLINE LookupResult Lookup(const Properties &properties) const { + uint32_t pos = 0; + while (true) { + const FlatDecisionNode &node = nodes_[pos]; + if (node.property0 < 0) { + return {node.childID, node.predictor, node.predictor_offset, + node.multiplier}; + } + bool p0 = properties[node.property0] <= node.splitval0; + uint32_t off0 = properties[node.properties[0]] <= node.splitvals[0]; + uint32_t off1 = + 2 | (properties[node.properties[1]] <= node.splitvals[1] ? 1 : 0); + pos = node.childID + (p0 ? off1 : off0); + } + } + + private: + const FlatTree &nodes_; +}; + +static constexpr size_t kExtraPropsPerChannel = 4; +static constexpr size_t kNumNonrefProperties = + kNumStaticProperties + 13 + weighted::kNumProperties; + +constexpr size_t kWPProp = kNumNonrefProperties - weighted::kNumProperties; +constexpr size_t kGradientProp = 9; + +// Clamps gradient to the min/max of n, w (and l, implicitly). +static JXL_INLINE int32_t ClampedGradient(const int32_t n, const int32_t w, + const int32_t l) { + const int32_t m = std::min(n, w); + const int32_t M = std::max(n, w); + // The end result of this operation doesn't overflow or underflow if the + // result is between m and M, but the intermediate value may overflow, so we + // do the intermediate operations in uint32_t and check later if we had an + // overflow or underflow condition comparing m, M and l directly. + // grad = M + m - l = n + w - l + const int32_t grad = + static_cast<int32_t>(static_cast<uint32_t>(n) + static_cast<uint32_t>(w) - + static_cast<uint32_t>(l)); + // We use two sets of ternary operators to force the evaluation of them in + // any case, allowing the compiler to avoid branches and use cmovl/cmovg in + // x86. + const int32_t grad_clamp_M = (l < m) ? M : grad; + return (l > M) ? m : grad_clamp_M; +} + +inline pixel_type_w Select(pixel_type_w a, pixel_type_w b, pixel_type_w c) { + pixel_type_w p = a + b - c; + pixel_type_w pa = std::abs(p - a); + pixel_type_w pb = std::abs(p - b); + return pa < pb ? a : b; +} + +inline void PrecomputeReferences(const Channel &ch, size_t y, + const Image &image, uint32_t i, + Channel *references) { + ZeroFillImage(&references->plane); + uint32_t offset = 0; + size_t num_extra_props = references->w; + intptr_t onerow = references->plane.PixelsPerRow(); + for (int32_t j = static_cast<int32_t>(i) - 1; + j >= 0 && offset < num_extra_props; j--) { + if (image.channel[j].w != image.channel[i].w || + image.channel[j].h != image.channel[i].h) { + continue; + } + if (image.channel[j].hshift != image.channel[i].hshift) continue; + if (image.channel[j].vshift != image.channel[i].vshift) continue; + pixel_type *JXL_RESTRICT rp = references->Row(0) + offset; + const pixel_type *JXL_RESTRICT rpp = image.channel[j].Row(y); + const pixel_type *JXL_RESTRICT rpprev = image.channel[j].Row(y ? y - 1 : 0); + for (size_t x = 0; x < ch.w; x++, rp += onerow) { + pixel_type_w v = rpp[x]; + rp[0] = std::abs(v); + rp[1] = v; + pixel_type_w vleft = (x ? rpp[x - 1] : 0); + pixel_type_w vtop = (y ? rpprev[x] : vleft); + pixel_type_w vtopleft = (x && y ? rpprev[x - 1] : vleft); + pixel_type_w vpredicted = ClampedGradient(vleft, vtop, vtopleft); + rp[2] = std::abs(v - vpredicted); + rp[3] = v - vpredicted; + } + + offset += kExtraPropsPerChannel; + } +} + +struct PredictionResult { + int context = 0; + pixel_type_w guess = 0; + Predictor predictor; + int32_t multiplier; +}; + +inline void InitPropsRow( + Properties *p, + const std::array<pixel_type, kNumStaticProperties> &static_props, + const int y) { + for (size_t i = 0; i < kNumStaticProperties; i++) { + (*p)[i] = static_props[i]; + } + (*p)[2] = y; + (*p)[9] = 0; // local gradient. +} + +namespace detail { +enum PredictorMode { + kUseTree = 1, + kUseWP = 2, + kForceComputeProperties = 4, + kAllPredictions = 8, + kNoEdgeCases = 16 +}; + +JXL_INLINE pixel_type_w PredictOne(Predictor p, pixel_type_w left, + pixel_type_w top, pixel_type_w toptop, + pixel_type_w topleft, pixel_type_w topright, + pixel_type_w leftleft, + pixel_type_w toprightright, + pixel_type_w wp_pred) { + switch (p) { + case Predictor::Zero: + return pixel_type_w{0}; + case Predictor::Left: + return left; + case Predictor::Top: + return top; + case Predictor::Select: + return Select(left, top, topleft); + case Predictor::Weighted: + return wp_pred; + case Predictor::Gradient: + return pixel_type_w{ClampedGradient(left, top, topleft)}; + case Predictor::TopLeft: + return topleft; + case Predictor::TopRight: + return topright; + case Predictor::LeftLeft: + return leftleft; + case Predictor::Average0: + return (left + top) / 2; + case Predictor::Average1: + return (left + topleft) / 2; + case Predictor::Average2: + return (topleft + top) / 2; + case Predictor::Average3: + return (top + topright) / 2; + case Predictor::Average4: + return (6 * top - 2 * toptop + 7 * left + 1 * leftleft + + 1 * toprightright + 3 * topright + 8) / + 16; + default: + return pixel_type_w{0}; + } +} + +template <int mode> +JXL_INLINE PredictionResult Predict( + Properties *p, size_t w, const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const size_t x, const size_t y, Predictor predictor, + const MATreeLookup *lookup, const Channel *references, + weighted::State *wp_state, pixel_type_w *predictions) { + // We start in position 3 because of 2 static properties + y. + size_t offset = 3; + constexpr bool compute_properties = + mode & kUseTree || mode & kForceComputeProperties; + constexpr bool nec = mode & kNoEdgeCases; + pixel_type_w left = (nec || x ? pp[-1] : (y ? pp[-onerow] : 0)); + pixel_type_w top = (nec || y ? pp[-onerow] : left); + pixel_type_w topleft = (nec || (x && y) ? pp[-1 - onerow] : left); + pixel_type_w topright = (nec || (x + 1 < w && y) ? pp[1 - onerow] : top); + pixel_type_w leftleft = (nec || x > 1 ? pp[-2] : left); + pixel_type_w toptop = (nec || y > 1 ? pp[-onerow - onerow] : top); + pixel_type_w toprightright = + (nec || (x + 2 < w && y) ? pp[2 - onerow] : topright); + + if (compute_properties) { + // location + (*p)[offset++] = x; + // neighbors + (*p)[offset++] = std::abs(top); + (*p)[offset++] = std::abs(left); + (*p)[offset++] = top; + (*p)[offset++] = left; + + // local gradient + (*p)[offset] = left - (*p)[offset + 1]; + offset++; + // local gradient + (*p)[offset++] = left + top - topleft; + + // FFV1 context properties + (*p)[offset++] = left - topleft; + (*p)[offset++] = topleft - top; + (*p)[offset++] = top - topright; + (*p)[offset++] = top - toptop; + (*p)[offset++] = left - leftleft; + } + + pixel_type_w wp_pred = 0; + if (mode & kUseWP) { + wp_pred = wp_state->Predict<compute_properties>( + x, y, w, top, left, topright, topleft, toptop, p, offset); + } + if (!nec && compute_properties) { + offset += weighted::kNumProperties; + // Extra properties. + const pixel_type *JXL_RESTRICT rp = references->Row(x); + for (size_t i = 0; i < references->w; i++) { + (*p)[offset++] = rp[i]; + } + } + PredictionResult result; + if (mode & kUseTree) { + MATreeLookup::LookupResult lr = lookup->Lookup(*p); + result.context = lr.context; + result.guess = lr.offset; + result.multiplier = lr.multiplier; + predictor = lr.predictor; + } + if (mode & kAllPredictions) { + for (size_t i = 0; i < kNumModularPredictors; i++) { + predictions[i] = PredictOne((Predictor)i, left, top, toptop, topleft, + topright, leftleft, toprightright, wp_pred); + } + } + result.guess += PredictOne(predictor, left, top, toptop, topleft, topright, + leftleft, toprightright, wp_pred); + result.predictor = predictor; + + return result; +} +} // namespace detail + +inline PredictionResult PredictNoTreeNoWP(size_t w, + const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, + const int y, Predictor predictor) { + return detail::Predict</*mode=*/0>( + /*p=*/nullptr, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, + /*references=*/nullptr, /*wp_state=*/nullptr, /*predictions=*/nullptr); +} + +inline PredictionResult PredictNoTreeWP(size_t w, + const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, + const int y, Predictor predictor, + weighted::State *wp_state) { + return detail::Predict<detail::kUseWP>( + /*p=*/nullptr, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, + /*references=*/nullptr, wp_state, /*predictions=*/nullptr); +} + +inline PredictionResult PredictTreeNoWP(Properties *p, size_t w, + const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, + const int y, + const MATreeLookup &tree_lookup, + const Channel &references) { + return detail::Predict<detail::kUseTree>( + p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references, + /*wp_state=*/nullptr, /*predictions=*/nullptr); +} +// Only use for y > 1, x > 1, x < w-2, and empty references +JXL_INLINE PredictionResult +PredictTreeNoWPNEC(Properties *p, size_t w, const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, const int y, + const MATreeLookup &tree_lookup, const Channel &references) { + return detail::Predict<detail::kUseTree | detail::kNoEdgeCases>( + p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references, + /*wp_state=*/nullptr, /*predictions=*/nullptr); +} + +inline PredictionResult PredictTreeWP(Properties *p, size_t w, + const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, + const int y, + const MATreeLookup &tree_lookup, + const Channel &references, + weighted::State *wp_state) { + return detail::Predict<detail::kUseTree | detail::kUseWP>( + p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references, + wp_state, /*predictions=*/nullptr); +} + +inline PredictionResult PredictLearn(Properties *p, size_t w, + const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, + const int y, Predictor predictor, + const Channel &references, + weighted::State *wp_state) { + return detail::Predict<detail::kForceComputeProperties | detail::kUseWP>( + p, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, &references, + wp_state, /*predictions=*/nullptr); +} + +inline void PredictLearnAll(Properties *p, size_t w, + const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, const int y, + const Channel &references, + weighted::State *wp_state, + pixel_type_w *predictions) { + detail::Predict<detail::kForceComputeProperties | detail::kUseWP | + detail::kAllPredictions>( + p, w, pp, onerow, x, y, Predictor::Zero, + /*lookup=*/nullptr, &references, wp_state, predictions); +} + +inline void PredictAllNoWP(size_t w, const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, const int y, + pixel_type_w *predictions) { + detail::Predict<detail::kAllPredictions>( + /*p=*/nullptr, w, pp, onerow, x, y, Predictor::Zero, + /*lookup=*/nullptr, + /*references=*/nullptr, /*wp_state=*/nullptr, predictions); +} +} // namespace jxl + +#endif // LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/dec_ma.cc b/third_party/jpeg-xl/lib/jxl/modular/encoding/dec_ma.cc new file mode 100644 index 0000000000..66562f7dfd --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/dec_ma.cc @@ -0,0 +1,107 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/encoding/dec_ma.h" + +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/modular/encoding/ma_common.h" +#include "lib/jxl/modular/modular_image.h" + +namespace jxl { + +namespace { + +Status ValidateTree( + const Tree &tree, + const std::vector<std::pair<pixel_type, pixel_type>> &prop_bounds, + size_t root) { + if (tree[root].property == -1) return true; + size_t p = tree[root].property; + int val = tree[root].splitval; + if (prop_bounds[p].first > val) return JXL_FAILURE("Invalid tree"); + // Splitting at max value makes no sense: left range will be exactly same + // as parent, right range will be invalid (min > max). + if (prop_bounds[p].second <= val) return JXL_FAILURE("Invalid tree"); + auto new_bounds = prop_bounds; + new_bounds[p].first = val + 1; + JXL_RETURN_IF_ERROR(ValidateTree(tree, new_bounds, tree[root].lchild)); + new_bounds[p] = prop_bounds[p]; + new_bounds[p].second = val; + return ValidateTree(tree, new_bounds, tree[root].rchild); +} + +Status DecodeTree(BitReader *br, ANSSymbolReader *reader, + const std::vector<uint8_t> &context_map, Tree *tree, + size_t tree_size_limit) { + size_t leaf_id = 0; + size_t to_decode = 1; + tree->clear(); + while (to_decode > 0) { + JXL_RETURN_IF_ERROR(br->AllReadsWithinBounds()); + if (tree->size() > tree_size_limit) { + return JXL_FAILURE("Tree is too large: %" PRIuS " nodes vs %" PRIuS + " max nodes", + tree->size(), tree_size_limit); + } + to_decode--; + uint32_t prop1 = reader->ReadHybridUint(kPropertyContext, br, context_map); + if (prop1 > 256) return JXL_FAILURE("Invalid tree property value"); + int property = prop1 - 1; + if (property == -1) { + size_t predictor = + reader->ReadHybridUint(kPredictorContext, br, context_map); + if (predictor >= kNumModularPredictors) { + return JXL_FAILURE("Invalid predictor"); + } + int64_t predictor_offset = + UnpackSigned(reader->ReadHybridUint(kOffsetContext, br, context_map)); + uint32_t mul_log = + reader->ReadHybridUint(kMultiplierLogContext, br, context_map); + if (mul_log >= 31) { + return JXL_FAILURE("Invalid multiplier logarithm"); + } + uint32_t mul_bits = + reader->ReadHybridUint(kMultiplierBitsContext, br, context_map); + if (mul_bits + 1 >= 1u << (31u - mul_log)) { + return JXL_FAILURE("Invalid multiplier"); + } + uint32_t multiplier = (mul_bits + 1U) << mul_log; + tree->emplace_back(-1, 0, leaf_id++, 0, static_cast<Predictor>(predictor), + predictor_offset, multiplier); + continue; + } + int splitval = + UnpackSigned(reader->ReadHybridUint(kSplitValContext, br, context_map)); + tree->emplace_back(property, splitval, tree->size() + to_decode + 1, + tree->size() + to_decode + 2, Predictor::Zero, 0, 1); + to_decode += 2; + } + std::vector<std::pair<pixel_type, pixel_type>> prop_bounds; + prop_bounds.resize(256, {std::numeric_limits<pixel_type>::min(), + std::numeric_limits<pixel_type>::max()}); + return ValidateTree(*tree, prop_bounds, 0); +} +} // namespace + +Status DecodeTree(BitReader *br, Tree *tree, size_t tree_size_limit) { + std::vector<uint8_t> tree_context_map; + ANSCode tree_code; + JXL_RETURN_IF_ERROR( + DecodeHistograms(br, kNumTreeContexts, &tree_code, &tree_context_map)); + // TODO(eustas): investigate more infinite tree cases. + if (tree_code.degenerate_symbols[tree_context_map[kPropertyContext]] > 0) { + return JXL_FAILURE("Infinite tree"); + } + ANSSymbolReader reader(&tree_code, br); + JXL_RETURN_IF_ERROR(DecodeTree(br, &reader, tree_context_map, tree, + std::min(tree_size_limit, kMaxTreeSize))); + if (!reader.CheckANSFinalState()) { + return JXL_FAILURE("ANS decode final state failed"); + } + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/dec_ma.h b/third_party/jpeg-xl/lib/jxl/modular/encoding/dec_ma.h new file mode 100644 index 0000000000..a910c4deb1 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/dec_ma.h @@ -0,0 +1,66 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_ENCODING_DEC_MA_H_ +#define LIB_JXL_MODULAR_ENCODING_DEC_MA_H_ + +#include <stddef.h> +#include <stdint.h> + +#include <vector> + +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/modular/options.h" + +namespace jxl { + +// inner nodes +struct PropertyDecisionNode { + PropertyVal splitval; + int16_t property; // -1: leaf node, lchild points to leaf node + uint32_t lchild; + uint32_t rchild; + Predictor predictor; + int64_t predictor_offset; + uint32_t multiplier; + + PropertyDecisionNode(int p, int split_val, int lchild, int rchild, + Predictor predictor, int64_t predictor_offset, + uint32_t multiplier) + : splitval(split_val), + property(p), + lchild(lchild), + rchild(rchild), + predictor(predictor), + predictor_offset(predictor_offset), + multiplier(multiplier) {} + PropertyDecisionNode() + : splitval(0), + property(-1), + lchild(0), + rchild(0), + predictor(Predictor::Zero), + predictor_offset(0), + multiplier(1) {} + static PropertyDecisionNode Leaf(Predictor predictor, int64_t offset = 0, + uint32_t multiplier = 1) { + return PropertyDecisionNode(-1, 0, 0, 0, predictor, offset, multiplier); + } + static PropertyDecisionNode Split(int p, int split_val, int lchild, + int rchild = -1) { + if (rchild == -1) rchild = lchild + 1; + return PropertyDecisionNode(p, split_val, lchild, rchild, Predictor::Zero, + 0, 1); + } +}; + +using Tree = std::vector<PropertyDecisionNode>; + +Status DecodeTree(BitReader *br, Tree *tree, size_t tree_size_limit); + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_ENCODING_DEC_MA_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_debug_tree.cc b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_debug_tree.cc new file mode 100644 index 0000000000..f2a1705e4b --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_debug_tree.cc @@ -0,0 +1,124 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/encoding/enc_debug_tree.h" + +#include <stdint.h> +#include <stdlib.h> + +#include "lib/jxl/base/os_macros.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/encoding/dec_ma.h" +#include "lib/jxl/modular/options.h" + +#if JXL_OS_IOS +#define JXL_ENABLE_DOT 0 +#else +#define JXL_ENABLE_DOT 1 // iOS lacks C89 system() +#endif + +namespace jxl { + +const char *PredictorName(Predictor p) { + switch (p) { + case Predictor::Zero: + return "Zero"; + case Predictor::Left: + return "Left"; + case Predictor::Top: + return "Top"; + case Predictor::Average0: + return "Avg0"; + case Predictor::Average1: + return "Avg1"; + case Predictor::Average2: + return "Avg2"; + case Predictor::Average3: + return "Avg3"; + case Predictor::Average4: + return "Avg4"; + case Predictor::Select: + return "Sel"; + case Predictor::Gradient: + return "Grd"; + case Predictor::Weighted: + return "Wgh"; + case Predictor::TopLeft: + return "TopL"; + case Predictor::TopRight: + return "TopR"; + case Predictor::LeftLeft: + return "LL"; + default: + return "INVALID"; + }; +} + +std::string PropertyName(size_t i) { + static_assert(kNumNonrefProperties == 16, "Update this function"); + switch (i) { + case 0: + return "c"; + case 1: + return "g"; + case 2: + return "y"; + case 3: + return "x"; + case 4: + return "|N|"; + case 5: + return "|W|"; + case 6: + return "N"; + case 7: + return "W"; + case 8: + return "W-WW-NW+NWW"; + case 9: + return "W+N-NW"; + case 10: + return "W-NW"; + case 11: + return "NW-N"; + case 12: + return "N-NE"; + case 13: + return "N-NN"; + case 14: + return "W-WW"; + case 15: + return "WGH"; + default: + return "ch[" + ToString(15 - (int)i) + "]"; + } +} + +void PrintTree(const Tree &tree, const std::string &path) { + FILE *f = fopen((path + ".dot").c_str(), "w"); + fprintf(f, "graph{\n"); + for (size_t cur = 0; cur < tree.size(); cur++) { + if (tree[cur].property < 0) { + fprintf(f, "n%05" PRIuS " [label=\"%s%+" PRId64 " (x%u)\"];\n", cur, + PredictorName(tree[cur].predictor), tree[cur].predictor_offset, + tree[cur].multiplier); + } else { + fprintf(f, "n%05" PRIuS " [label=\"%s>%d\"];\n", cur, + PropertyName(tree[cur].property).c_str(), tree[cur].splitval); + fprintf(f, "n%05" PRIuS " -- n%05d;\n", cur, tree[cur].lchild); + fprintf(f, "n%05" PRIuS " -- n%05d;\n", cur, tree[cur].rchild); + } + } + fprintf(f, "}\n"); + fclose(f); +#if JXL_ENABLE_DOT + JXL_ASSERT( + system(("dot " + path + ".dot -T svg -o " + path + ".svg").c_str()) == 0); +#endif +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_debug_tree.h b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_debug_tree.h new file mode 100644 index 0000000000..78deaab1b8 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_debug_tree.h @@ -0,0 +1,27 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_ENCODING_ENC_DEBUG_TREE_H_ +#define LIB_JXL_MODULAR_ENCODING_ENC_DEBUG_TREE_H_ + +#include <stddef.h> +#include <stdint.h> + +#include <string> +#include <vector> + +#include "lib/jxl/modular/encoding/dec_ma.h" +#include "lib/jxl/modular/options.h" + +namespace jxl { + +const char *PredictorName(Predictor p); +std::string PropertyName(size_t i); + +void PrintTree(const Tree &tree, const std::string &path); + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_ENCODING_ENC_DEBUG_TREE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_encoding.cc b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_encoding.cc new file mode 100644 index 0000000000..c8c183335e --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_encoding.cc @@ -0,0 +1,562 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include <stdint.h> +#include <stdlib.h> + +#include <cinttypes> +#include <limits> +#include <numeric> +#include <queue> +#include <set> +#include <unordered_map> +#include <unordered_set> + +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/enc_aux_out.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_fields.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/encoding/enc_debug_tree.h" +#include "lib/jxl/modular/encoding/enc_ma.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/modular/encoding/ma_common.h" +#include "lib/jxl/modular/options.h" +#include "lib/jxl/modular/transform/transform.h" +#include "lib/jxl/toc.h" + +namespace jxl { + +namespace { +// Plot tree (if enabled) and predictor usage map. +constexpr bool kWantDebug = false; +constexpr bool kPrintTree = false; + +inline std::array<uint8_t, 3> PredictorColor(Predictor p) { + switch (p) { + case Predictor::Zero: + return {{0, 0, 0}}; + case Predictor::Left: + return {{255, 0, 0}}; + case Predictor::Top: + return {{0, 255, 0}}; + case Predictor::Average0: + return {{0, 0, 255}}; + case Predictor::Average4: + return {{192, 128, 128}}; + case Predictor::Select: + return {{255, 255, 0}}; + case Predictor::Gradient: + return {{255, 0, 255}}; + case Predictor::Weighted: + return {{0, 255, 255}}; + // TODO + default: + return {{255, 255, 255}}; + }; +} + +} // namespace + +void GatherTreeData(const Image &image, pixel_type chan, size_t group_id, + const weighted::Header &wp_header, + const ModularOptions &options, TreeSamples &tree_samples, + size_t *total_pixels) { + const Channel &channel = image.channel[chan]; + + JXL_DEBUG_V(7, "Learning %" PRIuS "x%" PRIuS " channel %d", channel.w, + channel.h, chan); + + std::array<pixel_type, kNumStaticProperties> static_props = { + {chan, (int)group_id}}; + Properties properties(kNumNonrefProperties + + kExtraPropsPerChannel * options.max_properties); + double pixel_fraction = std::min(1.0f, options.nb_repeats); + // a fraction of 0 is used to disable learning entirely. + if (pixel_fraction > 0) { + pixel_fraction = std::max(pixel_fraction, + std::min(1.0, 1024.0 / (channel.w * channel.h))); + } + uint64_t threshold = + (std::numeric_limits<uint64_t>::max() >> 32) * pixel_fraction; + uint64_t s[2] = {static_cast<uint64_t>(0x94D049BB133111EBull), + static_cast<uint64_t>(0xBF58476D1CE4E5B9ull)}; + // Xorshift128+ adapted from xorshift128+-inl.h + auto use_sample = [&]() { + auto s1 = s[0]; + const auto s0 = s[1]; + const auto bits = s1 + s0; // b, c + s[0] = s0; + s1 ^= s1 << 23; + s1 ^= s0 ^ (s1 >> 18) ^ (s0 >> 5); + s[1] = s1; + return (bits >> 32) <= threshold; + }; + + const intptr_t onerow = channel.plane.PixelsPerRow(); + Channel references(properties.size() - kNumNonrefProperties, channel.w); + weighted::State wp_state(wp_header, channel.w, channel.h); + tree_samples.PrepareForSamples(pixel_fraction * channel.h * channel.w + 64); + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT p = channel.Row(y); + PrecomputeReferences(channel, y, image, chan, &references); + InitPropsRow(&properties, static_props, y); + // TODO(veluca): avoid computing WP if we don't use its property or + // predictions. + for (size_t x = 0; x < channel.w; x++) { + pixel_type_w pred[kNumModularPredictors]; + if (tree_samples.NumPredictors() != 1) { + PredictLearnAll(&properties, channel.w, p + x, onerow, x, y, references, + &wp_state, pred); + } else { + pred[static_cast<int>(tree_samples.PredictorFromIndex(0))] = + PredictLearn(&properties, channel.w, p + x, onerow, x, y, + tree_samples.PredictorFromIndex(0), references, + &wp_state) + .guess; + } + (*total_pixels)++; + if (use_sample()) { + tree_samples.AddSample(p[x], properties, pred); + } + wp_state.UpdateErrors(p[x], x, y, channel.w); + } + } +} + +Tree LearnTree(TreeSamples &&tree_samples, size_t total_pixels, + const ModularOptions &options, + const std::vector<ModularMultiplierInfo> &multiplier_info = {}, + StaticPropRange static_prop_range = {}) { + for (size_t i = 0; i < kNumStaticProperties; i++) { + if (static_prop_range[i][1] == 0) { + static_prop_range[i][1] = std::numeric_limits<uint32_t>::max(); + } + } + if (!tree_samples.HasSamples()) { + Tree tree; + tree.emplace_back(); + tree.back().predictor = tree_samples.PredictorFromIndex(0); + tree.back().property = -1; + tree.back().predictor_offset = 0; + tree.back().multiplier = 1; + return tree; + } + float pixel_fraction = tree_samples.NumSamples() * 1.0f / total_pixels; + float required_cost = pixel_fraction * 0.9 + 0.1; + tree_samples.AllSamplesDone(); + Tree tree; + ComputeBestTree(tree_samples, + options.splitting_heuristics_node_threshold * required_cost, + multiplier_info, static_prop_range, + options.fast_decode_multiplier, &tree); + return tree; +} + +Status EncodeModularChannelMAANS(const Image &image, pixel_type chan, + const weighted::Header &wp_header, + const Tree &global_tree, Token **tokenpp, + AuxOut *aux_out, size_t group_id, + bool skip_encoder_fast_path) { + const Channel &channel = image.channel[chan]; + Token *tokenp = *tokenpp; + JXL_ASSERT(channel.w != 0 && channel.h != 0); + + Image3F predictor_img; + if (kWantDebug) predictor_img = Image3F(channel.w, channel.h); + + JXL_DEBUG_V(6, + "Encoding %" PRIuS "x%" PRIuS + " channel %d, " + "(shift=%i,%i)", + channel.w, channel.h, chan, channel.hshift, channel.vshift); + + std::array<pixel_type, kNumStaticProperties> static_props = { + {chan, (int)group_id}}; + bool use_wp, is_wp_only; + bool is_gradient_only; + size_t num_props; + FlatTree tree = FilterTree(global_tree, static_props, &num_props, &use_wp, + &is_wp_only, &is_gradient_only); + Properties properties(num_props); + MATreeLookup tree_lookup(tree); + JXL_DEBUG_V(3, "Encoding using a MA tree with %" PRIuS " nodes", tree.size()); + + // Check if this tree is a WP-only tree with a small enough property value + // range. + // Initialized to avoid clang-tidy complaining. + uint16_t context_lookup[2 * kPropRangeFast] = {}; + int8_t offsets[2 * kPropRangeFast] = {}; + if (is_wp_only) { + is_wp_only = TreeToLookupTable(tree, context_lookup, offsets); + } + if (is_gradient_only) { + is_gradient_only = TreeToLookupTable(tree, context_lookup, offsets); + } + + if (is_wp_only && !skip_encoder_fast_path) { + for (size_t c = 0; c < 3; c++) { + FillImage(static_cast<float>(PredictorColor(Predictor::Weighted)[c]), + &predictor_img.Plane(c)); + } + const intptr_t onerow = channel.plane.PixelsPerRow(); + weighted::State wp_state(wp_header, channel.w, channel.h); + Properties properties(1); + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + size_t offset = 0; + pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); + pixel_type_w top = (y ? *(r + x - onerow) : left); + pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left); + pixel_type_w topright = + (x + 1 < channel.w && y ? *(r + x + 1 - onerow) : top); + pixel_type_w toptop = (y > 1 ? *(r + x - onerow - onerow) : top); + int32_t guess = wp_state.Predict</*compute_properties=*/true>( + x, y, channel.w, top, left, topright, topleft, toptop, &properties, + offset); + uint32_t pos = + kPropRangeFast + std::min(std::max(-kPropRangeFast, properties[0]), + kPropRangeFast - 1); + uint32_t ctx_id = context_lookup[pos]; + int32_t residual = r[x] - guess - offsets[pos]; + *tokenp++ = Token(ctx_id, PackSigned(residual)); + wp_state.UpdateErrors(r[x], x, y, channel.w); + } + } + } else if (tree.size() == 1 && tree[0].predictor == Predictor::Gradient && + tree[0].multiplier == 1 && tree[0].predictor_offset == 0 && + !skip_encoder_fast_path) { + for (size_t c = 0; c < 3; c++) { + FillImage(static_cast<float>(PredictorColor(Predictor::Gradient)[c]), + &predictor_img.Plane(c)); + } + const intptr_t onerow = channel.plane.PixelsPerRow(); + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); + pixel_type_w top = (y ? *(r + x - onerow) : left); + pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left); + int32_t guess = ClampedGradient(top, left, topleft); + int32_t residual = r[x] - guess; + *tokenp++ = Token(tree[0].childID, PackSigned(residual)); + } + } + } else if (is_gradient_only && !skip_encoder_fast_path) { + for (size_t c = 0; c < 3; c++) { + FillImage(static_cast<float>(PredictorColor(Predictor::Gradient)[c]), + &predictor_img.Plane(c)); + } + const intptr_t onerow = channel.plane.PixelsPerRow(); + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); + pixel_type_w top = (y ? *(r + x - onerow) : left); + pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left); + int32_t guess = ClampedGradient(top, left, topleft); + uint32_t pos = + kPropRangeFast + + std::min<pixel_type_w>( + std::max<pixel_type_w>(-kPropRangeFast, top + left - topleft), + kPropRangeFast - 1); + uint32_t ctx_id = context_lookup[pos]; + int32_t residual = r[x] - guess - offsets[pos]; + *tokenp++ = Token(ctx_id, PackSigned(residual)); + } + } + } else if (tree.size() == 1 && tree[0].predictor == Predictor::Zero && + tree[0].multiplier == 1 && tree[0].predictor_offset == 0 && + !skip_encoder_fast_path) { + for (size_t c = 0; c < 3; c++) { + FillImage(static_cast<float>(PredictorColor(Predictor::Zero)[c]), + &predictor_img.Plane(c)); + } + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT p = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + *tokenp++ = Token(tree[0].childID, PackSigned(p[x])); + } + } + } else if (tree.size() == 1 && tree[0].predictor != Predictor::Weighted && + (tree[0].multiplier & (tree[0].multiplier - 1)) == 0 && + tree[0].predictor_offset == 0 && !skip_encoder_fast_path) { + // multiplier is a power of 2. + for (size_t c = 0; c < 3; c++) { + FillImage(static_cast<float>(PredictorColor(tree[0].predictor)[c]), + &predictor_img.Plane(c)); + } + uint32_t mul_shift = FloorLog2Nonzero((uint32_t)tree[0].multiplier); + const intptr_t onerow = channel.plane.PixelsPerRow(); + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + PredictionResult pred = PredictNoTreeNoWP(channel.w, r + x, onerow, x, + y, tree[0].predictor); + pixel_type_w residual = r[x] - pred.guess; + JXL_DASSERT((residual >> mul_shift) * tree[0].multiplier == residual); + *tokenp++ = Token(tree[0].childID, PackSigned(residual >> mul_shift)); + } + } + + } else if (!use_wp && !skip_encoder_fast_path) { + const intptr_t onerow = channel.plane.PixelsPerRow(); + Channel references(properties.size() - kNumNonrefProperties, channel.w); + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT p = channel.Row(y); + PrecomputeReferences(channel, y, image, chan, &references); + float *pred_img_row[3]; + if (kWantDebug) { + for (size_t c = 0; c < 3; c++) { + pred_img_row[c] = predictor_img.PlaneRow(c, y); + } + } + InitPropsRow(&properties, static_props, y); + for (size_t x = 0; x < channel.w; x++) { + PredictionResult res = + PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references); + if (kWantDebug) { + for (size_t i = 0; i < 3; i++) { + pred_img_row[i][x] = PredictorColor(res.predictor)[i]; + } + } + pixel_type_w residual = p[x] - res.guess; + JXL_ASSERT(residual % res.multiplier == 0); + *tokenp++ = Token(res.context, PackSigned(residual / res.multiplier)); + } + } + } else { + const intptr_t onerow = channel.plane.PixelsPerRow(); + Channel references(properties.size() - kNumNonrefProperties, channel.w); + weighted::State wp_state(wp_header, channel.w, channel.h); + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT p = channel.Row(y); + PrecomputeReferences(channel, y, image, chan, &references); + float *pred_img_row[3]; + if (kWantDebug) { + for (size_t c = 0; c < 3; c++) { + pred_img_row[c] = predictor_img.PlaneRow(c, y); + } + } + InitPropsRow(&properties, static_props, y); + for (size_t x = 0; x < channel.w; x++) { + PredictionResult res = + PredictTreeWP(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references, &wp_state); + if (kWantDebug) { + for (size_t i = 0; i < 3; i++) { + pred_img_row[i][x] = PredictorColor(res.predictor)[i]; + } + } + pixel_type_w residual = p[x] - res.guess; + JXL_ASSERT(residual % res.multiplier == 0); + *tokenp++ = Token(res.context, PackSigned(residual / res.multiplier)); + wp_state.UpdateErrors(p[x], x, y, channel.w); + } + } + } + if (kWantDebug && WantDebugOutput(aux_out)) { + aux_out->DumpImage( + ("pred_" + ToString(group_id) + "_" + ToString(chan)).c_str(), + predictor_img); + } + *tokenpp = tokenp; + return true; +} + +Status ModularEncode(const Image &image, const ModularOptions &options, + BitWriter *writer, AuxOut *aux_out, size_t layer, + size_t group_id, TreeSamples *tree_samples, + size_t *total_pixels, const Tree *tree, + GroupHeader *header, std::vector<Token> *tokens, + size_t *width) { + if (image.error) return JXL_FAILURE("Invalid image"); + size_t nb_channels = image.channel.size(); + JXL_DEBUG_V( + 2, "Encoding %" PRIuS "-channel, %i-bit, %" PRIuS "x%" PRIuS " image.", + nb_channels, image.bitdepth, image.w, image.h); + + if (nb_channels < 1) { + return true; // is there any use for a zero-channel image? + } + + // encode transforms + GroupHeader header_storage; + if (header == nullptr) header = &header_storage; + Bundle::Init(header); + if (options.predictor == Predictor::Weighted) { + weighted::PredictorMode(options.wp_mode, &header->wp_header); + } + header->transforms = image.transform; + // This doesn't actually work + if (tree != nullptr) { + header->use_global_tree = true; + } + if (tree_samples == nullptr && tree == nullptr) { + JXL_RETURN_IF_ERROR(Bundle::Write(*header, writer, layer, aux_out)); + } + + TreeSamples tree_samples_storage; + size_t total_pixels_storage = 0; + if (!total_pixels) total_pixels = &total_pixels_storage; + // If there's no tree, compute one (or gather data to). + if (tree == nullptr) { + bool gather_data = tree_samples != nullptr; + if (tree_samples == nullptr) { + JXL_RETURN_IF_ERROR(tree_samples_storage.SetPredictor( + options.predictor, options.wp_tree_mode)); + JXL_RETURN_IF_ERROR(tree_samples_storage.SetProperties( + options.splitting_heuristics_properties, options.wp_tree_mode)); + std::vector<pixel_type> pixel_samples; + std::vector<pixel_type> diff_samples; + std::vector<uint32_t> group_pixel_count; + std::vector<uint32_t> channel_pixel_count; + CollectPixelSamples(image, options, 0, group_pixel_count, + channel_pixel_count, pixel_samples, diff_samples); + std::vector<ModularMultiplierInfo> dummy_multiplier_info; + StaticPropRange range; + tree_samples_storage.PreQuantizeProperties( + range, dummy_multiplier_info, group_pixel_count, channel_pixel_count, + pixel_samples, diff_samples, options.max_property_values); + } + for (size_t i = 0; i < nb_channels; i++) { + if (!image.channel[i].w || !image.channel[i].h) { + continue; // skip empty channels + } + if (i >= image.nb_meta_channels && + (image.channel[i].w > options.max_chan_size || + image.channel[i].h > options.max_chan_size)) { + break; + } + GatherTreeData(image, i, group_id, header->wp_header, options, + gather_data ? *tree_samples : tree_samples_storage, + total_pixels); + } + if (gather_data) return true; + } + + JXL_ASSERT((tree == nullptr) == (tokens == nullptr)); + + Tree tree_storage; + std::vector<std::vector<Token>> tokens_storage(1); + // Compute tree. + if (tree == nullptr) { + EntropyEncodingData code; + std::vector<uint8_t> context_map; + + std::vector<std::vector<Token>> tree_tokens(1); + tree_storage = + LearnTree(std::move(tree_samples_storage), *total_pixels, options); + tree = &tree_storage; + tokens = &tokens_storage[0]; + + Tree decoded_tree; + TokenizeTree(*tree, &tree_tokens[0], &decoded_tree); + JXL_ASSERT(tree->size() == decoded_tree.size()); + tree_storage = std::move(decoded_tree); + + if (kWantDebug && kPrintTree && WantDebugOutput(aux_out)) { + PrintTree(*tree, aux_out->debug_prefix + "/tree_" + ToString(group_id)); + } + // Write tree + BuildAndEncodeHistograms(HistogramParams(), kNumTreeContexts, tree_tokens, + &code, &context_map, writer, kLayerModularTree, + aux_out); + WriteTokens(tree_tokens[0], code, context_map, writer, kLayerModularTree, + aux_out); + } + + size_t image_width = 0; + size_t total_tokens = 0; + for (size_t i = 0; i < nb_channels; i++) { + if (i >= image.nb_meta_channels && + (image.channel[i].w > options.max_chan_size || + image.channel[i].h > options.max_chan_size)) { + break; + } + if (image.channel[i].w > image_width) image_width = image.channel[i].w; + total_tokens += image.channel[i].w * image.channel[i].h; + } + if (options.zero_tokens) { + tokens->resize(tokens->size() + total_tokens, {0, 0}); + } else { + // Do one big allocation for all the tokens we'll need, + // to avoid reallocs that might require copying. + size_t pos = tokens->size(); + tokens->resize(pos + total_tokens); + Token *tokenp = tokens->data() + pos; + for (size_t i = 0; i < nb_channels; i++) { + if (!image.channel[i].w || !image.channel[i].h) { + continue; // skip empty channels + } + if (i >= image.nb_meta_channels && + (image.channel[i].w > options.max_chan_size || + image.channel[i].h > options.max_chan_size)) { + break; + } + JXL_RETURN_IF_ERROR(EncodeModularChannelMAANS( + image, i, header->wp_header, *tree, &tokenp, aux_out, group_id, + options.skip_encoder_fast_path)); + } + // Make sure we actually wrote all tokens + JXL_CHECK(tokenp == tokens->data() + tokens->size()); + } + + // Write data if not using a global tree/ANS stream. + if (!header->use_global_tree) { + EntropyEncodingData code; + std::vector<uint8_t> context_map; + HistogramParams histo_params; + histo_params.image_widths.push_back(image_width); + BuildAndEncodeHistograms(histo_params, (tree->size() + 1) / 2, + tokens_storage, &code, &context_map, writer, layer, + aux_out); + WriteTokens(tokens_storage[0], code, context_map, writer, layer, aux_out); + } else { + *width = image_width; + } + return true; +} + +Status ModularGenericCompress(Image &image, const ModularOptions &opts, + BitWriter *writer, AuxOut *aux_out, size_t layer, + size_t group_id, TreeSamples *tree_samples, + size_t *total_pixels, const Tree *tree, + GroupHeader *header, std::vector<Token> *tokens, + size_t *width) { + if (image.w == 0 || image.h == 0) return true; + ModularOptions options = opts; // Make a copy to modify it. + + if (options.predictor == static_cast<Predictor>(-1)) { + options.predictor = Predictor::Gradient; + } + + size_t bits = writer ? writer->BitsWritten() : 0; + JXL_RETURN_IF_ERROR(ModularEncode(image, options, writer, aux_out, layer, + group_id, tree_samples, total_pixels, tree, + header, tokens, width)); + bits = writer ? writer->BitsWritten() - bits : 0; + if (writer) { + JXL_DEBUG_V(4, + "Modular-encoded a %" PRIuS "x%" PRIuS + " bitdepth=%i nbchans=%" PRIuS " image in %" PRIuS " bytes", + image.w, image.h, image.bitdepth, image.channel.size(), + bits / 8); + } + (void)bits; + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_encoding.h b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_encoding.h new file mode 100644 index 0000000000..04df504750 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_encoding.h @@ -0,0 +1,47 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_ENCODING_ENC_ENCODING_H_ +#define LIB_JXL_MODULAR_ENCODING_ENC_ENCODING_H_ + +#include <stddef.h> +#include <stdint.h> + +#include <vector> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/image.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/encoding/enc_ma.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/options.h" +#include "lib/jxl/modular/transform/transform.h" + +namespace jxl { + +Tree LearnTree(TreeSamples &&tree_samples, size_t total_pixels, + const ModularOptions &options, + const std::vector<ModularMultiplierInfo> &multiplier_info = {}, + StaticPropRange static_prop_range = {}); + +// TODO(veluca): make cleaner interfaces. + +Status ModularGenericCompress( + Image &image, const ModularOptions &opts, BitWriter *writer, + AuxOut *aux_out = nullptr, size_t layer = 0, size_t group_id = 0, + // For gathering data for producing a global tree. + TreeSamples *tree_samples = nullptr, size_t *total_pixels = nullptr, + // For encoding with global tree. + const Tree *tree = nullptr, GroupHeader *header = nullptr, + std::vector<Token> *tokens = nullptr, size_t *widths = nullptr); +} // namespace jxl + +#endif // LIB_JXL_MODULAR_ENCODING_ENC_ENCODING_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_ma.cc b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_ma.cc new file mode 100644 index 0000000000..d0f6b47566 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_ma.cc @@ -0,0 +1,1023 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/encoding/enc_ma.h" + +#include <algorithm> +#include <limits> +#include <numeric> +#include <queue> +#include <unordered_map> +#include <unordered_set> + +#include "lib/jxl/modular/encoding/ma_common.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/modular/encoding/enc_ma.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/base/random.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/fast_math-inl.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/options.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Eq; +using hwy::HWY_NAMESPACE::IfThenElse; +using hwy::HWY_NAMESPACE::Lt; + +const HWY_FULL(float) df; +const HWY_FULL(int32_t) di; +size_t Padded(size_t x) { return RoundUpTo(x, Lanes(df)); } + +float EstimateBits(const int32_t *counts, int32_t *rounded_counts, + size_t num_symbols) { + // Try to approximate the effect of rounding up nonzero probabilities. + int32_t total = std::accumulate(counts, counts + num_symbols, 0); + const auto min = Set(di, (total + ANS_TAB_SIZE - 1) >> ANS_LOG_TAB_SIZE); + const auto zero_i = Zero(di); + for (size_t i = 0; i < num_symbols; i += Lanes(df)) { + auto counts_v = LoadU(di, &counts[i]); + counts_v = IfThenElse(Eq(counts_v, zero_i), zero_i, + IfThenElse(Lt(counts_v, min), min, counts_v)); + StoreU(counts_v, di, &rounded_counts[i]); + } + // Compute entropy of the "rounded" probabilities. + const auto zero = Zero(df); + const size_t total_scalar = + std::accumulate(rounded_counts, rounded_counts + num_symbols, 0); + const auto inv_total = Set(df, 1.0f / total_scalar); + auto bits_lanes = Zero(df); + auto total_v = Set(di, total_scalar); + for (size_t i = 0; i < num_symbols; i += Lanes(df)) { + const auto counts_v = ConvertTo(df, LoadU(di, &counts[i])); + const auto round_counts_v = LoadU(di, &rounded_counts[i]); + const auto probs = Mul(ConvertTo(df, round_counts_v), inv_total); + const auto nbps = IfThenElse(Eq(round_counts_v, total_v), BitCast(di, zero), + BitCast(di, FastLog2f(df, probs))); + bits_lanes = Sub(bits_lanes, IfThenElse(Eq(counts_v, zero), zero, + Mul(counts_v, BitCast(df, nbps)))); + } + return GetLane(SumOfLanes(df, bits_lanes)); +} + +void MakeSplitNode(size_t pos, int property, int splitval, Predictor lpred, + int64_t loff, Predictor rpred, int64_t roff, Tree *tree) { + // Note that the tree splits on *strictly greater*. + (*tree)[pos].lchild = tree->size(); + (*tree)[pos].rchild = tree->size() + 1; + (*tree)[pos].splitval = splitval; + (*tree)[pos].property = property; + tree->emplace_back(); + tree->back().property = -1; + tree->back().predictor = rpred; + tree->back().predictor_offset = roff; + tree->back().multiplier = 1; + tree->emplace_back(); + tree->back().property = -1; + tree->back().predictor = lpred; + tree->back().predictor_offset = loff; + tree->back().multiplier = 1; +} + +enum class IntersectionType { kNone, kPartial, kInside }; +IntersectionType BoxIntersects(StaticPropRange needle, StaticPropRange haystack, + uint32_t &partial_axis, uint32_t &partial_val) { + bool partial = false; + for (size_t i = 0; i < kNumStaticProperties; i++) { + if (haystack[i][0] >= needle[i][1]) { + return IntersectionType::kNone; + } + if (haystack[i][1] <= needle[i][0]) { + return IntersectionType::kNone; + } + if (haystack[i][0] <= needle[i][0] && haystack[i][1] >= needle[i][1]) { + continue; + } + partial = true; + partial_axis = i; + if (haystack[i][0] > needle[i][0] && haystack[i][0] < needle[i][1]) { + partial_val = haystack[i][0] - 1; + } else { + JXL_DASSERT(haystack[i][1] > needle[i][0] && + haystack[i][1] < needle[i][1]); + partial_val = haystack[i][1] - 1; + } + } + return partial ? IntersectionType::kPartial : IntersectionType::kInside; +} + +void SplitTreeSamples(TreeSamples &tree_samples, size_t begin, size_t pos, + size_t end, size_t prop) { + auto cmp = [&](size_t a, size_t b) { + return int32_t(tree_samples.Property(prop, a)) - + int32_t(tree_samples.Property(prop, b)); + }; + Rng rng(0); + while (end > begin + 1) { + { + size_t pivot = rng.UniformU(begin, end); + tree_samples.Swap(begin, pivot); + } + size_t pivot_begin = begin; + size_t pivot_end = pivot_begin + 1; + for (size_t i = begin + 1; i < end; i++) { + JXL_DASSERT(i >= pivot_end); + JXL_DASSERT(pivot_end > pivot_begin); + int32_t cmp_result = cmp(i, pivot_begin); + if (cmp_result < 0) { // i < pivot, move pivot forward and put i before + // the pivot. + tree_samples.ThreeShuffle(pivot_begin, pivot_end, i); + pivot_begin++; + pivot_end++; + } else if (cmp_result == 0) { + tree_samples.Swap(pivot_end, i); + pivot_end++; + } + } + JXL_DASSERT(pivot_begin >= begin); + JXL_DASSERT(pivot_end > pivot_begin); + JXL_DASSERT(pivot_end <= end); + for (size_t i = begin; i < pivot_begin; i++) { + JXL_DASSERT(cmp(i, pivot_begin) < 0); + } + for (size_t i = pivot_end; i < end; i++) { + JXL_DASSERT(cmp(i, pivot_begin) > 0); + } + for (size_t i = pivot_begin; i < pivot_end; i++) { + JXL_DASSERT(cmp(i, pivot_begin) == 0); + } + // We now have that [begin, pivot_begin) is < pivot, [pivot_begin, + // pivot_end) is = pivot, and [pivot_end, end) is > pivot. + // If pos falls in the first or the last interval, we continue in that + // interval; otherwise, we are done. + if (pivot_begin > pos) { + end = pivot_begin; + } else if (pivot_end < pos) { + begin = pivot_end; + } else { + break; + } + } +} + +void FindBestSplit(TreeSamples &tree_samples, float threshold, + const std::vector<ModularMultiplierInfo> &mul_info, + StaticPropRange initial_static_prop_range, + float fast_decode_multiplier, Tree *tree) { + struct NodeInfo { + size_t pos; + size_t begin; + size_t end; + uint64_t used_properties; + StaticPropRange static_prop_range; + }; + std::vector<NodeInfo> nodes; + nodes.push_back(NodeInfo{0, 0, tree_samples.NumDistinctSamples(), 0, + initial_static_prop_range}); + + size_t num_predictors = tree_samples.NumPredictors(); + size_t num_properties = tree_samples.NumProperties(); + + // TODO(veluca): consider parallelizing the search (processing multiple nodes + // at a time). + while (!nodes.empty()) { + size_t pos = nodes.back().pos; + size_t begin = nodes.back().begin; + size_t end = nodes.back().end; + uint64_t used_properties = nodes.back().used_properties; + StaticPropRange static_prop_range = nodes.back().static_prop_range; + nodes.pop_back(); + if (begin == end) continue; + + struct SplitInfo { + size_t prop = 0; + uint32_t val = 0; + size_t pos = 0; + float lcost = std::numeric_limits<float>::max(); + float rcost = std::numeric_limits<float>::max(); + Predictor lpred = Predictor::Zero; + Predictor rpred = Predictor::Zero; + float Cost() { return lcost + rcost; } + }; + + SplitInfo best_split_static_constant; + SplitInfo best_split_static; + SplitInfo best_split_nonstatic; + SplitInfo best_split_nowp; + + JXL_DASSERT(begin <= end); + JXL_DASSERT(end <= tree_samples.NumDistinctSamples()); + + // Compute the maximum token in the range. + size_t max_symbols = 0; + for (size_t pred = 0; pred < num_predictors; pred++) { + for (size_t i = begin; i < end; i++) { + uint32_t tok = tree_samples.Token(pred, i); + max_symbols = max_symbols > tok + 1 ? max_symbols : tok + 1; + } + } + max_symbols = Padded(max_symbols); + std::vector<int32_t> rounded_counts(max_symbols); + std::vector<int32_t> counts(max_symbols * num_predictors); + std::vector<uint32_t> tot_extra_bits(num_predictors); + for (size_t pred = 0; pred < num_predictors; pred++) { + for (size_t i = begin; i < end; i++) { + counts[pred * max_symbols + tree_samples.Token(pred, i)] += + tree_samples.Count(i); + tot_extra_bits[pred] += + tree_samples.NBits(pred, i) * tree_samples.Count(i); + } + } + + float base_bits; + { + size_t pred = tree_samples.PredictorIndex((*tree)[pos].predictor); + base_bits = EstimateBits(counts.data() + pred * max_symbols, + rounded_counts.data(), max_symbols) + + tot_extra_bits[pred]; + } + + SplitInfo *best = &best_split_nonstatic; + + SplitInfo forced_split; + // The multiplier ranges cut halfway through the current ranges of static + // properties. We do this even if the current node is not a leaf, to + // minimize the number of nodes in the resulting tree. + for (size_t i = 0; i < mul_info.size(); i++) { + uint32_t axis, val; + IntersectionType t = + BoxIntersects(static_prop_range, mul_info[i].range, axis, val); + if (t == IntersectionType::kNone) continue; + if (t == IntersectionType::kInside) { + (*tree)[pos].multiplier = mul_info[i].multiplier; + break; + } + if (t == IntersectionType::kPartial) { + forced_split.val = tree_samples.QuantizeProperty(axis, val); + forced_split.prop = axis; + forced_split.lcost = forced_split.rcost = base_bits / 2 - threshold; + forced_split.lpred = forced_split.rpred = (*tree)[pos].predictor; + best = &forced_split; + best->pos = begin; + JXL_ASSERT(best->prop == tree_samples.PropertyFromIndex(best->prop)); + for (size_t x = begin; x < end; x++) { + if (tree_samples.Property(best->prop, x) <= best->val) { + best->pos++; + } + } + break; + } + } + + if (best != &forced_split) { + std::vector<int> prop_value_used_count; + std::vector<int> count_increase; + std::vector<size_t> extra_bits_increase; + // For each property, compute which of its values are used, and what + // tokens correspond to those usages. Then, iterate through the values, + // and compute the entropy of each side of the split (of the form `prop > + // threshold`). Finally, find the split that minimizes the cost. + struct CostInfo { + float cost = std::numeric_limits<float>::max(); + float extra_cost = 0; + float Cost() const { return cost + extra_cost; } + Predictor pred; // will be uninitialized in some cases, but never used. + }; + std::vector<CostInfo> costs_l; + std::vector<CostInfo> costs_r; + + std::vector<int32_t> counts_above(max_symbols); + std::vector<int32_t> counts_below(max_symbols); + + // The lower the threshold, the higher the expected noisiness of the + // estimate. Thus, discourage changing predictors. + float change_pred_penalty = 800.0f / (100.0f + threshold); + for (size_t prop = 0; prop < num_properties && base_bits > threshold; + prop++) { + costs_l.clear(); + costs_r.clear(); + size_t prop_size = tree_samples.NumPropertyValues(prop); + if (extra_bits_increase.size() < prop_size) { + count_increase.resize(prop_size * max_symbols); + extra_bits_increase.resize(prop_size); + } + // Clear prop_value_used_count (which cannot be cleared "on the go") + prop_value_used_count.clear(); + prop_value_used_count.resize(prop_size); + + size_t first_used = prop_size; + size_t last_used = 0; + + // TODO(veluca): consider finding multiple splits along a single + // property at the same time, possibly with a bottom-up approach. + for (size_t i = begin; i < end; i++) { + size_t p = tree_samples.Property(prop, i); + prop_value_used_count[p]++; + last_used = std::max(last_used, p); + first_used = std::min(first_used, p); + } + costs_l.resize(last_used - first_used); + costs_r.resize(last_used - first_used); + // For all predictors, compute the right and left costs of each split. + for (size_t pred = 0; pred < num_predictors; pred++) { + // Compute cost and histogram increments for each property value. + for (size_t i = begin; i < end; i++) { + size_t p = tree_samples.Property(prop, i); + size_t cnt = tree_samples.Count(i); + size_t sym = tree_samples.Token(pred, i); + count_increase[p * max_symbols + sym] += cnt; + extra_bits_increase[p] += tree_samples.NBits(pred, i) * cnt; + } + memcpy(counts_above.data(), counts.data() + pred * max_symbols, + max_symbols * sizeof counts_above[0]); + memset(counts_below.data(), 0, max_symbols * sizeof counts_below[0]); + size_t extra_bits_below = 0; + // Exclude last used: this ensures neither counts_above nor + // counts_below is empty. + for (size_t i = first_used; i < last_used; i++) { + if (!prop_value_used_count[i]) continue; + extra_bits_below += extra_bits_increase[i]; + // The increase for this property value has been used, and will not + // be used again: clear it. Also below. + extra_bits_increase[i] = 0; + for (size_t sym = 0; sym < max_symbols; sym++) { + counts_above[sym] -= count_increase[i * max_symbols + sym]; + counts_below[sym] += count_increase[i * max_symbols + sym]; + count_increase[i * max_symbols + sym] = 0; + } + float rcost = EstimateBits(counts_above.data(), + rounded_counts.data(), max_symbols) + + tot_extra_bits[pred] - extra_bits_below; + float lcost = EstimateBits(counts_below.data(), + rounded_counts.data(), max_symbols) + + extra_bits_below; + JXL_DASSERT(extra_bits_below <= tot_extra_bits[pred]); + float penalty = 0; + // Never discourage moving away from the Weighted predictor. + if (tree_samples.PredictorFromIndex(pred) != + (*tree)[pos].predictor && + (*tree)[pos].predictor != Predictor::Weighted) { + penalty = change_pred_penalty; + } + // If everything else is equal, disfavour Weighted (slower) and + // favour Zero (faster if it's the only predictor used in a + // group+channel combination) + if (tree_samples.PredictorFromIndex(pred) == Predictor::Weighted) { + penalty += 1e-8; + } + if (tree_samples.PredictorFromIndex(pred) == Predictor::Zero) { + penalty -= 1e-8; + } + if (rcost + penalty < costs_r[i - first_used].Cost()) { + costs_r[i - first_used].cost = rcost; + costs_r[i - first_used].extra_cost = penalty; + costs_r[i - first_used].pred = + tree_samples.PredictorFromIndex(pred); + } + if (lcost + penalty < costs_l[i - first_used].Cost()) { + costs_l[i - first_used].cost = lcost; + costs_l[i - first_used].extra_cost = penalty; + costs_l[i - first_used].pred = + tree_samples.PredictorFromIndex(pred); + } + } + } + // Iterate through the possible splits and find the one with minimum sum + // of costs of the two sides. + size_t split = begin; + for (size_t i = first_used; i < last_used; i++) { + if (!prop_value_used_count[i]) continue; + split += prop_value_used_count[i]; + float rcost = costs_r[i - first_used].cost; + float lcost = costs_l[i - first_used].cost; + // WP was not used + we would use the WP property or predictor + bool adds_wp = + (tree_samples.PropertyFromIndex(prop) == kWPProp && + (used_properties & (1LU << prop)) == 0) || + ((costs_l[i - first_used].pred == Predictor::Weighted || + costs_r[i - first_used].pred == Predictor::Weighted) && + (*tree)[pos].predictor != Predictor::Weighted); + bool zero_entropy_side = rcost == 0 || lcost == 0; + + SplitInfo &best = + prop < kNumStaticProperties + ? (zero_entropy_side ? best_split_static_constant + : best_split_static) + : (adds_wp ? best_split_nonstatic : best_split_nowp); + if (lcost + rcost < best.Cost()) { + best.prop = prop; + best.val = i; + best.pos = split; + best.lcost = lcost; + best.lpred = costs_l[i - first_used].pred; + best.rcost = rcost; + best.rpred = costs_r[i - first_used].pred; + } + } + // Clear extra_bits_increase and cost_increase for last_used. + extra_bits_increase[last_used] = 0; + for (size_t sym = 0; sym < max_symbols; sym++) { + count_increase[last_used * max_symbols + sym] = 0; + } + } + + // Try to avoid introducing WP. + if (best_split_nowp.Cost() + threshold < base_bits && + best_split_nowp.Cost() <= fast_decode_multiplier * best->Cost()) { + best = &best_split_nowp; + } + // Split along static props if possible and not significantly more + // expensive. + if (best_split_static.Cost() + threshold < base_bits && + best_split_static.Cost() <= fast_decode_multiplier * best->Cost()) { + best = &best_split_static; + } + // Split along static props to create constant nodes if possible. + if (best_split_static_constant.Cost() + threshold < base_bits) { + best = &best_split_static_constant; + } + } + + if (best->Cost() + threshold < base_bits) { + uint32_t p = tree_samples.PropertyFromIndex(best->prop); + pixel_type dequant = + tree_samples.UnquantizeProperty(best->prop, best->val); + // Split node and try to split children. + MakeSplitNode(pos, p, dequant, best->lpred, 0, best->rpred, 0, tree); + // "Sort" according to winning property + SplitTreeSamples(tree_samples, begin, best->pos, end, best->prop); + if (p >= kNumStaticProperties) { + used_properties |= 1 << best->prop; + } + auto new_sp_range = static_prop_range; + if (p < kNumStaticProperties) { + JXL_ASSERT(static_cast<uint32_t>(dequant + 1) <= new_sp_range[p][1]); + new_sp_range[p][1] = dequant + 1; + JXL_ASSERT(new_sp_range[p][0] < new_sp_range[p][1]); + } + nodes.push_back(NodeInfo{(*tree)[pos].rchild, begin, best->pos, + used_properties, new_sp_range}); + new_sp_range = static_prop_range; + if (p < kNumStaticProperties) { + JXL_ASSERT(new_sp_range[p][0] <= static_cast<uint32_t>(dequant + 1)); + new_sp_range[p][0] = dequant + 1; + JXL_ASSERT(new_sp_range[p][0] < new_sp_range[p][1]); + } + nodes.push_back(NodeInfo{(*tree)[pos].lchild, best->pos, end, + used_properties, new_sp_range}); + } + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(FindBestSplit); // Local function. + +void ComputeBestTree(TreeSamples &tree_samples, float threshold, + const std::vector<ModularMultiplierInfo> &mul_info, + StaticPropRange static_prop_range, + float fast_decode_multiplier, Tree *tree) { + // TODO(veluca): take into account that different contexts can have different + // uint configs. + // + // Initialize tree. + tree->emplace_back(); + tree->back().property = -1; + tree->back().predictor = tree_samples.PredictorFromIndex(0); + tree->back().predictor_offset = 0; + tree->back().multiplier = 1; + JXL_ASSERT(tree_samples.NumProperties() < 64); + + JXL_ASSERT(tree_samples.NumDistinctSamples() <= + std::numeric_limits<uint32_t>::max()); + HWY_DYNAMIC_DISPATCH(FindBestSplit) + (tree_samples, threshold, mul_info, static_prop_range, fast_decode_multiplier, + tree); +} + +constexpr int32_t TreeSamples::kPropertyRange; +constexpr uint32_t TreeSamples::kDedupEntryUnused; + +Status TreeSamples::SetPredictor(Predictor predictor, + ModularOptions::TreeMode wp_tree_mode) { + if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) { + predictors = {Predictor::Weighted}; + residuals.resize(1); + return true; + } + if (wp_tree_mode == ModularOptions::TreeMode::kNoWP && + predictor == Predictor::Weighted) { + return JXL_FAILURE("Invalid predictor settings"); + } + if (predictor == Predictor::Variable) { + for (size_t i = 0; i < kNumModularPredictors; i++) { + predictors.push_back(static_cast<Predictor>(i)); + } + std::swap(predictors[0], predictors[static_cast<int>(Predictor::Weighted)]); + std::swap(predictors[1], predictors[static_cast<int>(Predictor::Gradient)]); + } else if (predictor == Predictor::Best) { + predictors = {Predictor::Weighted, Predictor::Gradient}; + } else { + predictors = {predictor}; + } + if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) { + auto wp_it = + std::find(predictors.begin(), predictors.end(), Predictor::Weighted); + if (wp_it != predictors.end()) { + predictors.erase(wp_it); + } + } + residuals.resize(predictors.size()); + return true; +} + +Status TreeSamples::SetProperties(const std::vector<uint32_t> &properties, + ModularOptions::TreeMode wp_tree_mode) { + props_to_use = properties; + if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) { + props_to_use = {static_cast<uint32_t>(kWPProp)}; + } + if (wp_tree_mode == ModularOptions::TreeMode::kGradientOnly) { + props_to_use = {static_cast<uint32_t>(kGradientProp)}; + } + if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) { + auto it = std::find(props_to_use.begin(), props_to_use.end(), kWPProp); + if (it != props_to_use.end()) { + props_to_use.erase(it); + } + } + if (props_to_use.empty()) { + return JXL_FAILURE("Invalid property set configuration"); + } + props.resize(props_to_use.size()); + return true; +} + +void TreeSamples::InitTable(size_t size) { + JXL_DASSERT((size & (size - 1)) == 0); + if (dedup_table_.size() == size) return; + dedup_table_.resize(size, kDedupEntryUnused); + for (size_t i = 0; i < NumDistinctSamples(); i++) { + if (sample_counts[i] != std::numeric_limits<uint16_t>::max()) { + AddToTable(i); + } + } +} + +bool TreeSamples::AddToTableAndMerge(size_t a) { + size_t pos1 = Hash1(a); + size_t pos2 = Hash2(a); + if (dedup_table_[pos1] != kDedupEntryUnused && + IsSameSample(a, dedup_table_[pos1])) { + JXL_DASSERT(sample_counts[a] == 1); + sample_counts[dedup_table_[pos1]]++; + // Remove from hash table samples that are saturated. + if (sample_counts[dedup_table_[pos1]] == + std::numeric_limits<uint16_t>::max()) { + dedup_table_[pos1] = kDedupEntryUnused; + } + return true; + } + if (dedup_table_[pos2] != kDedupEntryUnused && + IsSameSample(a, dedup_table_[pos2])) { + JXL_DASSERT(sample_counts[a] == 1); + sample_counts[dedup_table_[pos2]]++; + // Remove from hash table samples that are saturated. + if (sample_counts[dedup_table_[pos2]] == + std::numeric_limits<uint16_t>::max()) { + dedup_table_[pos2] = kDedupEntryUnused; + } + return true; + } + AddToTable(a); + return false; +} + +void TreeSamples::AddToTable(size_t a) { + size_t pos1 = Hash1(a); + size_t pos2 = Hash2(a); + if (dedup_table_[pos1] == kDedupEntryUnused) { + dedup_table_[pos1] = a; + } else if (dedup_table_[pos2] == kDedupEntryUnused) { + dedup_table_[pos2] = a; + } +} + +void TreeSamples::PrepareForSamples(size_t num_samples) { + for (auto &res : residuals) { + res.reserve(res.size() + num_samples); + } + for (auto &p : props) { + p.reserve(p.size() + num_samples); + } + size_t total_num_samples = num_samples + sample_counts.size(); + size_t next_pow2 = 1LLU << CeilLog2Nonzero(total_num_samples * 3 / 2); + InitTable(next_pow2); +} + +size_t TreeSamples::Hash1(size_t a) const { + constexpr uint64_t constant = 0x1e35a7bd; + uint64_t h = constant; + for (const auto &r : residuals) { + h = h * constant + r[a].tok; + h = h * constant + r[a].nbits; + } + for (const auto &p : props) { + h = h * constant + p[a]; + } + return (h >> 16) & (dedup_table_.size() - 1); +} +size_t TreeSamples::Hash2(size_t a) const { + constexpr uint64_t constant = 0x1e35a7bd1e35a7bd; + uint64_t h = constant; + for (const auto &p : props) { + h = h * constant ^ p[a]; + } + for (const auto &r : residuals) { + h = h * constant ^ r[a].tok; + h = h * constant ^ r[a].nbits; + } + return (h >> 16) & (dedup_table_.size() - 1); +} + +bool TreeSamples::IsSameSample(size_t a, size_t b) const { + bool ret = true; + for (const auto &r : residuals) { + if (r[a].tok != r[b].tok) { + ret = false; + } + if (r[a].nbits != r[b].nbits) { + ret = false; + } + } + for (const auto &p : props) { + if (p[a] != p[b]) { + ret = false; + } + } + return ret; +} + +void TreeSamples::AddSample(pixel_type_w pixel, const Properties &properties, + const pixel_type_w *predictions) { + for (size_t i = 0; i < predictors.size(); i++) { + pixel_type v = pixel - predictions[static_cast<int>(predictors[i])]; + uint32_t tok, nbits, bits; + HybridUintConfig(4, 1, 2).Encode(PackSigned(v), &tok, &nbits, &bits); + JXL_DASSERT(tok < 256); + JXL_DASSERT(nbits < 256); + residuals[i].emplace_back( + ResidualToken{static_cast<uint8_t>(tok), static_cast<uint8_t>(nbits)}); + } + for (size_t i = 0; i < props_to_use.size(); i++) { + props[i].push_back(QuantizeProperty(i, properties[props_to_use[i]])); + } + sample_counts.push_back(1); + num_samples++; + if (AddToTableAndMerge(sample_counts.size() - 1)) { + for (auto &r : residuals) r.pop_back(); + for (auto &p : props) p.pop_back(); + sample_counts.pop_back(); + } +} + +void TreeSamples::Swap(size_t a, size_t b) { + if (a == b) return; + for (auto &r : residuals) { + std::swap(r[a], r[b]); + } + for (auto &p : props) { + std::swap(p[a], p[b]); + } + std::swap(sample_counts[a], sample_counts[b]); +} + +void TreeSamples::ThreeShuffle(size_t a, size_t b, size_t c) { + if (b == c) return Swap(a, b); + for (auto &r : residuals) { + auto tmp = r[a]; + r[a] = r[c]; + r[c] = r[b]; + r[b] = tmp; + } + for (auto &p : props) { + auto tmp = p[a]; + p[a] = p[c]; + p[c] = p[b]; + p[b] = tmp; + } + auto tmp = sample_counts[a]; + sample_counts[a] = sample_counts[c]; + sample_counts[c] = sample_counts[b]; + sample_counts[b] = tmp; +} + +namespace { +std::vector<int32_t> QuantizeHistogram(const std::vector<uint32_t> &histogram, + size_t num_chunks) { + if (histogram.empty()) return {}; + // TODO(veluca): selecting distinct quantiles is likely not the best + // way to go about this. + std::vector<int32_t> thresholds; + size_t sum = std::accumulate(histogram.begin(), histogram.end(), 0LU); + size_t cumsum = 0; + size_t threshold = 0; + for (size_t i = 0; i + 1 < histogram.size(); i++) { + cumsum += histogram[i]; + if (cumsum > (threshold + 1) * sum / num_chunks) { + thresholds.push_back(i); + while (cumsum >= (threshold + 1) * sum / num_chunks) threshold++; + } + } + return thresholds; +} + +std::vector<int32_t> QuantizeSamples(const std::vector<int32_t> &samples, + size_t num_chunks) { + if (samples.empty()) return {}; + int min = *std::min_element(samples.begin(), samples.end()); + constexpr int kRange = 512; + min = std::min(std::max(min, -kRange), kRange); + std::vector<uint32_t> counts(2 * kRange + 1); + for (int s : samples) { + uint32_t sample_offset = std::min(std::max(s, -kRange), kRange) - min; + counts[sample_offset]++; + } + std::vector<int32_t> thresholds = QuantizeHistogram(counts, num_chunks); + for (auto &v : thresholds) v += min; + return thresholds; +} +} // namespace + +void TreeSamples::PreQuantizeProperties( + const StaticPropRange &range, + const std::vector<ModularMultiplierInfo> &multiplier_info, + const std::vector<uint32_t> &group_pixel_count, + const std::vector<uint32_t> &channel_pixel_count, + std::vector<pixel_type> &pixel_samples, + std::vector<pixel_type> &diff_samples, size_t max_property_values) { + // If we have forced splits because of multipliers, choose channel and group + // thresholds accordingly. + std::vector<int32_t> group_multiplier_thresholds; + std::vector<int32_t> channel_multiplier_thresholds; + for (const auto &v : multiplier_info) { + if (v.range[0][0] != range[0][0]) { + channel_multiplier_thresholds.push_back(v.range[0][0] - 1); + } + if (v.range[0][1] != range[0][1]) { + channel_multiplier_thresholds.push_back(v.range[0][1] - 1); + } + if (v.range[1][0] != range[1][0]) { + group_multiplier_thresholds.push_back(v.range[1][0] - 1); + } + if (v.range[1][1] != range[1][1]) { + group_multiplier_thresholds.push_back(v.range[1][1] - 1); + } + } + std::sort(channel_multiplier_thresholds.begin(), + channel_multiplier_thresholds.end()); + channel_multiplier_thresholds.resize( + std::unique(channel_multiplier_thresholds.begin(), + channel_multiplier_thresholds.end()) - + channel_multiplier_thresholds.begin()); + std::sort(group_multiplier_thresholds.begin(), + group_multiplier_thresholds.end()); + group_multiplier_thresholds.resize( + std::unique(group_multiplier_thresholds.begin(), + group_multiplier_thresholds.end()) - + group_multiplier_thresholds.begin()); + + compact_properties.resize(props_to_use.size()); + auto quantize_channel = [&]() { + if (!channel_multiplier_thresholds.empty()) { + return channel_multiplier_thresholds; + } + return QuantizeHistogram(channel_pixel_count, max_property_values); + }; + auto quantize_group_id = [&]() { + if (!group_multiplier_thresholds.empty()) { + return group_multiplier_thresholds; + } + return QuantizeHistogram(group_pixel_count, max_property_values); + }; + auto quantize_coordinate = [&]() { + std::vector<int32_t> quantized; + quantized.reserve(max_property_values - 1); + for (size_t i = 0; i + 1 < max_property_values; i++) { + quantized.push_back((i + 1) * 256 / max_property_values - 1); + } + return quantized; + }; + std::vector<int32_t> abs_pixel_thr; + std::vector<int32_t> pixel_thr; + auto quantize_pixel_property = [&]() { + if (pixel_thr.empty()) { + pixel_thr = QuantizeSamples(pixel_samples, max_property_values); + } + return pixel_thr; + }; + auto quantize_abs_pixel_property = [&]() { + if (abs_pixel_thr.empty()) { + quantize_pixel_property(); // Compute the non-abs thresholds. + for (auto &v : pixel_samples) v = std::abs(v); + abs_pixel_thr = QuantizeSamples(pixel_samples, max_property_values); + } + return abs_pixel_thr; + }; + std::vector<int32_t> abs_diff_thr; + std::vector<int32_t> diff_thr; + auto quantize_diff_property = [&]() { + if (diff_thr.empty()) { + diff_thr = QuantizeSamples(diff_samples, max_property_values); + } + return diff_thr; + }; + auto quantize_abs_diff_property = [&]() { + if (abs_diff_thr.empty()) { + quantize_diff_property(); // Compute the non-abs thresholds. + for (auto &v : diff_samples) v = std::abs(v); + abs_diff_thr = QuantizeSamples(diff_samples, max_property_values); + } + return abs_diff_thr; + }; + auto quantize_wp = [&]() { + if (max_property_values < 32) { + return std::vector<int32_t>{-127, -63, -31, -15, -7, -3, -1, 0, + 1, 3, 7, 15, 31, 63, 127}; + } + if (max_property_values < 64) { + return std::vector<int32_t>{-255, -191, -127, -95, -63, -47, -31, -23, + -15, -11, -7, -5, -3, -1, 0, 1, + 3, 5, 7, 11, 15, 23, 31, 47, + 63, 95, 127, 191, 255}; + } + return std::vector<int32_t>{ + -255, -223, -191, -159, -127, -111, -95, -79, -63, -55, -47, + -39, -31, -27, -23, -19, -15, -13, -11, -9, -7, -6, + -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, + 6, 7, 9, 11, 13, 15, 19, 23, 27, 31, 39, + 47, 55, 63, 79, 95, 111, 127, 159, 191, 223, 255}; + }; + + property_mapping.resize(props_to_use.size()); + for (size_t i = 0; i < props_to_use.size(); i++) { + if (props_to_use[i] == 0) { + compact_properties[i] = quantize_channel(); + } else if (props_to_use[i] == 1) { + compact_properties[i] = quantize_group_id(); + } else if (props_to_use[i] == 2 || props_to_use[i] == 3) { + compact_properties[i] = quantize_coordinate(); + } else if (props_to_use[i] == 6 || props_to_use[i] == 7 || + props_to_use[i] == 8 || + (props_to_use[i] >= kNumNonrefProperties && + (props_to_use[i] - kNumNonrefProperties) % 4 == 1)) { + compact_properties[i] = quantize_pixel_property(); + } else if (props_to_use[i] == 4 || props_to_use[i] == 5 || + (props_to_use[i] >= kNumNonrefProperties && + (props_to_use[i] - kNumNonrefProperties) % 4 == 0)) { + compact_properties[i] = quantize_abs_pixel_property(); + } else if (props_to_use[i] >= kNumNonrefProperties && + (props_to_use[i] - kNumNonrefProperties) % 4 == 2) { + compact_properties[i] = quantize_abs_diff_property(); + } else if (props_to_use[i] == kWPProp) { + compact_properties[i] = quantize_wp(); + } else { + compact_properties[i] = quantize_diff_property(); + } + property_mapping[i].resize(kPropertyRange * 2 + 1); + size_t mapped = 0; + for (size_t j = 0; j < property_mapping[i].size(); j++) { + while (mapped < compact_properties[i].size() && + static_cast<int>(j) - kPropertyRange > + compact_properties[i][mapped]) { + mapped++; + } + // property_mapping[i] of a value V is `mapped` if + // compact_properties[i][mapped] <= j and + // compact_properties[i][mapped-1] > j + // This is because the decision node in the tree splits on (property) > j, + // hence everything that is not > of a threshold should be clustered + // together. + property_mapping[i][j] = mapped; + } + } +} + +void CollectPixelSamples(const Image &image, const ModularOptions &options, + size_t group_id, + std::vector<uint32_t> &group_pixel_count, + std::vector<uint32_t> &channel_pixel_count, + std::vector<pixel_type> &pixel_samples, + std::vector<pixel_type> &diff_samples) { + if (options.nb_repeats == 0) return; + if (group_pixel_count.size() <= group_id) { + group_pixel_count.resize(group_id + 1); + } + if (channel_pixel_count.size() < image.channel.size()) { + channel_pixel_count.resize(image.channel.size()); + } + Rng rng(group_id); + // Sample 10% of the final number of samples for property quantization. + float fraction = std::min(options.nb_repeats * 0.1, 0.99); + Rng::GeometricDistribution dist(fraction); + size_t total_pixels = 0; + std::vector<size_t> channel_ids; + for (size_t i = 0; i < image.channel.size(); i++) { + if (image.channel[i].w <= 1 || image.channel[i].h == 0) { + continue; // skip empty or width-1 channels. + } + if (i >= image.nb_meta_channels && + (image.channel[i].w > options.max_chan_size || + image.channel[i].h > options.max_chan_size)) { + break; + } + channel_ids.push_back(i); + group_pixel_count[group_id] += image.channel[i].w * image.channel[i].h; + channel_pixel_count[i] += image.channel[i].w * image.channel[i].h; + total_pixels += image.channel[i].w * image.channel[i].h; + } + if (channel_ids.empty()) return; + pixel_samples.reserve(pixel_samples.size() + fraction * total_pixels); + diff_samples.reserve(diff_samples.size() + fraction * total_pixels); + size_t i = 0; + size_t y = 0; + size_t x = 0; + auto advance = [&](size_t amount) { + x += amount; + // Detect row overflow (rare). + while (x >= image.channel[channel_ids[i]].w) { + x -= image.channel[channel_ids[i]].w; + y++; + // Detect end-of-channel (even rarer). + if (y == image.channel[channel_ids[i]].h) { + i++; + y = 0; + if (i >= channel_ids.size()) { + return; + } + } + } + }; + advance(rng.Geometric(dist)); + for (; i < channel_ids.size(); advance(rng.Geometric(dist) + 1)) { + const pixel_type *row = image.channel[channel_ids[i]].Row(y); + pixel_samples.push_back(row[x]); + size_t xp = x == 0 ? 1 : x - 1; + diff_samples.push_back((int64_t)row[x] - row[xp]); + } +} + +// TODO(veluca): very simple encoding scheme. This should be improved. +void TokenizeTree(const Tree &tree, std::vector<Token> *tokens, + Tree *decoder_tree) { + JXL_ASSERT(tree.size() <= kMaxTreeSize); + std::queue<int> q; + q.push(0); + size_t leaf_id = 0; + decoder_tree->clear(); + while (!q.empty()) { + int cur = q.front(); + q.pop(); + JXL_ASSERT(tree[cur].property >= -1); + tokens->emplace_back(kPropertyContext, tree[cur].property + 1); + if (tree[cur].property == -1) { + tokens->emplace_back(kPredictorContext, + static_cast<int>(tree[cur].predictor)); + tokens->emplace_back(kOffsetContext, + PackSigned(tree[cur].predictor_offset)); + uint32_t mul_log = Num0BitsBelowLS1Bit_Nonzero(tree[cur].multiplier); + uint32_t mul_bits = (tree[cur].multiplier >> mul_log) - 1; + tokens->emplace_back(kMultiplierLogContext, mul_log); + tokens->emplace_back(kMultiplierBitsContext, mul_bits); + JXL_ASSERT(tree[cur].predictor < Predictor::Best); + decoder_tree->emplace_back(-1, 0, leaf_id++, 0, tree[cur].predictor, + tree[cur].predictor_offset, + tree[cur].multiplier); + continue; + } + decoder_tree->emplace_back(tree[cur].property, tree[cur].splitval, + decoder_tree->size() + q.size() + 1, + decoder_tree->size() + q.size() + 2, + Predictor::Zero, 0, 1); + q.push(tree[cur].lchild); + q.push(tree[cur].rchild); + tokens->emplace_back(kSplitValContext, PackSigned(tree[cur].splitval)); + } +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_ma.h b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_ma.h new file mode 100644 index 0000000000..ede37c8023 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_ma.h @@ -0,0 +1,157 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_ENCODING_ENC_MA_H_ +#define LIB_JXL_MODULAR_ENCODING_ENC_MA_H_ + +#include <numeric> + +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/modular/encoding/dec_ma.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/options.h" + +namespace jxl { + +// Struct to collect all the data needed to build a tree. +struct TreeSamples { + bool HasSamples() const { + return !residuals.empty() && !residuals[0].empty(); + } + size_t NumDistinctSamples() const { return sample_counts.size(); } + size_t NumSamples() const { return num_samples; } + // Set the predictor to use. Must be called before adding any samples. + Status SetPredictor(Predictor predictor, + ModularOptions::TreeMode wp_tree_mode); + // Set the properties to use. Must be called before adding any samples. + Status SetProperties(const std::vector<uint32_t> &properties, + ModularOptions::TreeMode wp_tree_mode); + + size_t Token(size_t pred, size_t i) const { return residuals[pred][i].tok; } + size_t NBits(size_t pred, size_t i) const { return residuals[pred][i].nbits; } + size_t Count(size_t i) const { return sample_counts[i]; } + size_t PredictorIndex(Predictor predictor) const { + const auto predictor_elem = + std::find(predictors.begin(), predictors.end(), predictor); + JXL_DASSERT(predictor_elem != predictors.end()); + return predictor_elem - predictors.begin(); + } + size_t PropertyIndex(size_t property) const { + const auto property_elem = + std::find(props_to_use.begin(), props_to_use.end(), property); + JXL_DASSERT(property_elem != props_to_use.end()); + return property_elem - props_to_use.begin(); + } + size_t NumPropertyValues(size_t property_index) const { + return compact_properties[property_index].size() + 1; + } + // Returns the *quantized* property value. + size_t Property(size_t property_index, size_t i) const { + return props[property_index][i]; + } + int UnquantizeProperty(size_t property_index, uint32_t quant) const { + JXL_ASSERT(quant < compact_properties[property_index].size()); + return compact_properties[property_index][quant]; + } + + Predictor PredictorFromIndex(size_t index) const { + JXL_DASSERT(index < predictors.size()); + return predictors[index]; + } + size_t PropertyFromIndex(size_t index) const { + JXL_DASSERT(index < props_to_use.size()); + return props_to_use[index]; + } + size_t NumPredictors() const { return predictors.size(); } + size_t NumProperties() const { return props_to_use.size(); } + + // Preallocate data for a given number of samples. MUST be called before + // adding any sample. + void PrepareForSamples(size_t num_samples); + // Add a sample. + void AddSample(pixel_type_w pixel, const Properties &properties, + const pixel_type_w *predictions); + // Pre-cluster property values. + void PreQuantizeProperties( + const StaticPropRange &range, + const std::vector<ModularMultiplierInfo> &multiplier_info, + const std::vector<uint32_t> &group_pixel_count, + const std::vector<uint32_t> &channel_pixel_count, + std::vector<pixel_type> &pixel_samples, + std::vector<pixel_type> &diff_samples, size_t max_property_values); + + void AllSamplesDone() { dedup_table_ = std::vector<uint32_t>(); } + + uint32_t QuantizeProperty(uint32_t prop, pixel_type v) const { + v = std::min(std::max(v, -kPropertyRange), kPropertyRange) + kPropertyRange; + return property_mapping[prop][v]; + } + + // Swaps samples in position a and b. Does nothing if a == b. + void Swap(size_t a, size_t b); + + // Cycles samples: a -> b -> c -> a. We assume a <= b <= c, so that we can + // just call Swap(a, b) if b==c. + void ThreeShuffle(size_t a, size_t b, size_t c); + + private: + // TODO(veluca): as the total number of properties and predictors are known + // before adding any samples, it might be better to interleave predictors, + // properties and counts in a single vector to improve locality. + // A first attempt at doing this actually results in much slower encoding, + // possibly because of the more complex addressing. + struct ResidualToken { + uint8_t tok; + uint8_t nbits; + }; + // Residual information: token and number of extra bits, per predictor. + std::vector<std::vector<ResidualToken>> residuals; + // Number of occurrences of each sample. + std::vector<uint16_t> sample_counts; + // Property values, quantized to at most 256 distinct values. + std::vector<std::vector<uint8_t>> props; + // Decompactification info for `props`. + std::vector<std::vector<int32_t>> compact_properties; + // List of properties to use. + std::vector<uint32_t> props_to_use; + // List of predictors to use. + std::vector<Predictor> predictors; + // Mapping property value -> quantized property value. + static constexpr int32_t kPropertyRange = 511; + std::vector<std::vector<uint8_t>> property_mapping; + // Number of samples seen. + size_t num_samples = 0; + // Table for deduplication. + static constexpr uint32_t kDedupEntryUnused{static_cast<uint32_t>(-1)}; + std::vector<uint32_t> dedup_table_; + + // Functions for sample deduplication. + bool IsSameSample(size_t a, size_t b) const; + size_t Hash1(size_t a) const; + size_t Hash2(size_t a) const; + void InitTable(size_t size); + // Returns true if `a` was already present in the table. + bool AddToTableAndMerge(size_t a); + void AddToTable(size_t a); +}; + +void TokenizeTree(const Tree &tree, std::vector<Token> *tokens, + Tree *decoder_tree); + +void CollectPixelSamples(const Image &image, const ModularOptions &options, + size_t group_id, + std::vector<uint32_t> &group_pixel_count, + std::vector<uint32_t> &channel_pixel_count, + std::vector<pixel_type> &pixel_samples, + std::vector<pixel_type> &diff_samples); + +void ComputeBestTree(TreeSamples &tree_samples, float threshold, + const std::vector<ModularMultiplierInfo> &mul_info, + StaticPropRange static_prop_range, + float fast_decode_multiplier, Tree *tree); + +} // namespace jxl +#endif // LIB_JXL_MODULAR_ENCODING_ENC_MA_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/encoding.cc b/third_party/jpeg-xl/lib/jxl/modular/encoding/encoding.cc new file mode 100644 index 0000000000..9d2c3e5cf9 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/encoding.cc @@ -0,0 +1,622 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/encoding/encoding.h" + +#include <stdint.h> +#include <stdlib.h> + +#include <queue> + +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/scope_guard.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/options.h" + +namespace jxl { + +// Removes all nodes that use a static property (i.e. channel or group ID) from +// the tree and collapses each node on even levels with its two children to +// produce a flatter tree. Also computes whether the resulting tree requires +// using the weighted predictor. +FlatTree FilterTree(const Tree &global_tree, + std::array<pixel_type, kNumStaticProperties> &static_props, + size_t *num_props, bool *use_wp, bool *wp_only, + bool *gradient_only) { + *num_props = 0; + bool has_wp = false; + bool has_non_wp = false; + *gradient_only = true; + const auto mark_property = [&](int32_t p) { + if (p == kWPProp) { + has_wp = true; + } else if (p >= kNumStaticProperties) { + has_non_wp = true; + } + if (p >= kNumStaticProperties && p != kGradientProp) { + *gradient_only = false; + } + }; + FlatTree output; + std::queue<size_t> nodes; + nodes.push(0); + // Produces a trimmed and flattened tree by doing a BFS visit of the original + // tree, ignoring branches that are known to be false and proceeding two + // levels at a time to collapse nodes in a flatter tree; if an inner parent + // node has a leaf as a child, the leaf is duplicated and an implicit fake + // node is added. This allows to reduce the number of branches when traversing + // the resulting flat tree. + while (!nodes.empty()) { + size_t cur = nodes.front(); + nodes.pop(); + // Skip nodes that we can decide now, by jumping directly to their children. + while (global_tree[cur].property < kNumStaticProperties && + global_tree[cur].property != -1) { + if (static_props[global_tree[cur].property] > global_tree[cur].splitval) { + cur = global_tree[cur].lchild; + } else { + cur = global_tree[cur].rchild; + } + } + FlatDecisionNode flat; + if (global_tree[cur].property == -1) { + flat.property0 = -1; + flat.childID = global_tree[cur].lchild; + flat.predictor = global_tree[cur].predictor; + flat.predictor_offset = global_tree[cur].predictor_offset; + flat.multiplier = global_tree[cur].multiplier; + *gradient_only &= flat.predictor == Predictor::Gradient; + has_wp |= flat.predictor == Predictor::Weighted; + has_non_wp |= flat.predictor != Predictor::Weighted; + output.push_back(flat); + continue; + } + flat.childID = output.size() + nodes.size() + 1; + + flat.property0 = global_tree[cur].property; + *num_props = std::max<size_t>(flat.property0 + 1, *num_props); + flat.splitval0 = global_tree[cur].splitval; + + for (size_t i = 0; i < 2; i++) { + size_t cur_child = + i == 0 ? global_tree[cur].lchild : global_tree[cur].rchild; + // Skip nodes that we can decide now. + while (global_tree[cur_child].property < kNumStaticProperties && + global_tree[cur_child].property != -1) { + if (static_props[global_tree[cur_child].property] > + global_tree[cur_child].splitval) { + cur_child = global_tree[cur_child].lchild; + } else { + cur_child = global_tree[cur_child].rchild; + } + } + // We ended up in a leaf, add a dummy decision and two copies of the leaf. + if (global_tree[cur_child].property == -1) { + flat.properties[i] = 0; + flat.splitvals[i] = 0; + nodes.push(cur_child); + nodes.push(cur_child); + } else { + flat.properties[i] = global_tree[cur_child].property; + flat.splitvals[i] = global_tree[cur_child].splitval; + nodes.push(global_tree[cur_child].lchild); + nodes.push(global_tree[cur_child].rchild); + *num_props = std::max<size_t>(flat.properties[i] + 1, *num_props); + } + } + + for (size_t j = 0; j < 2; j++) mark_property(flat.properties[j]); + mark_property(flat.property0); + output.push_back(flat); + } + if (*num_props > kNumNonrefProperties) { + *num_props = + DivCeil(*num_props - kNumNonrefProperties, kExtraPropsPerChannel) * + kExtraPropsPerChannel + + kNumNonrefProperties; + } else { + *num_props = kNumNonrefProperties; + } + *use_wp = has_wp; + *wp_only = has_wp && !has_non_wp; + + return output; +} + +Status DecodeModularChannelMAANS(BitReader *br, ANSSymbolReader *reader, + const std::vector<uint8_t> &context_map, + const Tree &global_tree, + const weighted::Header &wp_header, + pixel_type chan, size_t group_id, + Image *image) { + Channel &channel = image->channel[chan]; + + std::array<pixel_type, kNumStaticProperties> static_props = { + {chan, (int)group_id}}; + // TODO(veluca): filter the tree according to static_props. + + // zero pixel channel? could happen + if (channel.w == 0 || channel.h == 0) return true; + + bool tree_has_wp_prop_or_pred = false; + bool is_wp_only = false; + bool is_gradient_only = false; + size_t num_props; + FlatTree tree = + FilterTree(global_tree, static_props, &num_props, + &tree_has_wp_prop_or_pred, &is_wp_only, &is_gradient_only); + + // From here on, tree lookup returns a *clustered* context ID. + // This avoids an extra memory lookup after tree traversal. + for (size_t i = 0; i < tree.size(); i++) { + if (tree[i].property0 == -1) { + tree[i].childID = context_map[tree[i].childID]; + } + } + + JXL_DEBUG_V(3, "Decoded MA tree with %" PRIuS " nodes", tree.size()); + + // MAANS decode + const auto make_pixel = [](uint64_t v, pixel_type multiplier, + pixel_type_w offset) -> pixel_type { + JXL_DASSERT((v & 0xFFFFFFFF) == v); + pixel_type_w val = UnpackSigned(v); + // if it overflows, it overflows, and we have a problem anyway + return val * multiplier + offset; + }; + + if (tree.size() == 1) { + // special optimized case: no meta-adaptation, so no need + // to compute properties. + Predictor predictor = tree[0].predictor; + int64_t offset = tree[0].predictor_offset; + int32_t multiplier = tree[0].multiplier; + size_t ctx_id = tree[0].childID; + if (predictor == Predictor::Zero) { + uint32_t value; + if (reader->IsSingleValueAndAdvance(ctx_id, &value, + channel.w * channel.h)) { + // Special-case: histogram has a single symbol, with no extra bits, and + // we use ANS mode. + JXL_DEBUG_V(8, "Fastest track."); + pixel_type v = make_pixel(value, multiplier, offset); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + std::fill(r, r + channel.w, v); + } + } else { + JXL_DEBUG_V(8, "Fast track."); + if (multiplier == 1 && offset == 0) { + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + uint32_t v = reader->ReadHybridUintClustered(ctx_id, br); + r[x] = UnpackSigned(v); + } + } + } else { + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + uint32_t v = reader->ReadHybridUintClustered(ctx_id, br); + r[x] = make_pixel(v, multiplier, offset); + } + } + } + } + } else if (predictor == Predictor::Gradient && offset == 0 && + multiplier == 1 && reader->HuffRleOnly()) { + JXL_DEBUG_V(8, "Gradient RLE (fjxl) very fast track."); + uint32_t run = 0; + uint32_t v = 0; + pixel_type_w sv = 0; + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + const pixel_type *JXL_RESTRICT rtop = (y ? channel.Row(y - 1) : r - 1); + const pixel_type *JXL_RESTRICT rtopleft = + (y ? channel.Row(y - 1) - 1 : r - 1); + pixel_type_w guess = (y ? rtop[0] : 0); + if (run == 0) { + reader->ReadHybridUintClusteredHuffRleOnly(ctx_id, br, &v, &run); + sv = UnpackSigned(v); + } else { + run--; + } + r[0] = sv + guess; + for (size_t x = 1; x < channel.w; x++) { + pixel_type left = r[x - 1]; + pixel_type top = rtop[x]; + pixel_type topleft = rtopleft[x]; + pixel_type_w guess = ClampedGradient(top, left, topleft); + if (!run) { + reader->ReadHybridUintClusteredHuffRleOnly(ctx_id, br, &v, &run); + sv = UnpackSigned(v); + } else { + run--; + } + r[x] = sv + guess; + } + } + } else if (predictor == Predictor::Gradient && offset == 0 && + multiplier == 1) { + JXL_DEBUG_V(8, "Gradient very fast track."); + const intptr_t onerow = channel.plane.PixelsPerRow(); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + pixel_type left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); + pixel_type top = (y ? *(r + x - onerow) : left); + pixel_type topleft = (x && y ? *(r + x - 1 - onerow) : left); + pixel_type guess = ClampedGradient(top, left, topleft); + uint64_t v = reader->ReadHybridUintClustered(ctx_id, br); + r[x] = make_pixel(v, 1, guess); + } + } + } else if (predictor != Predictor::Weighted) { + // special optimized case: no wp + JXL_DEBUG_V(8, "Quite fast track."); + const intptr_t onerow = channel.plane.PixelsPerRow(); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + PredictionResult pred = + PredictNoTreeNoWP(channel.w, r + x, onerow, x, y, predictor); + pixel_type_w g = pred.guess + offset; + uint64_t v = reader->ReadHybridUintClustered(ctx_id, br); + // NOTE: pred.multiplier is unset. + r[x] = make_pixel(v, multiplier, g); + } + } + } else { + JXL_DEBUG_V(8, "Somewhat fast track."); + const intptr_t onerow = channel.plane.PixelsPerRow(); + weighted::State wp_state(wp_header, channel.w, channel.h); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + pixel_type_w g = PredictNoTreeWP(channel.w, r + x, onerow, x, y, + predictor, &wp_state) + .guess + + offset; + uint64_t v = reader->ReadHybridUintClustered(ctx_id, br); + r[x] = make_pixel(v, multiplier, g); + wp_state.UpdateErrors(r[x], x, y, channel.w); + } + } + } + return true; + } + + // Check if this tree is a WP-only tree with a small enough property value + // range. + // Initialized to avoid clang-tidy complaining. + uint8_t context_lookup[2 * kPropRangeFast] = {}; + int8_t multipliers[2 * kPropRangeFast] = {}; + int8_t offsets[2 * kPropRangeFast] = {}; + if (is_wp_only) { + is_wp_only = TreeToLookupTable(tree, context_lookup, offsets, multipliers); + } + if (is_gradient_only) { + is_gradient_only = + TreeToLookupTable(tree, context_lookup, offsets, multipliers); + } + + if (is_gradient_only) { + JXL_DEBUG_V(8, "Gradient fast track."); + const intptr_t onerow = channel.plane.PixelsPerRow(); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); + pixel_type_w top = (y ? *(r + x - onerow) : left); + pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left); + int32_t guess = ClampedGradient(top, left, topleft); + uint32_t pos = + kPropRangeFast + + std::min<pixel_type_w>( + std::max<pixel_type_w>(-kPropRangeFast, top + left - topleft), + kPropRangeFast - 1); + uint32_t ctx_id = context_lookup[pos]; + uint64_t v = reader->ReadHybridUintClustered(ctx_id, br); + r[x] = make_pixel(v, multipliers[pos], + static_cast<pixel_type_w>(offsets[pos]) + guess); + } + } + } else if (is_wp_only) { + JXL_DEBUG_V(8, "WP fast track."); + const intptr_t onerow = channel.plane.PixelsPerRow(); + weighted::State wp_state(wp_header, channel.w, channel.h); + Properties properties(1); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + size_t offset = 0; + pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); + pixel_type_w top = (y ? *(r + x - onerow) : left); + pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left); + pixel_type_w topright = + (x + 1 < channel.w && y ? *(r + x + 1 - onerow) : top); + pixel_type_w toptop = (y > 1 ? *(r + x - onerow - onerow) : top); + int32_t guess = wp_state.Predict</*compute_properties=*/true>( + x, y, channel.w, top, left, topright, topleft, toptop, &properties, + offset); + uint32_t pos = + kPropRangeFast + std::min(std::max(-kPropRangeFast, properties[0]), + kPropRangeFast - 1); + uint32_t ctx_id = context_lookup[pos]; + uint64_t v = reader->ReadHybridUintClustered(ctx_id, br); + r[x] = make_pixel(v, multipliers[pos], + static_cast<pixel_type_w>(offsets[pos]) + guess); + wp_state.UpdateErrors(r[x], x, y, channel.w); + } + } + } else if (!tree_has_wp_prop_or_pred) { + // special optimized case: the weighted predictor and its properties are not + // used, so no need to compute weights and properties. + JXL_DEBUG_V(8, "Slow track."); + MATreeLookup tree_lookup(tree); + Properties properties = Properties(num_props); + const intptr_t onerow = channel.plane.PixelsPerRow(); + Channel references(properties.size() - kNumNonrefProperties, channel.w); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT p = channel.Row(y); + PrecomputeReferences(channel, y, *image, chan, &references); + InitPropsRow(&properties, static_props, y); + if (y > 1 && channel.w > 8 && references.w == 0) { + for (size_t x = 0; x < 2; x++) { + PredictionResult res = + PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references); + uint64_t v = reader->ReadHybridUintClustered(res.context, br); + p[x] = make_pixel(v, res.multiplier, res.guess); + } + for (size_t x = 2; x < channel.w - 2; x++) { + PredictionResult res = + PredictTreeNoWPNEC(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references); + uint64_t v = reader->ReadHybridUintClustered(res.context, br); + p[x] = make_pixel(v, res.multiplier, res.guess); + } + for (size_t x = channel.w - 2; x < channel.w; x++) { + PredictionResult res = + PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references); + uint64_t v = reader->ReadHybridUintClustered(res.context, br); + p[x] = make_pixel(v, res.multiplier, res.guess); + } + } else { + for (size_t x = 0; x < channel.w; x++) { + PredictionResult res = + PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references); + uint64_t v = reader->ReadHybridUintClustered(res.context, br); + p[x] = make_pixel(v, res.multiplier, res.guess); + } + } + } + } else { + JXL_DEBUG_V(8, "Slowest track."); + MATreeLookup tree_lookup(tree); + Properties properties = Properties(num_props); + const intptr_t onerow = channel.plane.PixelsPerRow(); + Channel references(properties.size() - kNumNonrefProperties, channel.w); + weighted::State wp_state(wp_header, channel.w, channel.h); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT p = channel.Row(y); + InitPropsRow(&properties, static_props, y); + PrecomputeReferences(channel, y, *image, chan, &references); + for (size_t x = 0; x < channel.w; x++) { + PredictionResult res = + PredictTreeWP(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references, &wp_state); + uint64_t v = reader->ReadHybridUintClustered(res.context, br); + p[x] = make_pixel(v, res.multiplier, res.guess); + wp_state.UpdateErrors(p[x], x, y, channel.w); + } + } + } + return true; +} + +GroupHeader::GroupHeader() { Bundle::Init(this); } + +Status ValidateChannelDimensions(const Image &image, + const ModularOptions &options) { + size_t nb_channels = image.channel.size(); + for (bool is_dc : {true, false}) { + size_t group_dim = options.group_dim * (is_dc ? kBlockDim : 1); + size_t c = image.nb_meta_channels; + for (; c < nb_channels; c++) { + const Channel &ch = image.channel[c]; + if (ch.w > options.group_dim || ch.h > options.group_dim) break; + } + for (; c < nb_channels; c++) { + const Channel &ch = image.channel[c]; + if (ch.w == 0 || ch.h == 0) continue; // skip empty + bool is_dc_channel = std::min(ch.hshift, ch.vshift) >= 3; + if (is_dc_channel != is_dc) continue; + size_t tile_dim = group_dim >> std::max(ch.hshift, ch.vshift); + if (tile_dim == 0) { + return JXL_FAILURE("Inconsistent transforms"); + } + } + } + return true; +} + +Status ModularDecode(BitReader *br, Image &image, GroupHeader &header, + size_t group_id, ModularOptions *options, + const Tree *global_tree, const ANSCode *global_code, + const std::vector<uint8_t> *global_ctx_map, + bool allow_truncated_group) { + if (image.channel.empty()) return true; + + // decode transforms + Status status = Bundle::Read(br, &header); + if (!allow_truncated_group) JXL_RETURN_IF_ERROR(status); + if (status.IsFatalError()) return status; + if (!br->AllReadsWithinBounds()) { + // Don't do/undo transforms if header is incomplete. + header.transforms.clear(); + image.transform = header.transforms; + for (size_t c = 0; c < image.channel.size(); c++) { + ZeroFillImage(&image.channel[c].plane); + } + return Status(StatusCode::kNotEnoughBytes); + } + + JXL_DEBUG_V(3, "Image data underwent %" PRIuS " transformations: ", + header.transforms.size()); + image.transform = header.transforms; + for (Transform &transform : image.transform) { + JXL_RETURN_IF_ERROR(transform.MetaApply(image)); + } + if (image.error) { + return JXL_FAILURE("Corrupt file. Aborting."); + } + JXL_RETURN_IF_ERROR(ValidateChannelDimensions(image, *options)); + + size_t nb_channels = image.channel.size(); + + size_t num_chans = 0; + size_t distance_multiplier = 0; + for (size_t i = 0; i < nb_channels; i++) { + Channel &channel = image.channel[i]; + if (!channel.w || !channel.h) { + continue; // skip empty channels + } + if (i >= image.nb_meta_channels && (channel.w > options->max_chan_size || + channel.h > options->max_chan_size)) { + break; + } + if (channel.w > distance_multiplier) { + distance_multiplier = channel.w; + } + num_chans++; + } + if (num_chans == 0) return true; + + size_t next_channel = 0; + auto scope_guard = MakeScopeGuard([&]() { + // Do not do anything if truncated groups are not allowed. + if (!allow_truncated_group) return; + for (size_t c = next_channel; c < nb_channels; c++) { + ZeroFillImage(&image.channel[c].plane); + } + }); + + // Read tree. + Tree tree_storage; + std::vector<uint8_t> context_map_storage; + ANSCode code_storage; + const Tree *tree = &tree_storage; + const ANSCode *code = &code_storage; + const std::vector<uint8_t> *context_map = &context_map_storage; + if (!header.use_global_tree) { + size_t max_tree_size = 1024; + for (size_t i = 0; i < nb_channels; i++) { + Channel &channel = image.channel[i]; + if (!channel.w || !channel.h) { + continue; // skip empty channels + } + if (i >= image.nb_meta_channels && (channel.w > options->max_chan_size || + channel.h > options->max_chan_size)) { + break; + } + size_t pixels = channel.w * channel.h; + if (pixels / channel.w != channel.h) { + return JXL_FAILURE("Tree size overflow"); + } + max_tree_size += pixels; + if (max_tree_size < pixels) return JXL_FAILURE("Tree size overflow"); + } + max_tree_size = std::min(static_cast<size_t>(1 << 20), max_tree_size); + JXL_RETURN_IF_ERROR(DecodeTree(br, &tree_storage, max_tree_size)); + JXL_RETURN_IF_ERROR(DecodeHistograms(br, (tree_storage.size() + 1) / 2, + &code_storage, &context_map_storage)); + } else { + if (!global_tree || !global_code || !global_ctx_map || + global_tree->empty()) { + return JXL_FAILURE("No global tree available but one was requested"); + } + tree = global_tree; + code = global_code; + context_map = global_ctx_map; + } + + // Read channels + ANSSymbolReader reader(code, br, distance_multiplier); + for (; next_channel < nb_channels; next_channel++) { + Channel &channel = image.channel[next_channel]; + if (!channel.w || !channel.h) { + continue; // skip empty channels + } + if (next_channel >= image.nb_meta_channels && + (channel.w > options->max_chan_size || + channel.h > options->max_chan_size)) { + break; + } + JXL_RETURN_IF_ERROR(DecodeModularChannelMAANS( + br, &reader, *context_map, *tree, header.wp_header, next_channel, + group_id, &image)); + // Truncated group. + if (!br->AllReadsWithinBounds()) { + if (!allow_truncated_group) return JXL_FAILURE("Truncated input"); + return Status(StatusCode::kNotEnoughBytes); + } + } + + // Make sure no zero-filling happens even if next_channel < nb_channels. + scope_guard.Disarm(); + + if (!reader.CheckANSFinalState()) { + return JXL_FAILURE("ANS decode final state failed"); + } + return true; +} + +Status ModularGenericDecompress(BitReader *br, Image &image, + GroupHeader *header, size_t group_id, + ModularOptions *options, bool undo_transforms, + const Tree *tree, const ANSCode *code, + const std::vector<uint8_t> *ctx_map, + bool allow_truncated_group) { +#ifdef JXL_ENABLE_ASSERT + std::vector<std::pair<uint32_t, uint32_t>> req_sizes(image.channel.size()); + for (size_t c = 0; c < req_sizes.size(); c++) { + req_sizes[c] = {image.channel[c].w, image.channel[c].h}; + } +#endif + GroupHeader local_header; + if (header == nullptr) header = &local_header; + size_t bit_pos = br->TotalBitsConsumed(); + auto dec_status = ModularDecode(br, image, *header, group_id, options, tree, + code, ctx_map, allow_truncated_group); + if (!allow_truncated_group) JXL_RETURN_IF_ERROR(dec_status); + if (dec_status.IsFatalError()) return dec_status; + if (undo_transforms) image.undo_transforms(header->wp_header); + if (image.error) return JXL_FAILURE("Corrupt file. Aborting."); + JXL_DEBUG_V(4, + "Modular-decoded a %" PRIuS "x%" PRIuS " nbchans=%" PRIuS + " image from %" PRIuS " bytes", + image.w, image.h, image.channel.size(), + (br->TotalBitsConsumed() - bit_pos) / 8); + JXL_DEBUG_V(5, "Modular image: %s", image.DebugString().c_str()); + (void)bit_pos; +#ifdef JXL_ENABLE_ASSERT + // Check that after applying all transforms we are back to the requested image + // sizes, otherwise there's a programming error with the transformations. + if (undo_transforms) { + JXL_ASSERT(image.channel.size() == req_sizes.size()); + for (size_t c = 0; c < req_sizes.size(); c++) { + JXL_ASSERT(req_sizes[c].first == image.channel[c].w); + JXL_ASSERT(req_sizes[c].second == image.channel[c].h); + } + } +#endif + return dec_status; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/encoding.h b/third_party/jpeg-xl/lib/jxl/modular/encoding/encoding.h new file mode 100644 index 0000000000..89697bce87 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/encoding.h @@ -0,0 +1,135 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_ENCODING_ENCODING_H_ +#define LIB_JXL_MODULAR_ENCODING_ENCODING_H_ + +#include <stddef.h> +#include <stdint.h> + +#include <vector> + +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/image.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/encoding/dec_ma.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/options.h" +#include "lib/jxl/modular/transform/transform.h" + +namespace jxl { + +// Valid range of properties for using lookup tables instead of trees. +constexpr int32_t kPropRangeFast = 512; + +struct GroupHeader : public Fields { + GroupHeader(); + + JXL_FIELDS_NAME(GroupHeader) + + Status VisitFields(Visitor *JXL_RESTRICT visitor) override { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &use_global_tree)); + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&wp_header)); + uint32_t num_transforms = static_cast<uint32_t>(transforms.size()); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(0), Val(1), BitsOffset(4, 2), + BitsOffset(8, 18), 0, + &num_transforms)); + if (visitor->IsReading()) transforms.resize(num_transforms); + for (size_t i = 0; i < num_transforms; i++) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&transforms[i])); + } + return true; + } + + bool use_global_tree; + weighted::Header wp_header; + + std::vector<Transform> transforms; +}; + +FlatTree FilterTree(const Tree &global_tree, + std::array<pixel_type, kNumStaticProperties> &static_props, + size_t *num_props, bool *use_wp, bool *wp_only, + bool *gradient_only); + +template <typename T> +bool TreeToLookupTable(const FlatTree &tree, + T context_lookup[2 * kPropRangeFast], + int8_t offsets[2 * kPropRangeFast], + int8_t multipliers[2 * kPropRangeFast] = nullptr) { + struct TreeRange { + // Begin *excluded*, end *included*. This works best with > vs <= decision + // nodes. + int begin, end; + size_t pos; + }; + std::vector<TreeRange> ranges; + ranges.push_back(TreeRange{-kPropRangeFast - 1, kPropRangeFast - 1, 0}); + while (!ranges.empty()) { + TreeRange cur = ranges.back(); + ranges.pop_back(); + if (cur.begin < -kPropRangeFast - 1 || cur.begin >= kPropRangeFast - 1 || + cur.end > kPropRangeFast - 1) { + // Tree is outside the allowed range, exit. + return false; + } + auto &node = tree[cur.pos]; + // Leaf. + if (node.property0 == -1) { + if (node.predictor_offset < std::numeric_limits<int8_t>::min() || + node.predictor_offset > std::numeric_limits<int8_t>::max()) { + return false; + } + if (node.multiplier < std::numeric_limits<int8_t>::min() || + node.multiplier > std::numeric_limits<int8_t>::max()) { + return false; + } + if (multipliers == nullptr && node.multiplier != 1) { + return false; + } + for (int i = cur.begin + 1; i < cur.end + 1; i++) { + context_lookup[i + kPropRangeFast] = node.childID; + if (multipliers) multipliers[i + kPropRangeFast] = node.multiplier; + offsets[i + kPropRangeFast] = node.predictor_offset; + } + continue; + } + // > side of top node. + if (node.properties[0] >= kNumStaticProperties) { + ranges.push_back(TreeRange({node.splitvals[0], cur.end, node.childID})); + ranges.push_back( + TreeRange({node.splitval0, node.splitvals[0], node.childID + 1})); + } else { + ranges.push_back(TreeRange({node.splitval0, cur.end, node.childID})); + } + // <= side + if (node.properties[1] >= kNumStaticProperties) { + ranges.push_back( + TreeRange({node.splitvals[1], node.splitval0, node.childID + 2})); + ranges.push_back( + TreeRange({cur.begin, node.splitvals[1], node.childID + 3})); + } else { + ranges.push_back( + TreeRange({cur.begin, node.splitval0, node.childID + 2})); + } + } + return true; +} +// TODO(veluca): make cleaner interfaces. + +Status ValidateChannelDimensions(const Image &image, + const ModularOptions &options); + +Status ModularGenericDecompress(BitReader *br, Image &image, + GroupHeader *header, size_t group_id, + ModularOptions *options, + bool undo_transforms = true, + const Tree *tree = nullptr, + const ANSCode *code = nullptr, + const std::vector<uint8_t> *ctx_map = nullptr, + bool allow_truncated_group = false); +} // namespace jxl + +#endif // LIB_JXL_MODULAR_ENCODING_ENCODING_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/ma_common.h b/third_party/jpeg-xl/lib/jxl/modular/encoding/ma_common.h new file mode 100644 index 0000000000..71b7847321 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/ma_common.h @@ -0,0 +1,28 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_ENCODING_MA_COMMON_H_ +#define LIB_JXL_MODULAR_ENCODING_MA_COMMON_H_ + +#include <stddef.h> + +namespace jxl { + +enum MATreeContext : size_t { + kSplitValContext = 0, + kPropertyContext = 1, + kPredictorContext = 2, + kOffsetContext = 3, + kMultiplierLogContext = 4, + kMultiplierBitsContext = 5, + + kNumTreeContexts = 6, +}; + +static constexpr size_t kMaxTreeSize = 1 << 22; + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_ENCODING_MA_COMMON_H_ |