diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 19:33:14 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 19:33:14 +0000 |
commit | 36d22d82aa202bb199967e9512281e9a53db42c9 (patch) | |
tree | 105e8c98ddea1c1e4784a60a5a6410fa416be2de /third_party/jpeg-xl/lib/jxl/modular | |
parent | Initial commit. (diff) | |
download | firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.tar.xz firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.zip |
Adding upstream version 115.7.0esr.upstream/115.7.0esrupstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/jpeg-xl/lib/jxl/modular')
31 files changed, 6075 insertions, 0 deletions
diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/context_predict.h b/third_party/jpeg-xl/lib/jxl/modular/encoding/context_predict.h new file mode 100644 index 0000000000..914cd6a4e4 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/context_predict.h @@ -0,0 +1,626 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_ +#define LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_ + +#include <utility> +#include <vector> + +#include "lib/jxl/fields.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/options.h" + +namespace jxl { + +namespace weighted { +constexpr static size_t kNumPredictors = 4; +constexpr static int64_t kPredExtraBits = 3; +constexpr static int64_t kPredictionRound = ((1 << kPredExtraBits) >> 1) - 1; +constexpr static size_t kNumProperties = 1; + +struct Header : public Fields { + JXL_FIELDS_NAME(WeightedPredictorHeader) + // TODO(janwas): move to cc file, avoid including fields.h. + Header() { Bundle::Init(this); } + + Status VisitFields(Visitor *JXL_RESTRICT visitor) override { + if (visitor->AllDefault(*this, &all_default)) { + // Overwrite all serialized fields, but not any nonserialized_*. + visitor->SetDefault(this); + return true; + } + auto visit_p = [visitor](pixel_type val, pixel_type *p) { + uint32_t up = *p; + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(5, val, &up)); + *p = up; + return Status(true); + }; + JXL_QUIET_RETURN_IF_ERROR(visit_p(16, &p1C)); + JXL_QUIET_RETURN_IF_ERROR(visit_p(10, &p2C)); + JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Ca)); + JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Cb)); + JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Cc)); + JXL_QUIET_RETURN_IF_ERROR(visit_p(0, &p3Cd)); + JXL_QUIET_RETURN_IF_ERROR(visit_p(0, &p3Ce)); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xd, &w[0])); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[1])); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[2])); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[3])); + return true; + } + + bool all_default; + pixel_type p1C = 0, p2C = 0, p3Ca = 0, p3Cb = 0, p3Cc = 0, p3Cd = 0, p3Ce = 0; + uint32_t w[kNumPredictors] = {}; +}; + +struct State { + pixel_type_w prediction[kNumPredictors] = {}; + pixel_type_w pred = 0; // *before* removing the added bits. + std::vector<uint32_t> pred_errors[kNumPredictors]; + std::vector<int32_t> error; + const Header header; + + // Allows to approximate division by a number from 1 to 64. + uint32_t divlookup[64]; + + constexpr static pixel_type_w AddBits(pixel_type_w x) { + return uint64_t(x) << kPredExtraBits; + } + + State(Header header, size_t xsize, size_t ysize) : header(header) { + // Extra margin to avoid out-of-bounds writes. + // All have space for two rows of data. + for (size_t i = 0; i < 4; i++) { + pred_errors[i].resize((xsize + 2) * 2); + } + error.resize((xsize + 2) * 2); + // Initialize division lookup table. + for (int i = 0; i < 64; i++) { + divlookup[i] = (1 << 24) / (i + 1); + } + } + + // Approximates 4+(maxweight<<24)/(x+1), avoiding division + JXL_INLINE uint32_t ErrorWeight(uint64_t x, uint32_t maxweight) const { + int shift = static_cast<int>(FloorLog2Nonzero(x + 1)) - 5; + if (shift < 0) shift = 0; + return 4 + ((maxweight * divlookup[x >> shift]) >> shift); + } + + // Approximates the weighted average of the input values with the given + // weights, avoiding division. Weights must sum to at least 16. + JXL_INLINE pixel_type_w + WeightedAverage(const pixel_type_w *JXL_RESTRICT p, + std::array<uint32_t, kNumPredictors> w) const { + uint32_t weight_sum = 0; + for (size_t i = 0; i < kNumPredictors; i++) { + weight_sum += w[i]; + } + JXL_DASSERT(weight_sum > 15); + uint32_t log_weight = FloorLog2Nonzero(weight_sum); // at least 4. + weight_sum = 0; + for (size_t i = 0; i < kNumPredictors; i++) { + w[i] >>= log_weight - 4; + weight_sum += w[i]; + } + // for rounding. + pixel_type_w sum = (weight_sum >> 1) - 1; + for (size_t i = 0; i < kNumPredictors; i++) { + sum += p[i] * w[i]; + } + return (sum * divlookup[weight_sum - 1]) >> 24; + } + + template <bool compute_properties> + JXL_INLINE pixel_type_w Predict(size_t x, size_t y, size_t xsize, + pixel_type_w N, pixel_type_w W, + pixel_type_w NE, pixel_type_w NW, + pixel_type_w NN, Properties *properties, + size_t offset) { + size_t cur_row = y & 1 ? 0 : (xsize + 2); + size_t prev_row = y & 1 ? (xsize + 2) : 0; + size_t pos_N = prev_row + x; + size_t pos_NE = x < xsize - 1 ? pos_N + 1 : pos_N; + size_t pos_NW = x > 0 ? pos_N - 1 : pos_N; + std::array<uint32_t, kNumPredictors> weights; + for (size_t i = 0; i < kNumPredictors; i++) { + // pred_errors[pos_N] also contains the error of pixel W. + // pred_errors[pos_NW] also contains the error of pixel WW. + weights[i] = pred_errors[i][pos_N] + pred_errors[i][pos_NE] + + pred_errors[i][pos_NW]; + weights[i] = ErrorWeight(weights[i], header.w[i]); + } + + N = AddBits(N); + W = AddBits(W); + NE = AddBits(NE); + NW = AddBits(NW); + NN = AddBits(NN); + + pixel_type_w teW = x == 0 ? 0 : error[cur_row + x - 1]; + pixel_type_w teN = error[pos_N]; + pixel_type_w teNW = error[pos_NW]; + pixel_type_w sumWN = teN + teW; + pixel_type_w teNE = error[pos_NE]; + + if (compute_properties) { + pixel_type_w p = teW; + if (std::abs(teN) > std::abs(p)) p = teN; + if (std::abs(teNW) > std::abs(p)) p = teNW; + if (std::abs(teNE) > std::abs(p)) p = teNE; + (*properties)[offset++] = p; + } + + prediction[0] = W + NE - N; + prediction[1] = N - (((sumWN + teNE) * header.p1C) >> 5); + prediction[2] = W - (((sumWN + teNW) * header.p2C) >> 5); + prediction[3] = + N - ((teNW * header.p3Ca + teN * header.p3Cb + teNE * header.p3Cc + + (NN - N) * header.p3Cd + (NW - W) * header.p3Ce) >> + 5); + + pred = WeightedAverage(prediction, weights); + + // If all three have the same sign, skip clamping. + if (((teN ^ teW) | (teN ^ teNW)) > 0) { + return (pred + kPredictionRound) >> kPredExtraBits; + } + + // Otherwise, clamp to min/max of neighbouring pixels (just W, NE, N). + pixel_type_w mx = std::max(W, std::max(NE, N)); + pixel_type_w mn = std::min(W, std::min(NE, N)); + pred = std::max(mn, std::min(mx, pred)); + return (pred + kPredictionRound) >> kPredExtraBits; + } + + JXL_INLINE void UpdateErrors(pixel_type_w val, size_t x, size_t y, + size_t xsize) { + size_t cur_row = y & 1 ? 0 : (xsize + 2); + size_t prev_row = y & 1 ? (xsize + 2) : 0; + val = AddBits(val); + error[cur_row + x] = pred - val; + for (size_t i = 0; i < kNumPredictors; i++) { + pixel_type_w err = + (std::abs(prediction[i] - val) + kPredictionRound) >> kPredExtraBits; + // For predicting in the next row. + pred_errors[i][cur_row + x] = err; + // Add the error on this pixel to the error on the NE pixel. This has the + // effect of adding the error on this pixel to the E and EE pixels. + pred_errors[i][prev_row + x + 1] += err; + } + } +}; + +// Encoder helper function to set the parameters to some presets. +inline void PredictorMode(int i, Header *header) { + switch (i) { + case 0: + // ~ lossless16 predictor + header->w[0] = 0xd; + header->w[1] = 0xc; + header->w[2] = 0xc; + header->w[3] = 0xc; + header->p1C = 16; + header->p2C = 10; + header->p3Ca = 7; + header->p3Cb = 7; + header->p3Cc = 7; + header->p3Cd = 0; + header->p3Ce = 0; + break; + case 1: + // ~ default lossless8 predictor + header->w[0] = 0xd; + header->w[1] = 0xc; + header->w[2] = 0xc; + header->w[3] = 0xb; + header->p1C = 8; + header->p2C = 8; + header->p3Ca = 4; + header->p3Cb = 0; + header->p3Cc = 3; + header->p3Cd = 23; + header->p3Ce = 2; + break; + case 2: + // ~ west lossless8 predictor + header->w[0] = 0xd; + header->w[1] = 0xc; + header->w[2] = 0xd; + header->w[3] = 0xc; + header->p1C = 10; + header->p2C = 9; + header->p3Ca = 7; + header->p3Cb = 0; + header->p3Cc = 0; + header->p3Cd = 16; + header->p3Ce = 9; + break; + case 3: + // ~ north lossless8 predictor + header->w[0] = 0xd; + header->w[1] = 0xd; + header->w[2] = 0xc; + header->w[3] = 0xc; + header->p1C = 16; + header->p2C = 8; + header->p3Ca = 0; + header->p3Cb = 16; + header->p3Cc = 0; + header->p3Cd = 23; + header->p3Ce = 0; + break; + case 4: + default: + // something else, because why not + header->w[0] = 0xd; + header->w[1] = 0xc; + header->w[2] = 0xc; + header->w[3] = 0xc; + header->p1C = 10; + header->p2C = 10; + header->p3Ca = 5; + header->p3Cb = 5; + header->p3Cc = 5; + header->p3Cd = 12; + header->p3Ce = 4; + break; + } +} +} // namespace weighted + +// Stores a node and its two children at the same time. This significantly +// reduces the number of branches needed during decoding. +struct FlatDecisionNode { + // Property + splitval of the top node. + int32_t property0; // -1 if leaf. + union { + PropertyVal splitval0; + Predictor predictor; + }; + uint32_t childID; // childID is ctx id if leaf. + // Property+splitval of the two child nodes. + union { + PropertyVal splitvals[2]; + int32_t multiplier; + }; + union { + int32_t properties[2]; + int64_t predictor_offset; + }; +}; +using FlatTree = std::vector<FlatDecisionNode>; + +class MATreeLookup { + public: + explicit MATreeLookup(const FlatTree &tree) : nodes_(tree) {} + struct LookupResult { + uint32_t context; + Predictor predictor; + int64_t offset; + int32_t multiplier; + }; + JXL_INLINE LookupResult Lookup(const Properties &properties) const { + uint32_t pos = 0; + while (true) { + const FlatDecisionNode &node = nodes_[pos]; + if (node.property0 < 0) { + return {node.childID, node.predictor, node.predictor_offset, + node.multiplier}; + } + bool p0 = properties[node.property0] <= node.splitval0; + uint32_t off0 = properties[node.properties[0]] <= node.splitvals[0]; + uint32_t off1 = + 2 | (properties[node.properties[1]] <= node.splitvals[1] ? 1 : 0); + pos = node.childID + (p0 ? off1 : off0); + } + } + + private: + const FlatTree &nodes_; +}; + +static constexpr size_t kExtraPropsPerChannel = 4; +static constexpr size_t kNumNonrefProperties = + kNumStaticProperties + 13 + weighted::kNumProperties; + +constexpr size_t kWPProp = kNumNonrefProperties - weighted::kNumProperties; +constexpr size_t kGradientProp = 9; + +// Clamps gradient to the min/max of n, w (and l, implicitly). +static JXL_INLINE int32_t ClampedGradient(const int32_t n, const int32_t w, + const int32_t l) { + const int32_t m = std::min(n, w); + const int32_t M = std::max(n, w); + // The end result of this operation doesn't overflow or underflow if the + // result is between m and M, but the intermediate value may overflow, so we + // do the intermediate operations in uint32_t and check later if we had an + // overflow or underflow condition comparing m, M and l directly. + // grad = M + m - l = n + w - l + const int32_t grad = + static_cast<int32_t>(static_cast<uint32_t>(n) + static_cast<uint32_t>(w) - + static_cast<uint32_t>(l)); + // We use two sets of ternary operators to force the evaluation of them in + // any case, allowing the compiler to avoid branches and use cmovl/cmovg in + // x86. + const int32_t grad_clamp_M = (l < m) ? M : grad; + return (l > M) ? m : grad_clamp_M; +} + +inline pixel_type_w Select(pixel_type_w a, pixel_type_w b, pixel_type_w c) { + pixel_type_w p = a + b - c; + pixel_type_w pa = std::abs(p - a); + pixel_type_w pb = std::abs(p - b); + return pa < pb ? a : b; +} + +inline void PrecomputeReferences(const Channel &ch, size_t y, + const Image &image, uint32_t i, + Channel *references) { + ZeroFillImage(&references->plane); + uint32_t offset = 0; + size_t num_extra_props = references->w; + intptr_t onerow = references->plane.PixelsPerRow(); + for (int32_t j = static_cast<int32_t>(i) - 1; + j >= 0 && offset < num_extra_props; j--) { + if (image.channel[j].w != image.channel[i].w || + image.channel[j].h != image.channel[i].h) { + continue; + } + if (image.channel[j].hshift != image.channel[i].hshift) continue; + if (image.channel[j].vshift != image.channel[i].vshift) continue; + pixel_type *JXL_RESTRICT rp = references->Row(0) + offset; + const pixel_type *JXL_RESTRICT rpp = image.channel[j].Row(y); + const pixel_type *JXL_RESTRICT rpprev = image.channel[j].Row(y ? y - 1 : 0); + for (size_t x = 0; x < ch.w; x++, rp += onerow) { + pixel_type_w v = rpp[x]; + rp[0] = std::abs(v); + rp[1] = v; + pixel_type_w vleft = (x ? rpp[x - 1] : 0); + pixel_type_w vtop = (y ? rpprev[x] : vleft); + pixel_type_w vtopleft = (x && y ? rpprev[x - 1] : vleft); + pixel_type_w vpredicted = ClampedGradient(vleft, vtop, vtopleft); + rp[2] = std::abs(v - vpredicted); + rp[3] = v - vpredicted; + } + + offset += kExtraPropsPerChannel; + } +} + +struct PredictionResult { + int context = 0; + pixel_type_w guess = 0; + Predictor predictor; + int32_t multiplier; +}; + +inline void InitPropsRow( + Properties *p, + const std::array<pixel_type, kNumStaticProperties> &static_props, + const int y) { + for (size_t i = 0; i < kNumStaticProperties; i++) { + (*p)[i] = static_props[i]; + } + (*p)[2] = y; + (*p)[9] = 0; // local gradient. +} + +namespace detail { +enum PredictorMode { + kUseTree = 1, + kUseWP = 2, + kForceComputeProperties = 4, + kAllPredictions = 8, + kNoEdgeCases = 16 +}; + +JXL_INLINE pixel_type_w PredictOne(Predictor p, pixel_type_w left, + pixel_type_w top, pixel_type_w toptop, + pixel_type_w topleft, pixel_type_w topright, + pixel_type_w leftleft, + pixel_type_w toprightright, + pixel_type_w wp_pred) { + switch (p) { + case Predictor::Zero: + return pixel_type_w{0}; + case Predictor::Left: + return left; + case Predictor::Top: + return top; + case Predictor::Select: + return Select(left, top, topleft); + case Predictor::Weighted: + return wp_pred; + case Predictor::Gradient: + return pixel_type_w{ClampedGradient(left, top, topleft)}; + case Predictor::TopLeft: + return topleft; + case Predictor::TopRight: + return topright; + case Predictor::LeftLeft: + return leftleft; + case Predictor::Average0: + return (left + top) / 2; + case Predictor::Average1: + return (left + topleft) / 2; + case Predictor::Average2: + return (topleft + top) / 2; + case Predictor::Average3: + return (top + topright) / 2; + case Predictor::Average4: + return (6 * top - 2 * toptop + 7 * left + 1 * leftleft + + 1 * toprightright + 3 * topright + 8) / + 16; + default: + return pixel_type_w{0}; + } +} + +template <int mode> +JXL_INLINE PredictionResult Predict( + Properties *p, size_t w, const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const size_t x, const size_t y, Predictor predictor, + const MATreeLookup *lookup, const Channel *references, + weighted::State *wp_state, pixel_type_w *predictions) { + // We start in position 3 because of 2 static properties + y. + size_t offset = 3; + constexpr bool compute_properties = + mode & kUseTree || mode & kForceComputeProperties; + constexpr bool nec = mode & kNoEdgeCases; + pixel_type_w left = (nec || x ? pp[-1] : (y ? pp[-onerow] : 0)); + pixel_type_w top = (nec || y ? pp[-onerow] : left); + pixel_type_w topleft = (nec || (x && y) ? pp[-1 - onerow] : left); + pixel_type_w topright = (nec || (x + 1 < w && y) ? pp[1 - onerow] : top); + pixel_type_w leftleft = (nec || x > 1 ? pp[-2] : left); + pixel_type_w toptop = (nec || y > 1 ? pp[-onerow - onerow] : top); + pixel_type_w toprightright = + (nec || (x + 2 < w && y) ? pp[2 - onerow] : topright); + + if (compute_properties) { + // location + (*p)[offset++] = x; + // neighbors + (*p)[offset++] = std::abs(top); + (*p)[offset++] = std::abs(left); + (*p)[offset++] = top; + (*p)[offset++] = left; + + // local gradient + (*p)[offset] = left - (*p)[offset + 1]; + offset++; + // local gradient + (*p)[offset++] = left + top - topleft; + + // FFV1 context properties + (*p)[offset++] = left - topleft; + (*p)[offset++] = topleft - top; + (*p)[offset++] = top - topright; + (*p)[offset++] = top - toptop; + (*p)[offset++] = left - leftleft; + } + + pixel_type_w wp_pred = 0; + if (mode & kUseWP) { + wp_pred = wp_state->Predict<compute_properties>( + x, y, w, top, left, topright, topleft, toptop, p, offset); + } + if (!nec && compute_properties) { + offset += weighted::kNumProperties; + // Extra properties. + const pixel_type *JXL_RESTRICT rp = references->Row(x); + for (size_t i = 0; i < references->w; i++) { + (*p)[offset++] = rp[i]; + } + } + PredictionResult result; + if (mode & kUseTree) { + MATreeLookup::LookupResult lr = lookup->Lookup(*p); + result.context = lr.context; + result.guess = lr.offset; + result.multiplier = lr.multiplier; + predictor = lr.predictor; + } + if (mode & kAllPredictions) { + for (size_t i = 0; i < kNumModularPredictors; i++) { + predictions[i] = PredictOne((Predictor)i, left, top, toptop, topleft, + topright, leftleft, toprightright, wp_pred); + } + } + result.guess += PredictOne(predictor, left, top, toptop, topleft, topright, + leftleft, toprightright, wp_pred); + result.predictor = predictor; + + return result; +} +} // namespace detail + +inline PredictionResult PredictNoTreeNoWP(size_t w, + const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, + const int y, Predictor predictor) { + return detail::Predict</*mode=*/0>( + /*p=*/nullptr, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, + /*references=*/nullptr, /*wp_state=*/nullptr, /*predictions=*/nullptr); +} + +inline PredictionResult PredictNoTreeWP(size_t w, + const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, + const int y, Predictor predictor, + weighted::State *wp_state) { + return detail::Predict<detail::kUseWP>( + /*p=*/nullptr, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, + /*references=*/nullptr, wp_state, /*predictions=*/nullptr); +} + +inline PredictionResult PredictTreeNoWP(Properties *p, size_t w, + const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, + const int y, + const MATreeLookup &tree_lookup, + const Channel &references) { + return detail::Predict<detail::kUseTree>( + p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references, + /*wp_state=*/nullptr, /*predictions=*/nullptr); +} +// Only use for y > 1, x > 1, x < w-2, and empty references +JXL_INLINE PredictionResult +PredictTreeNoWPNEC(Properties *p, size_t w, const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, const int y, + const MATreeLookup &tree_lookup, const Channel &references) { + return detail::Predict<detail::kUseTree | detail::kNoEdgeCases>( + p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references, + /*wp_state=*/nullptr, /*predictions=*/nullptr); +} + +inline PredictionResult PredictTreeWP(Properties *p, size_t w, + const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, + const int y, + const MATreeLookup &tree_lookup, + const Channel &references, + weighted::State *wp_state) { + return detail::Predict<detail::kUseTree | detail::kUseWP>( + p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references, + wp_state, /*predictions=*/nullptr); +} + +inline PredictionResult PredictLearn(Properties *p, size_t w, + const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, + const int y, Predictor predictor, + const Channel &references, + weighted::State *wp_state) { + return detail::Predict<detail::kForceComputeProperties | detail::kUseWP>( + p, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, &references, + wp_state, /*predictions=*/nullptr); +} + +inline void PredictLearnAll(Properties *p, size_t w, + const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, const int y, + const Channel &references, + weighted::State *wp_state, + pixel_type_w *predictions) { + detail::Predict<detail::kForceComputeProperties | detail::kUseWP | + detail::kAllPredictions>( + p, w, pp, onerow, x, y, Predictor::Zero, + /*lookup=*/nullptr, &references, wp_state, predictions); +} + +inline void PredictAllNoWP(size_t w, const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, const int y, + pixel_type_w *predictions) { + detail::Predict<detail::kAllPredictions>( + /*p=*/nullptr, w, pp, onerow, x, y, Predictor::Zero, + /*lookup=*/nullptr, + /*references=*/nullptr, /*wp_state=*/nullptr, predictions); +} +} // namespace jxl + +#endif // LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/dec_ma.cc b/third_party/jpeg-xl/lib/jxl/modular/encoding/dec_ma.cc new file mode 100644 index 0000000000..66562f7dfd --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/dec_ma.cc @@ -0,0 +1,107 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/encoding/dec_ma.h" + +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/modular/encoding/ma_common.h" +#include "lib/jxl/modular/modular_image.h" + +namespace jxl { + +namespace { + +Status ValidateTree( + const Tree &tree, + const std::vector<std::pair<pixel_type, pixel_type>> &prop_bounds, + size_t root) { + if (tree[root].property == -1) return true; + size_t p = tree[root].property; + int val = tree[root].splitval; + if (prop_bounds[p].first > val) return JXL_FAILURE("Invalid tree"); + // Splitting at max value makes no sense: left range will be exactly same + // as parent, right range will be invalid (min > max). + if (prop_bounds[p].second <= val) return JXL_FAILURE("Invalid tree"); + auto new_bounds = prop_bounds; + new_bounds[p].first = val + 1; + JXL_RETURN_IF_ERROR(ValidateTree(tree, new_bounds, tree[root].lchild)); + new_bounds[p] = prop_bounds[p]; + new_bounds[p].second = val; + return ValidateTree(tree, new_bounds, tree[root].rchild); +} + +Status DecodeTree(BitReader *br, ANSSymbolReader *reader, + const std::vector<uint8_t> &context_map, Tree *tree, + size_t tree_size_limit) { + size_t leaf_id = 0; + size_t to_decode = 1; + tree->clear(); + while (to_decode > 0) { + JXL_RETURN_IF_ERROR(br->AllReadsWithinBounds()); + if (tree->size() > tree_size_limit) { + return JXL_FAILURE("Tree is too large: %" PRIuS " nodes vs %" PRIuS + " max nodes", + tree->size(), tree_size_limit); + } + to_decode--; + uint32_t prop1 = reader->ReadHybridUint(kPropertyContext, br, context_map); + if (prop1 > 256) return JXL_FAILURE("Invalid tree property value"); + int property = prop1 - 1; + if (property == -1) { + size_t predictor = + reader->ReadHybridUint(kPredictorContext, br, context_map); + if (predictor >= kNumModularPredictors) { + return JXL_FAILURE("Invalid predictor"); + } + int64_t predictor_offset = + UnpackSigned(reader->ReadHybridUint(kOffsetContext, br, context_map)); + uint32_t mul_log = + reader->ReadHybridUint(kMultiplierLogContext, br, context_map); + if (mul_log >= 31) { + return JXL_FAILURE("Invalid multiplier logarithm"); + } + uint32_t mul_bits = + reader->ReadHybridUint(kMultiplierBitsContext, br, context_map); + if (mul_bits + 1 >= 1u << (31u - mul_log)) { + return JXL_FAILURE("Invalid multiplier"); + } + uint32_t multiplier = (mul_bits + 1U) << mul_log; + tree->emplace_back(-1, 0, leaf_id++, 0, static_cast<Predictor>(predictor), + predictor_offset, multiplier); + continue; + } + int splitval = + UnpackSigned(reader->ReadHybridUint(kSplitValContext, br, context_map)); + tree->emplace_back(property, splitval, tree->size() + to_decode + 1, + tree->size() + to_decode + 2, Predictor::Zero, 0, 1); + to_decode += 2; + } + std::vector<std::pair<pixel_type, pixel_type>> prop_bounds; + prop_bounds.resize(256, {std::numeric_limits<pixel_type>::min(), + std::numeric_limits<pixel_type>::max()}); + return ValidateTree(*tree, prop_bounds, 0); +} +} // namespace + +Status DecodeTree(BitReader *br, Tree *tree, size_t tree_size_limit) { + std::vector<uint8_t> tree_context_map; + ANSCode tree_code; + JXL_RETURN_IF_ERROR( + DecodeHistograms(br, kNumTreeContexts, &tree_code, &tree_context_map)); + // TODO(eustas): investigate more infinite tree cases. + if (tree_code.degenerate_symbols[tree_context_map[kPropertyContext]] > 0) { + return JXL_FAILURE("Infinite tree"); + } + ANSSymbolReader reader(&tree_code, br); + JXL_RETURN_IF_ERROR(DecodeTree(br, &reader, tree_context_map, tree, + std::min(tree_size_limit, kMaxTreeSize))); + if (!reader.CheckANSFinalState()) { + return JXL_FAILURE("ANS decode final state failed"); + } + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/dec_ma.h b/third_party/jpeg-xl/lib/jxl/modular/encoding/dec_ma.h new file mode 100644 index 0000000000..a910c4deb1 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/dec_ma.h @@ -0,0 +1,66 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_ENCODING_DEC_MA_H_ +#define LIB_JXL_MODULAR_ENCODING_DEC_MA_H_ + +#include <stddef.h> +#include <stdint.h> + +#include <vector> + +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/modular/options.h" + +namespace jxl { + +// inner nodes +struct PropertyDecisionNode { + PropertyVal splitval; + int16_t property; // -1: leaf node, lchild points to leaf node + uint32_t lchild; + uint32_t rchild; + Predictor predictor; + int64_t predictor_offset; + uint32_t multiplier; + + PropertyDecisionNode(int p, int split_val, int lchild, int rchild, + Predictor predictor, int64_t predictor_offset, + uint32_t multiplier) + : splitval(split_val), + property(p), + lchild(lchild), + rchild(rchild), + predictor(predictor), + predictor_offset(predictor_offset), + multiplier(multiplier) {} + PropertyDecisionNode() + : splitval(0), + property(-1), + lchild(0), + rchild(0), + predictor(Predictor::Zero), + predictor_offset(0), + multiplier(1) {} + static PropertyDecisionNode Leaf(Predictor predictor, int64_t offset = 0, + uint32_t multiplier = 1) { + return PropertyDecisionNode(-1, 0, 0, 0, predictor, offset, multiplier); + } + static PropertyDecisionNode Split(int p, int split_val, int lchild, + int rchild = -1) { + if (rchild == -1) rchild = lchild + 1; + return PropertyDecisionNode(p, split_val, lchild, rchild, Predictor::Zero, + 0, 1); + } +}; + +using Tree = std::vector<PropertyDecisionNode>; + +Status DecodeTree(BitReader *br, Tree *tree, size_t tree_size_limit); + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_ENCODING_DEC_MA_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_debug_tree.cc b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_debug_tree.cc new file mode 100644 index 0000000000..f2a1705e4b --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_debug_tree.cc @@ -0,0 +1,124 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/encoding/enc_debug_tree.h" + +#include <stdint.h> +#include <stdlib.h> + +#include "lib/jxl/base/os_macros.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/encoding/dec_ma.h" +#include "lib/jxl/modular/options.h" + +#if JXL_OS_IOS +#define JXL_ENABLE_DOT 0 +#else +#define JXL_ENABLE_DOT 1 // iOS lacks C89 system() +#endif + +namespace jxl { + +const char *PredictorName(Predictor p) { + switch (p) { + case Predictor::Zero: + return "Zero"; + case Predictor::Left: + return "Left"; + case Predictor::Top: + return "Top"; + case Predictor::Average0: + return "Avg0"; + case Predictor::Average1: + return "Avg1"; + case Predictor::Average2: + return "Avg2"; + case Predictor::Average3: + return "Avg3"; + case Predictor::Average4: + return "Avg4"; + case Predictor::Select: + return "Sel"; + case Predictor::Gradient: + return "Grd"; + case Predictor::Weighted: + return "Wgh"; + case Predictor::TopLeft: + return "TopL"; + case Predictor::TopRight: + return "TopR"; + case Predictor::LeftLeft: + return "LL"; + default: + return "INVALID"; + }; +} + +std::string PropertyName(size_t i) { + static_assert(kNumNonrefProperties == 16, "Update this function"); + switch (i) { + case 0: + return "c"; + case 1: + return "g"; + case 2: + return "y"; + case 3: + return "x"; + case 4: + return "|N|"; + case 5: + return "|W|"; + case 6: + return "N"; + case 7: + return "W"; + case 8: + return "W-WW-NW+NWW"; + case 9: + return "W+N-NW"; + case 10: + return "W-NW"; + case 11: + return "NW-N"; + case 12: + return "N-NE"; + case 13: + return "N-NN"; + case 14: + return "W-WW"; + case 15: + return "WGH"; + default: + return "ch[" + ToString(15 - (int)i) + "]"; + } +} + +void PrintTree(const Tree &tree, const std::string &path) { + FILE *f = fopen((path + ".dot").c_str(), "w"); + fprintf(f, "graph{\n"); + for (size_t cur = 0; cur < tree.size(); cur++) { + if (tree[cur].property < 0) { + fprintf(f, "n%05" PRIuS " [label=\"%s%+" PRId64 " (x%u)\"];\n", cur, + PredictorName(tree[cur].predictor), tree[cur].predictor_offset, + tree[cur].multiplier); + } else { + fprintf(f, "n%05" PRIuS " [label=\"%s>%d\"];\n", cur, + PropertyName(tree[cur].property).c_str(), tree[cur].splitval); + fprintf(f, "n%05" PRIuS " -- n%05d;\n", cur, tree[cur].lchild); + fprintf(f, "n%05" PRIuS " -- n%05d;\n", cur, tree[cur].rchild); + } + } + fprintf(f, "}\n"); + fclose(f); +#if JXL_ENABLE_DOT + JXL_ASSERT( + system(("dot " + path + ".dot -T svg -o " + path + ".svg").c_str()) == 0); +#endif +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_debug_tree.h b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_debug_tree.h new file mode 100644 index 0000000000..78deaab1b8 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_debug_tree.h @@ -0,0 +1,27 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_ENCODING_ENC_DEBUG_TREE_H_ +#define LIB_JXL_MODULAR_ENCODING_ENC_DEBUG_TREE_H_ + +#include <stddef.h> +#include <stdint.h> + +#include <string> +#include <vector> + +#include "lib/jxl/modular/encoding/dec_ma.h" +#include "lib/jxl/modular/options.h" + +namespace jxl { + +const char *PredictorName(Predictor p); +std::string PropertyName(size_t i); + +void PrintTree(const Tree &tree, const std::string &path); + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_ENCODING_ENC_DEBUG_TREE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_encoding.cc b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_encoding.cc new file mode 100644 index 0000000000..c8c183335e --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_encoding.cc @@ -0,0 +1,562 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include <stdint.h> +#include <stdlib.h> + +#include <cinttypes> +#include <limits> +#include <numeric> +#include <queue> +#include <set> +#include <unordered_map> +#include <unordered_set> + +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/enc_aux_out.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_fields.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/encoding/enc_debug_tree.h" +#include "lib/jxl/modular/encoding/enc_ma.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/modular/encoding/ma_common.h" +#include "lib/jxl/modular/options.h" +#include "lib/jxl/modular/transform/transform.h" +#include "lib/jxl/toc.h" + +namespace jxl { + +namespace { +// Plot tree (if enabled) and predictor usage map. +constexpr bool kWantDebug = false; +constexpr bool kPrintTree = false; + +inline std::array<uint8_t, 3> PredictorColor(Predictor p) { + switch (p) { + case Predictor::Zero: + return {{0, 0, 0}}; + case Predictor::Left: + return {{255, 0, 0}}; + case Predictor::Top: + return {{0, 255, 0}}; + case Predictor::Average0: + return {{0, 0, 255}}; + case Predictor::Average4: + return {{192, 128, 128}}; + case Predictor::Select: + return {{255, 255, 0}}; + case Predictor::Gradient: + return {{255, 0, 255}}; + case Predictor::Weighted: + return {{0, 255, 255}}; + // TODO + default: + return {{255, 255, 255}}; + }; +} + +} // namespace + +void GatherTreeData(const Image &image, pixel_type chan, size_t group_id, + const weighted::Header &wp_header, + const ModularOptions &options, TreeSamples &tree_samples, + size_t *total_pixels) { + const Channel &channel = image.channel[chan]; + + JXL_DEBUG_V(7, "Learning %" PRIuS "x%" PRIuS " channel %d", channel.w, + channel.h, chan); + + std::array<pixel_type, kNumStaticProperties> static_props = { + {chan, (int)group_id}}; + Properties properties(kNumNonrefProperties + + kExtraPropsPerChannel * options.max_properties); + double pixel_fraction = std::min(1.0f, options.nb_repeats); + // a fraction of 0 is used to disable learning entirely. + if (pixel_fraction > 0) { + pixel_fraction = std::max(pixel_fraction, + std::min(1.0, 1024.0 / (channel.w * channel.h))); + } + uint64_t threshold = + (std::numeric_limits<uint64_t>::max() >> 32) * pixel_fraction; + uint64_t s[2] = {static_cast<uint64_t>(0x94D049BB133111EBull), + static_cast<uint64_t>(0xBF58476D1CE4E5B9ull)}; + // Xorshift128+ adapted from xorshift128+-inl.h + auto use_sample = [&]() { + auto s1 = s[0]; + const auto s0 = s[1]; + const auto bits = s1 + s0; // b, c + s[0] = s0; + s1 ^= s1 << 23; + s1 ^= s0 ^ (s1 >> 18) ^ (s0 >> 5); + s[1] = s1; + return (bits >> 32) <= threshold; + }; + + const intptr_t onerow = channel.plane.PixelsPerRow(); + Channel references(properties.size() - kNumNonrefProperties, channel.w); + weighted::State wp_state(wp_header, channel.w, channel.h); + tree_samples.PrepareForSamples(pixel_fraction * channel.h * channel.w + 64); + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT p = channel.Row(y); + PrecomputeReferences(channel, y, image, chan, &references); + InitPropsRow(&properties, static_props, y); + // TODO(veluca): avoid computing WP if we don't use its property or + // predictions. + for (size_t x = 0; x < channel.w; x++) { + pixel_type_w pred[kNumModularPredictors]; + if (tree_samples.NumPredictors() != 1) { + PredictLearnAll(&properties, channel.w, p + x, onerow, x, y, references, + &wp_state, pred); + } else { + pred[static_cast<int>(tree_samples.PredictorFromIndex(0))] = + PredictLearn(&properties, channel.w, p + x, onerow, x, y, + tree_samples.PredictorFromIndex(0), references, + &wp_state) + .guess; + } + (*total_pixels)++; + if (use_sample()) { + tree_samples.AddSample(p[x], properties, pred); + } + wp_state.UpdateErrors(p[x], x, y, channel.w); + } + } +} + +Tree LearnTree(TreeSamples &&tree_samples, size_t total_pixels, + const ModularOptions &options, + const std::vector<ModularMultiplierInfo> &multiplier_info = {}, + StaticPropRange static_prop_range = {}) { + for (size_t i = 0; i < kNumStaticProperties; i++) { + if (static_prop_range[i][1] == 0) { + static_prop_range[i][1] = std::numeric_limits<uint32_t>::max(); + } + } + if (!tree_samples.HasSamples()) { + Tree tree; + tree.emplace_back(); + tree.back().predictor = tree_samples.PredictorFromIndex(0); + tree.back().property = -1; + tree.back().predictor_offset = 0; + tree.back().multiplier = 1; + return tree; + } + float pixel_fraction = tree_samples.NumSamples() * 1.0f / total_pixels; + float required_cost = pixel_fraction * 0.9 + 0.1; + tree_samples.AllSamplesDone(); + Tree tree; + ComputeBestTree(tree_samples, + options.splitting_heuristics_node_threshold * required_cost, + multiplier_info, static_prop_range, + options.fast_decode_multiplier, &tree); + return tree; +} + +Status EncodeModularChannelMAANS(const Image &image, pixel_type chan, + const weighted::Header &wp_header, + const Tree &global_tree, Token **tokenpp, + AuxOut *aux_out, size_t group_id, + bool skip_encoder_fast_path) { + const Channel &channel = image.channel[chan]; + Token *tokenp = *tokenpp; + JXL_ASSERT(channel.w != 0 && channel.h != 0); + + Image3F predictor_img; + if (kWantDebug) predictor_img = Image3F(channel.w, channel.h); + + JXL_DEBUG_V(6, + "Encoding %" PRIuS "x%" PRIuS + " channel %d, " + "(shift=%i,%i)", + channel.w, channel.h, chan, channel.hshift, channel.vshift); + + std::array<pixel_type, kNumStaticProperties> static_props = { + {chan, (int)group_id}}; + bool use_wp, is_wp_only; + bool is_gradient_only; + size_t num_props; + FlatTree tree = FilterTree(global_tree, static_props, &num_props, &use_wp, + &is_wp_only, &is_gradient_only); + Properties properties(num_props); + MATreeLookup tree_lookup(tree); + JXL_DEBUG_V(3, "Encoding using a MA tree with %" PRIuS " nodes", tree.size()); + + // Check if this tree is a WP-only tree with a small enough property value + // range. + // Initialized to avoid clang-tidy complaining. + uint16_t context_lookup[2 * kPropRangeFast] = {}; + int8_t offsets[2 * kPropRangeFast] = {}; + if (is_wp_only) { + is_wp_only = TreeToLookupTable(tree, context_lookup, offsets); + } + if (is_gradient_only) { + is_gradient_only = TreeToLookupTable(tree, context_lookup, offsets); + } + + if (is_wp_only && !skip_encoder_fast_path) { + for (size_t c = 0; c < 3; c++) { + FillImage(static_cast<float>(PredictorColor(Predictor::Weighted)[c]), + &predictor_img.Plane(c)); + } + const intptr_t onerow = channel.plane.PixelsPerRow(); + weighted::State wp_state(wp_header, channel.w, channel.h); + Properties properties(1); + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + size_t offset = 0; + pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); + pixel_type_w top = (y ? *(r + x - onerow) : left); + pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left); + pixel_type_w topright = + (x + 1 < channel.w && y ? *(r + x + 1 - onerow) : top); + pixel_type_w toptop = (y > 1 ? *(r + x - onerow - onerow) : top); + int32_t guess = wp_state.Predict</*compute_properties=*/true>( + x, y, channel.w, top, left, topright, topleft, toptop, &properties, + offset); + uint32_t pos = + kPropRangeFast + std::min(std::max(-kPropRangeFast, properties[0]), + kPropRangeFast - 1); + uint32_t ctx_id = context_lookup[pos]; + int32_t residual = r[x] - guess - offsets[pos]; + *tokenp++ = Token(ctx_id, PackSigned(residual)); + wp_state.UpdateErrors(r[x], x, y, channel.w); + } + } + } else if (tree.size() == 1 && tree[0].predictor == Predictor::Gradient && + tree[0].multiplier == 1 && tree[0].predictor_offset == 0 && + !skip_encoder_fast_path) { + for (size_t c = 0; c < 3; c++) { + FillImage(static_cast<float>(PredictorColor(Predictor::Gradient)[c]), + &predictor_img.Plane(c)); + } + const intptr_t onerow = channel.plane.PixelsPerRow(); + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); + pixel_type_w top = (y ? *(r + x - onerow) : left); + pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left); + int32_t guess = ClampedGradient(top, left, topleft); + int32_t residual = r[x] - guess; + *tokenp++ = Token(tree[0].childID, PackSigned(residual)); + } + } + } else if (is_gradient_only && !skip_encoder_fast_path) { + for (size_t c = 0; c < 3; c++) { + FillImage(static_cast<float>(PredictorColor(Predictor::Gradient)[c]), + &predictor_img.Plane(c)); + } + const intptr_t onerow = channel.plane.PixelsPerRow(); + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); + pixel_type_w top = (y ? *(r + x - onerow) : left); + pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left); + int32_t guess = ClampedGradient(top, left, topleft); + uint32_t pos = + kPropRangeFast + + std::min<pixel_type_w>( + std::max<pixel_type_w>(-kPropRangeFast, top + left - topleft), + kPropRangeFast - 1); + uint32_t ctx_id = context_lookup[pos]; + int32_t residual = r[x] - guess - offsets[pos]; + *tokenp++ = Token(ctx_id, PackSigned(residual)); + } + } + } else if (tree.size() == 1 && tree[0].predictor == Predictor::Zero && + tree[0].multiplier == 1 && tree[0].predictor_offset == 0 && + !skip_encoder_fast_path) { + for (size_t c = 0; c < 3; c++) { + FillImage(static_cast<float>(PredictorColor(Predictor::Zero)[c]), + &predictor_img.Plane(c)); + } + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT p = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + *tokenp++ = Token(tree[0].childID, PackSigned(p[x])); + } + } + } else if (tree.size() == 1 && tree[0].predictor != Predictor::Weighted && + (tree[0].multiplier & (tree[0].multiplier - 1)) == 0 && + tree[0].predictor_offset == 0 && !skip_encoder_fast_path) { + // multiplier is a power of 2. + for (size_t c = 0; c < 3; c++) { + FillImage(static_cast<float>(PredictorColor(tree[0].predictor)[c]), + &predictor_img.Plane(c)); + } + uint32_t mul_shift = FloorLog2Nonzero((uint32_t)tree[0].multiplier); + const intptr_t onerow = channel.plane.PixelsPerRow(); + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + PredictionResult pred = PredictNoTreeNoWP(channel.w, r + x, onerow, x, + y, tree[0].predictor); + pixel_type_w residual = r[x] - pred.guess; + JXL_DASSERT((residual >> mul_shift) * tree[0].multiplier == residual); + *tokenp++ = Token(tree[0].childID, PackSigned(residual >> mul_shift)); + } + } + + } else if (!use_wp && !skip_encoder_fast_path) { + const intptr_t onerow = channel.plane.PixelsPerRow(); + Channel references(properties.size() - kNumNonrefProperties, channel.w); + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT p = channel.Row(y); + PrecomputeReferences(channel, y, image, chan, &references); + float *pred_img_row[3]; + if (kWantDebug) { + for (size_t c = 0; c < 3; c++) { + pred_img_row[c] = predictor_img.PlaneRow(c, y); + } + } + InitPropsRow(&properties, static_props, y); + for (size_t x = 0; x < channel.w; x++) { + PredictionResult res = + PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references); + if (kWantDebug) { + for (size_t i = 0; i < 3; i++) { + pred_img_row[i][x] = PredictorColor(res.predictor)[i]; + } + } + pixel_type_w residual = p[x] - res.guess; + JXL_ASSERT(residual % res.multiplier == 0); + *tokenp++ = Token(res.context, PackSigned(residual / res.multiplier)); + } + } + } else { + const intptr_t onerow = channel.plane.PixelsPerRow(); + Channel references(properties.size() - kNumNonrefProperties, channel.w); + weighted::State wp_state(wp_header, channel.w, channel.h); + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT p = channel.Row(y); + PrecomputeReferences(channel, y, image, chan, &references); + float *pred_img_row[3]; + if (kWantDebug) { + for (size_t c = 0; c < 3; c++) { + pred_img_row[c] = predictor_img.PlaneRow(c, y); + } + } + InitPropsRow(&properties, static_props, y); + for (size_t x = 0; x < channel.w; x++) { + PredictionResult res = + PredictTreeWP(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references, &wp_state); + if (kWantDebug) { + for (size_t i = 0; i < 3; i++) { + pred_img_row[i][x] = PredictorColor(res.predictor)[i]; + } + } + pixel_type_w residual = p[x] - res.guess; + JXL_ASSERT(residual % res.multiplier == 0); + *tokenp++ = Token(res.context, PackSigned(residual / res.multiplier)); + wp_state.UpdateErrors(p[x], x, y, channel.w); + } + } + } + if (kWantDebug && WantDebugOutput(aux_out)) { + aux_out->DumpImage( + ("pred_" + ToString(group_id) + "_" + ToString(chan)).c_str(), + predictor_img); + } + *tokenpp = tokenp; + return true; +} + +Status ModularEncode(const Image &image, const ModularOptions &options, + BitWriter *writer, AuxOut *aux_out, size_t layer, + size_t group_id, TreeSamples *tree_samples, + size_t *total_pixels, const Tree *tree, + GroupHeader *header, std::vector<Token> *tokens, + size_t *width) { + if (image.error) return JXL_FAILURE("Invalid image"); + size_t nb_channels = image.channel.size(); + JXL_DEBUG_V( + 2, "Encoding %" PRIuS "-channel, %i-bit, %" PRIuS "x%" PRIuS " image.", + nb_channels, image.bitdepth, image.w, image.h); + + if (nb_channels < 1) { + return true; // is there any use for a zero-channel image? + } + + // encode transforms + GroupHeader header_storage; + if (header == nullptr) header = &header_storage; + Bundle::Init(header); + if (options.predictor == Predictor::Weighted) { + weighted::PredictorMode(options.wp_mode, &header->wp_header); + } + header->transforms = image.transform; + // This doesn't actually work + if (tree != nullptr) { + header->use_global_tree = true; + } + if (tree_samples == nullptr && tree == nullptr) { + JXL_RETURN_IF_ERROR(Bundle::Write(*header, writer, layer, aux_out)); + } + + TreeSamples tree_samples_storage; + size_t total_pixels_storage = 0; + if (!total_pixels) total_pixels = &total_pixels_storage; + // If there's no tree, compute one (or gather data to). + if (tree == nullptr) { + bool gather_data = tree_samples != nullptr; + if (tree_samples == nullptr) { + JXL_RETURN_IF_ERROR(tree_samples_storage.SetPredictor( + options.predictor, options.wp_tree_mode)); + JXL_RETURN_IF_ERROR(tree_samples_storage.SetProperties( + options.splitting_heuristics_properties, options.wp_tree_mode)); + std::vector<pixel_type> pixel_samples; + std::vector<pixel_type> diff_samples; + std::vector<uint32_t> group_pixel_count; + std::vector<uint32_t> channel_pixel_count; + CollectPixelSamples(image, options, 0, group_pixel_count, + channel_pixel_count, pixel_samples, diff_samples); + std::vector<ModularMultiplierInfo> dummy_multiplier_info; + StaticPropRange range; + tree_samples_storage.PreQuantizeProperties( + range, dummy_multiplier_info, group_pixel_count, channel_pixel_count, + pixel_samples, diff_samples, options.max_property_values); + } + for (size_t i = 0; i < nb_channels; i++) { + if (!image.channel[i].w || !image.channel[i].h) { + continue; // skip empty channels + } + if (i >= image.nb_meta_channels && + (image.channel[i].w > options.max_chan_size || + image.channel[i].h > options.max_chan_size)) { + break; + } + GatherTreeData(image, i, group_id, header->wp_header, options, + gather_data ? *tree_samples : tree_samples_storage, + total_pixels); + } + if (gather_data) return true; + } + + JXL_ASSERT((tree == nullptr) == (tokens == nullptr)); + + Tree tree_storage; + std::vector<std::vector<Token>> tokens_storage(1); + // Compute tree. + if (tree == nullptr) { + EntropyEncodingData code; + std::vector<uint8_t> context_map; + + std::vector<std::vector<Token>> tree_tokens(1); + tree_storage = + LearnTree(std::move(tree_samples_storage), *total_pixels, options); + tree = &tree_storage; + tokens = &tokens_storage[0]; + + Tree decoded_tree; + TokenizeTree(*tree, &tree_tokens[0], &decoded_tree); + JXL_ASSERT(tree->size() == decoded_tree.size()); + tree_storage = std::move(decoded_tree); + + if (kWantDebug && kPrintTree && WantDebugOutput(aux_out)) { + PrintTree(*tree, aux_out->debug_prefix + "/tree_" + ToString(group_id)); + } + // Write tree + BuildAndEncodeHistograms(HistogramParams(), kNumTreeContexts, tree_tokens, + &code, &context_map, writer, kLayerModularTree, + aux_out); + WriteTokens(tree_tokens[0], code, context_map, writer, kLayerModularTree, + aux_out); + } + + size_t image_width = 0; + size_t total_tokens = 0; + for (size_t i = 0; i < nb_channels; i++) { + if (i >= image.nb_meta_channels && + (image.channel[i].w > options.max_chan_size || + image.channel[i].h > options.max_chan_size)) { + break; + } + if (image.channel[i].w > image_width) image_width = image.channel[i].w; + total_tokens += image.channel[i].w * image.channel[i].h; + } + if (options.zero_tokens) { + tokens->resize(tokens->size() + total_tokens, {0, 0}); + } else { + // Do one big allocation for all the tokens we'll need, + // to avoid reallocs that might require copying. + size_t pos = tokens->size(); + tokens->resize(pos + total_tokens); + Token *tokenp = tokens->data() + pos; + for (size_t i = 0; i < nb_channels; i++) { + if (!image.channel[i].w || !image.channel[i].h) { + continue; // skip empty channels + } + if (i >= image.nb_meta_channels && + (image.channel[i].w > options.max_chan_size || + image.channel[i].h > options.max_chan_size)) { + break; + } + JXL_RETURN_IF_ERROR(EncodeModularChannelMAANS( + image, i, header->wp_header, *tree, &tokenp, aux_out, group_id, + options.skip_encoder_fast_path)); + } + // Make sure we actually wrote all tokens + JXL_CHECK(tokenp == tokens->data() + tokens->size()); + } + + // Write data if not using a global tree/ANS stream. + if (!header->use_global_tree) { + EntropyEncodingData code; + std::vector<uint8_t> context_map; + HistogramParams histo_params; + histo_params.image_widths.push_back(image_width); + BuildAndEncodeHistograms(histo_params, (tree->size() + 1) / 2, + tokens_storage, &code, &context_map, writer, layer, + aux_out); + WriteTokens(tokens_storage[0], code, context_map, writer, layer, aux_out); + } else { + *width = image_width; + } + return true; +} + +Status ModularGenericCompress(Image &image, const ModularOptions &opts, + BitWriter *writer, AuxOut *aux_out, size_t layer, + size_t group_id, TreeSamples *tree_samples, + size_t *total_pixels, const Tree *tree, + GroupHeader *header, std::vector<Token> *tokens, + size_t *width) { + if (image.w == 0 || image.h == 0) return true; + ModularOptions options = opts; // Make a copy to modify it. + + if (options.predictor == static_cast<Predictor>(-1)) { + options.predictor = Predictor::Gradient; + } + + size_t bits = writer ? writer->BitsWritten() : 0; + JXL_RETURN_IF_ERROR(ModularEncode(image, options, writer, aux_out, layer, + group_id, tree_samples, total_pixels, tree, + header, tokens, width)); + bits = writer ? writer->BitsWritten() - bits : 0; + if (writer) { + JXL_DEBUG_V(4, + "Modular-encoded a %" PRIuS "x%" PRIuS + " bitdepth=%i nbchans=%" PRIuS " image in %" PRIuS " bytes", + image.w, image.h, image.bitdepth, image.channel.size(), + bits / 8); + } + (void)bits; + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_encoding.h b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_encoding.h new file mode 100644 index 0000000000..04df504750 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_encoding.h @@ -0,0 +1,47 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_ENCODING_ENC_ENCODING_H_ +#define LIB_JXL_MODULAR_ENCODING_ENC_ENCODING_H_ + +#include <stddef.h> +#include <stdint.h> + +#include <vector> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/image.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/encoding/enc_ma.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/options.h" +#include "lib/jxl/modular/transform/transform.h" + +namespace jxl { + +Tree LearnTree(TreeSamples &&tree_samples, size_t total_pixels, + const ModularOptions &options, + const std::vector<ModularMultiplierInfo> &multiplier_info = {}, + StaticPropRange static_prop_range = {}); + +// TODO(veluca): make cleaner interfaces. + +Status ModularGenericCompress( + Image &image, const ModularOptions &opts, BitWriter *writer, + AuxOut *aux_out = nullptr, size_t layer = 0, size_t group_id = 0, + // For gathering data for producing a global tree. + TreeSamples *tree_samples = nullptr, size_t *total_pixels = nullptr, + // For encoding with global tree. + const Tree *tree = nullptr, GroupHeader *header = nullptr, + std::vector<Token> *tokens = nullptr, size_t *widths = nullptr); +} // namespace jxl + +#endif // LIB_JXL_MODULAR_ENCODING_ENC_ENCODING_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_ma.cc b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_ma.cc new file mode 100644 index 0000000000..d0f6b47566 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_ma.cc @@ -0,0 +1,1023 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/encoding/enc_ma.h" + +#include <algorithm> +#include <limits> +#include <numeric> +#include <queue> +#include <unordered_map> +#include <unordered_set> + +#include "lib/jxl/modular/encoding/ma_common.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/modular/encoding/enc_ma.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/base/random.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/fast_math-inl.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/options.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Eq; +using hwy::HWY_NAMESPACE::IfThenElse; +using hwy::HWY_NAMESPACE::Lt; + +const HWY_FULL(float) df; +const HWY_FULL(int32_t) di; +size_t Padded(size_t x) { return RoundUpTo(x, Lanes(df)); } + +float EstimateBits(const int32_t *counts, int32_t *rounded_counts, + size_t num_symbols) { + // Try to approximate the effect of rounding up nonzero probabilities. + int32_t total = std::accumulate(counts, counts + num_symbols, 0); + const auto min = Set(di, (total + ANS_TAB_SIZE - 1) >> ANS_LOG_TAB_SIZE); + const auto zero_i = Zero(di); + for (size_t i = 0; i < num_symbols; i += Lanes(df)) { + auto counts_v = LoadU(di, &counts[i]); + counts_v = IfThenElse(Eq(counts_v, zero_i), zero_i, + IfThenElse(Lt(counts_v, min), min, counts_v)); + StoreU(counts_v, di, &rounded_counts[i]); + } + // Compute entropy of the "rounded" probabilities. + const auto zero = Zero(df); + const size_t total_scalar = + std::accumulate(rounded_counts, rounded_counts + num_symbols, 0); + const auto inv_total = Set(df, 1.0f / total_scalar); + auto bits_lanes = Zero(df); + auto total_v = Set(di, total_scalar); + for (size_t i = 0; i < num_symbols; i += Lanes(df)) { + const auto counts_v = ConvertTo(df, LoadU(di, &counts[i])); + const auto round_counts_v = LoadU(di, &rounded_counts[i]); + const auto probs = Mul(ConvertTo(df, round_counts_v), inv_total); + const auto nbps = IfThenElse(Eq(round_counts_v, total_v), BitCast(di, zero), + BitCast(di, FastLog2f(df, probs))); + bits_lanes = Sub(bits_lanes, IfThenElse(Eq(counts_v, zero), zero, + Mul(counts_v, BitCast(df, nbps)))); + } + return GetLane(SumOfLanes(df, bits_lanes)); +} + +void MakeSplitNode(size_t pos, int property, int splitval, Predictor lpred, + int64_t loff, Predictor rpred, int64_t roff, Tree *tree) { + // Note that the tree splits on *strictly greater*. + (*tree)[pos].lchild = tree->size(); + (*tree)[pos].rchild = tree->size() + 1; + (*tree)[pos].splitval = splitval; + (*tree)[pos].property = property; + tree->emplace_back(); + tree->back().property = -1; + tree->back().predictor = rpred; + tree->back().predictor_offset = roff; + tree->back().multiplier = 1; + tree->emplace_back(); + tree->back().property = -1; + tree->back().predictor = lpred; + tree->back().predictor_offset = loff; + tree->back().multiplier = 1; +} + +enum class IntersectionType { kNone, kPartial, kInside }; +IntersectionType BoxIntersects(StaticPropRange needle, StaticPropRange haystack, + uint32_t &partial_axis, uint32_t &partial_val) { + bool partial = false; + for (size_t i = 0; i < kNumStaticProperties; i++) { + if (haystack[i][0] >= needle[i][1]) { + return IntersectionType::kNone; + } + if (haystack[i][1] <= needle[i][0]) { + return IntersectionType::kNone; + } + if (haystack[i][0] <= needle[i][0] && haystack[i][1] >= needle[i][1]) { + continue; + } + partial = true; + partial_axis = i; + if (haystack[i][0] > needle[i][0] && haystack[i][0] < needle[i][1]) { + partial_val = haystack[i][0] - 1; + } else { + JXL_DASSERT(haystack[i][1] > needle[i][0] && + haystack[i][1] < needle[i][1]); + partial_val = haystack[i][1] - 1; + } + } + return partial ? IntersectionType::kPartial : IntersectionType::kInside; +} + +void SplitTreeSamples(TreeSamples &tree_samples, size_t begin, size_t pos, + size_t end, size_t prop) { + auto cmp = [&](size_t a, size_t b) { + return int32_t(tree_samples.Property(prop, a)) - + int32_t(tree_samples.Property(prop, b)); + }; + Rng rng(0); + while (end > begin + 1) { + { + size_t pivot = rng.UniformU(begin, end); + tree_samples.Swap(begin, pivot); + } + size_t pivot_begin = begin; + size_t pivot_end = pivot_begin + 1; + for (size_t i = begin + 1; i < end; i++) { + JXL_DASSERT(i >= pivot_end); + JXL_DASSERT(pivot_end > pivot_begin); + int32_t cmp_result = cmp(i, pivot_begin); + if (cmp_result < 0) { // i < pivot, move pivot forward and put i before + // the pivot. + tree_samples.ThreeShuffle(pivot_begin, pivot_end, i); + pivot_begin++; + pivot_end++; + } else if (cmp_result == 0) { + tree_samples.Swap(pivot_end, i); + pivot_end++; + } + } + JXL_DASSERT(pivot_begin >= begin); + JXL_DASSERT(pivot_end > pivot_begin); + JXL_DASSERT(pivot_end <= end); + for (size_t i = begin; i < pivot_begin; i++) { + JXL_DASSERT(cmp(i, pivot_begin) < 0); + } + for (size_t i = pivot_end; i < end; i++) { + JXL_DASSERT(cmp(i, pivot_begin) > 0); + } + for (size_t i = pivot_begin; i < pivot_end; i++) { + JXL_DASSERT(cmp(i, pivot_begin) == 0); + } + // We now have that [begin, pivot_begin) is < pivot, [pivot_begin, + // pivot_end) is = pivot, and [pivot_end, end) is > pivot. + // If pos falls in the first or the last interval, we continue in that + // interval; otherwise, we are done. + if (pivot_begin > pos) { + end = pivot_begin; + } else if (pivot_end < pos) { + begin = pivot_end; + } else { + break; + } + } +} + +void FindBestSplit(TreeSamples &tree_samples, float threshold, + const std::vector<ModularMultiplierInfo> &mul_info, + StaticPropRange initial_static_prop_range, + float fast_decode_multiplier, Tree *tree) { + struct NodeInfo { + size_t pos; + size_t begin; + size_t end; + uint64_t used_properties; + StaticPropRange static_prop_range; + }; + std::vector<NodeInfo> nodes; + nodes.push_back(NodeInfo{0, 0, tree_samples.NumDistinctSamples(), 0, + initial_static_prop_range}); + + size_t num_predictors = tree_samples.NumPredictors(); + size_t num_properties = tree_samples.NumProperties(); + + // TODO(veluca): consider parallelizing the search (processing multiple nodes + // at a time). + while (!nodes.empty()) { + size_t pos = nodes.back().pos; + size_t begin = nodes.back().begin; + size_t end = nodes.back().end; + uint64_t used_properties = nodes.back().used_properties; + StaticPropRange static_prop_range = nodes.back().static_prop_range; + nodes.pop_back(); + if (begin == end) continue; + + struct SplitInfo { + size_t prop = 0; + uint32_t val = 0; + size_t pos = 0; + float lcost = std::numeric_limits<float>::max(); + float rcost = std::numeric_limits<float>::max(); + Predictor lpred = Predictor::Zero; + Predictor rpred = Predictor::Zero; + float Cost() { return lcost + rcost; } + }; + + SplitInfo best_split_static_constant; + SplitInfo best_split_static; + SplitInfo best_split_nonstatic; + SplitInfo best_split_nowp; + + JXL_DASSERT(begin <= end); + JXL_DASSERT(end <= tree_samples.NumDistinctSamples()); + + // Compute the maximum token in the range. + size_t max_symbols = 0; + for (size_t pred = 0; pred < num_predictors; pred++) { + for (size_t i = begin; i < end; i++) { + uint32_t tok = tree_samples.Token(pred, i); + max_symbols = max_symbols > tok + 1 ? max_symbols : tok + 1; + } + } + max_symbols = Padded(max_symbols); + std::vector<int32_t> rounded_counts(max_symbols); + std::vector<int32_t> counts(max_symbols * num_predictors); + std::vector<uint32_t> tot_extra_bits(num_predictors); + for (size_t pred = 0; pred < num_predictors; pred++) { + for (size_t i = begin; i < end; i++) { + counts[pred * max_symbols + tree_samples.Token(pred, i)] += + tree_samples.Count(i); + tot_extra_bits[pred] += + tree_samples.NBits(pred, i) * tree_samples.Count(i); + } + } + + float base_bits; + { + size_t pred = tree_samples.PredictorIndex((*tree)[pos].predictor); + base_bits = EstimateBits(counts.data() + pred * max_symbols, + rounded_counts.data(), max_symbols) + + tot_extra_bits[pred]; + } + + SplitInfo *best = &best_split_nonstatic; + + SplitInfo forced_split; + // The multiplier ranges cut halfway through the current ranges of static + // properties. We do this even if the current node is not a leaf, to + // minimize the number of nodes in the resulting tree. + for (size_t i = 0; i < mul_info.size(); i++) { + uint32_t axis, val; + IntersectionType t = + BoxIntersects(static_prop_range, mul_info[i].range, axis, val); + if (t == IntersectionType::kNone) continue; + if (t == IntersectionType::kInside) { + (*tree)[pos].multiplier = mul_info[i].multiplier; + break; + } + if (t == IntersectionType::kPartial) { + forced_split.val = tree_samples.QuantizeProperty(axis, val); + forced_split.prop = axis; + forced_split.lcost = forced_split.rcost = base_bits / 2 - threshold; + forced_split.lpred = forced_split.rpred = (*tree)[pos].predictor; + best = &forced_split; + best->pos = begin; + JXL_ASSERT(best->prop == tree_samples.PropertyFromIndex(best->prop)); + for (size_t x = begin; x < end; x++) { + if (tree_samples.Property(best->prop, x) <= best->val) { + best->pos++; + } + } + break; + } + } + + if (best != &forced_split) { + std::vector<int> prop_value_used_count; + std::vector<int> count_increase; + std::vector<size_t> extra_bits_increase; + // For each property, compute which of its values are used, and what + // tokens correspond to those usages. Then, iterate through the values, + // and compute the entropy of each side of the split (of the form `prop > + // threshold`). Finally, find the split that minimizes the cost. + struct CostInfo { + float cost = std::numeric_limits<float>::max(); + float extra_cost = 0; + float Cost() const { return cost + extra_cost; } + Predictor pred; // will be uninitialized in some cases, but never used. + }; + std::vector<CostInfo> costs_l; + std::vector<CostInfo> costs_r; + + std::vector<int32_t> counts_above(max_symbols); + std::vector<int32_t> counts_below(max_symbols); + + // The lower the threshold, the higher the expected noisiness of the + // estimate. Thus, discourage changing predictors. + float change_pred_penalty = 800.0f / (100.0f + threshold); + for (size_t prop = 0; prop < num_properties && base_bits > threshold; + prop++) { + costs_l.clear(); + costs_r.clear(); + size_t prop_size = tree_samples.NumPropertyValues(prop); + if (extra_bits_increase.size() < prop_size) { + count_increase.resize(prop_size * max_symbols); + extra_bits_increase.resize(prop_size); + } + // Clear prop_value_used_count (which cannot be cleared "on the go") + prop_value_used_count.clear(); + prop_value_used_count.resize(prop_size); + + size_t first_used = prop_size; + size_t last_used = 0; + + // TODO(veluca): consider finding multiple splits along a single + // property at the same time, possibly with a bottom-up approach. + for (size_t i = begin; i < end; i++) { + size_t p = tree_samples.Property(prop, i); + prop_value_used_count[p]++; + last_used = std::max(last_used, p); + first_used = std::min(first_used, p); + } + costs_l.resize(last_used - first_used); + costs_r.resize(last_used - first_used); + // For all predictors, compute the right and left costs of each split. + for (size_t pred = 0; pred < num_predictors; pred++) { + // Compute cost and histogram increments for each property value. + for (size_t i = begin; i < end; i++) { + size_t p = tree_samples.Property(prop, i); + size_t cnt = tree_samples.Count(i); + size_t sym = tree_samples.Token(pred, i); + count_increase[p * max_symbols + sym] += cnt; + extra_bits_increase[p] += tree_samples.NBits(pred, i) * cnt; + } + memcpy(counts_above.data(), counts.data() + pred * max_symbols, + max_symbols * sizeof counts_above[0]); + memset(counts_below.data(), 0, max_symbols * sizeof counts_below[0]); + size_t extra_bits_below = 0; + // Exclude last used: this ensures neither counts_above nor + // counts_below is empty. + for (size_t i = first_used; i < last_used; i++) { + if (!prop_value_used_count[i]) continue; + extra_bits_below += extra_bits_increase[i]; + // The increase for this property value has been used, and will not + // be used again: clear it. Also below. + extra_bits_increase[i] = 0; + for (size_t sym = 0; sym < max_symbols; sym++) { + counts_above[sym] -= count_increase[i * max_symbols + sym]; + counts_below[sym] += count_increase[i * max_symbols + sym]; + count_increase[i * max_symbols + sym] = 0; + } + float rcost = EstimateBits(counts_above.data(), + rounded_counts.data(), max_symbols) + + tot_extra_bits[pred] - extra_bits_below; + float lcost = EstimateBits(counts_below.data(), + rounded_counts.data(), max_symbols) + + extra_bits_below; + JXL_DASSERT(extra_bits_below <= tot_extra_bits[pred]); + float penalty = 0; + // Never discourage moving away from the Weighted predictor. + if (tree_samples.PredictorFromIndex(pred) != + (*tree)[pos].predictor && + (*tree)[pos].predictor != Predictor::Weighted) { + penalty = change_pred_penalty; + } + // If everything else is equal, disfavour Weighted (slower) and + // favour Zero (faster if it's the only predictor used in a + // group+channel combination) + if (tree_samples.PredictorFromIndex(pred) == Predictor::Weighted) { + penalty += 1e-8; + } + if (tree_samples.PredictorFromIndex(pred) == Predictor::Zero) { + penalty -= 1e-8; + } + if (rcost + penalty < costs_r[i - first_used].Cost()) { + costs_r[i - first_used].cost = rcost; + costs_r[i - first_used].extra_cost = penalty; + costs_r[i - first_used].pred = + tree_samples.PredictorFromIndex(pred); + } + if (lcost + penalty < costs_l[i - first_used].Cost()) { + costs_l[i - first_used].cost = lcost; + costs_l[i - first_used].extra_cost = penalty; + costs_l[i - first_used].pred = + tree_samples.PredictorFromIndex(pred); + } + } + } + // Iterate through the possible splits and find the one with minimum sum + // of costs of the two sides. + size_t split = begin; + for (size_t i = first_used; i < last_used; i++) { + if (!prop_value_used_count[i]) continue; + split += prop_value_used_count[i]; + float rcost = costs_r[i - first_used].cost; + float lcost = costs_l[i - first_used].cost; + // WP was not used + we would use the WP property or predictor + bool adds_wp = + (tree_samples.PropertyFromIndex(prop) == kWPProp && + (used_properties & (1LU << prop)) == 0) || + ((costs_l[i - first_used].pred == Predictor::Weighted || + costs_r[i - first_used].pred == Predictor::Weighted) && + (*tree)[pos].predictor != Predictor::Weighted); + bool zero_entropy_side = rcost == 0 || lcost == 0; + + SplitInfo &best = + prop < kNumStaticProperties + ? (zero_entropy_side ? best_split_static_constant + : best_split_static) + : (adds_wp ? best_split_nonstatic : best_split_nowp); + if (lcost + rcost < best.Cost()) { + best.prop = prop; + best.val = i; + best.pos = split; + best.lcost = lcost; + best.lpred = costs_l[i - first_used].pred; + best.rcost = rcost; + best.rpred = costs_r[i - first_used].pred; + } + } + // Clear extra_bits_increase and cost_increase for last_used. + extra_bits_increase[last_used] = 0; + for (size_t sym = 0; sym < max_symbols; sym++) { + count_increase[last_used * max_symbols + sym] = 0; + } + } + + // Try to avoid introducing WP. + if (best_split_nowp.Cost() + threshold < base_bits && + best_split_nowp.Cost() <= fast_decode_multiplier * best->Cost()) { + best = &best_split_nowp; + } + // Split along static props if possible and not significantly more + // expensive. + if (best_split_static.Cost() + threshold < base_bits && + best_split_static.Cost() <= fast_decode_multiplier * best->Cost()) { + best = &best_split_static; + } + // Split along static props to create constant nodes if possible. + if (best_split_static_constant.Cost() + threshold < base_bits) { + best = &best_split_static_constant; + } + } + + if (best->Cost() + threshold < base_bits) { + uint32_t p = tree_samples.PropertyFromIndex(best->prop); + pixel_type dequant = + tree_samples.UnquantizeProperty(best->prop, best->val); + // Split node and try to split children. + MakeSplitNode(pos, p, dequant, best->lpred, 0, best->rpred, 0, tree); + // "Sort" according to winning property + SplitTreeSamples(tree_samples, begin, best->pos, end, best->prop); + if (p >= kNumStaticProperties) { + used_properties |= 1 << best->prop; + } + auto new_sp_range = static_prop_range; + if (p < kNumStaticProperties) { + JXL_ASSERT(static_cast<uint32_t>(dequant + 1) <= new_sp_range[p][1]); + new_sp_range[p][1] = dequant + 1; + JXL_ASSERT(new_sp_range[p][0] < new_sp_range[p][1]); + } + nodes.push_back(NodeInfo{(*tree)[pos].rchild, begin, best->pos, + used_properties, new_sp_range}); + new_sp_range = static_prop_range; + if (p < kNumStaticProperties) { + JXL_ASSERT(new_sp_range[p][0] <= static_cast<uint32_t>(dequant + 1)); + new_sp_range[p][0] = dequant + 1; + JXL_ASSERT(new_sp_range[p][0] < new_sp_range[p][1]); + } + nodes.push_back(NodeInfo{(*tree)[pos].lchild, best->pos, end, + used_properties, new_sp_range}); + } + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(FindBestSplit); // Local function. + +void ComputeBestTree(TreeSamples &tree_samples, float threshold, + const std::vector<ModularMultiplierInfo> &mul_info, + StaticPropRange static_prop_range, + float fast_decode_multiplier, Tree *tree) { + // TODO(veluca): take into account that different contexts can have different + // uint configs. + // + // Initialize tree. + tree->emplace_back(); + tree->back().property = -1; + tree->back().predictor = tree_samples.PredictorFromIndex(0); + tree->back().predictor_offset = 0; + tree->back().multiplier = 1; + JXL_ASSERT(tree_samples.NumProperties() < 64); + + JXL_ASSERT(tree_samples.NumDistinctSamples() <= + std::numeric_limits<uint32_t>::max()); + HWY_DYNAMIC_DISPATCH(FindBestSplit) + (tree_samples, threshold, mul_info, static_prop_range, fast_decode_multiplier, + tree); +} + +constexpr int32_t TreeSamples::kPropertyRange; +constexpr uint32_t TreeSamples::kDedupEntryUnused; + +Status TreeSamples::SetPredictor(Predictor predictor, + ModularOptions::TreeMode wp_tree_mode) { + if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) { + predictors = {Predictor::Weighted}; + residuals.resize(1); + return true; + } + if (wp_tree_mode == ModularOptions::TreeMode::kNoWP && + predictor == Predictor::Weighted) { + return JXL_FAILURE("Invalid predictor settings"); + } + if (predictor == Predictor::Variable) { + for (size_t i = 0; i < kNumModularPredictors; i++) { + predictors.push_back(static_cast<Predictor>(i)); + } + std::swap(predictors[0], predictors[static_cast<int>(Predictor::Weighted)]); + std::swap(predictors[1], predictors[static_cast<int>(Predictor::Gradient)]); + } else if (predictor == Predictor::Best) { + predictors = {Predictor::Weighted, Predictor::Gradient}; + } else { + predictors = {predictor}; + } + if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) { + auto wp_it = + std::find(predictors.begin(), predictors.end(), Predictor::Weighted); + if (wp_it != predictors.end()) { + predictors.erase(wp_it); + } + } + residuals.resize(predictors.size()); + return true; +} + +Status TreeSamples::SetProperties(const std::vector<uint32_t> &properties, + ModularOptions::TreeMode wp_tree_mode) { + props_to_use = properties; + if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) { + props_to_use = {static_cast<uint32_t>(kWPProp)}; + } + if (wp_tree_mode == ModularOptions::TreeMode::kGradientOnly) { + props_to_use = {static_cast<uint32_t>(kGradientProp)}; + } + if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) { + auto it = std::find(props_to_use.begin(), props_to_use.end(), kWPProp); + if (it != props_to_use.end()) { + props_to_use.erase(it); + } + } + if (props_to_use.empty()) { + return JXL_FAILURE("Invalid property set configuration"); + } + props.resize(props_to_use.size()); + return true; +} + +void TreeSamples::InitTable(size_t size) { + JXL_DASSERT((size & (size - 1)) == 0); + if (dedup_table_.size() == size) return; + dedup_table_.resize(size, kDedupEntryUnused); + for (size_t i = 0; i < NumDistinctSamples(); i++) { + if (sample_counts[i] != std::numeric_limits<uint16_t>::max()) { + AddToTable(i); + } + } +} + +bool TreeSamples::AddToTableAndMerge(size_t a) { + size_t pos1 = Hash1(a); + size_t pos2 = Hash2(a); + if (dedup_table_[pos1] != kDedupEntryUnused && + IsSameSample(a, dedup_table_[pos1])) { + JXL_DASSERT(sample_counts[a] == 1); + sample_counts[dedup_table_[pos1]]++; + // Remove from hash table samples that are saturated. + if (sample_counts[dedup_table_[pos1]] == + std::numeric_limits<uint16_t>::max()) { + dedup_table_[pos1] = kDedupEntryUnused; + } + return true; + } + if (dedup_table_[pos2] != kDedupEntryUnused && + IsSameSample(a, dedup_table_[pos2])) { + JXL_DASSERT(sample_counts[a] == 1); + sample_counts[dedup_table_[pos2]]++; + // Remove from hash table samples that are saturated. + if (sample_counts[dedup_table_[pos2]] == + std::numeric_limits<uint16_t>::max()) { + dedup_table_[pos2] = kDedupEntryUnused; + } + return true; + } + AddToTable(a); + return false; +} + +void TreeSamples::AddToTable(size_t a) { + size_t pos1 = Hash1(a); + size_t pos2 = Hash2(a); + if (dedup_table_[pos1] == kDedupEntryUnused) { + dedup_table_[pos1] = a; + } else if (dedup_table_[pos2] == kDedupEntryUnused) { + dedup_table_[pos2] = a; + } +} + +void TreeSamples::PrepareForSamples(size_t num_samples) { + for (auto &res : residuals) { + res.reserve(res.size() + num_samples); + } + for (auto &p : props) { + p.reserve(p.size() + num_samples); + } + size_t total_num_samples = num_samples + sample_counts.size(); + size_t next_pow2 = 1LLU << CeilLog2Nonzero(total_num_samples * 3 / 2); + InitTable(next_pow2); +} + +size_t TreeSamples::Hash1(size_t a) const { + constexpr uint64_t constant = 0x1e35a7bd; + uint64_t h = constant; + for (const auto &r : residuals) { + h = h * constant + r[a].tok; + h = h * constant + r[a].nbits; + } + for (const auto &p : props) { + h = h * constant + p[a]; + } + return (h >> 16) & (dedup_table_.size() - 1); +} +size_t TreeSamples::Hash2(size_t a) const { + constexpr uint64_t constant = 0x1e35a7bd1e35a7bd; + uint64_t h = constant; + for (const auto &p : props) { + h = h * constant ^ p[a]; + } + for (const auto &r : residuals) { + h = h * constant ^ r[a].tok; + h = h * constant ^ r[a].nbits; + } + return (h >> 16) & (dedup_table_.size() - 1); +} + +bool TreeSamples::IsSameSample(size_t a, size_t b) const { + bool ret = true; + for (const auto &r : residuals) { + if (r[a].tok != r[b].tok) { + ret = false; + } + if (r[a].nbits != r[b].nbits) { + ret = false; + } + } + for (const auto &p : props) { + if (p[a] != p[b]) { + ret = false; + } + } + return ret; +} + +void TreeSamples::AddSample(pixel_type_w pixel, const Properties &properties, + const pixel_type_w *predictions) { + for (size_t i = 0; i < predictors.size(); i++) { + pixel_type v = pixel - predictions[static_cast<int>(predictors[i])]; + uint32_t tok, nbits, bits; + HybridUintConfig(4, 1, 2).Encode(PackSigned(v), &tok, &nbits, &bits); + JXL_DASSERT(tok < 256); + JXL_DASSERT(nbits < 256); + residuals[i].emplace_back( + ResidualToken{static_cast<uint8_t>(tok), static_cast<uint8_t>(nbits)}); + } + for (size_t i = 0; i < props_to_use.size(); i++) { + props[i].push_back(QuantizeProperty(i, properties[props_to_use[i]])); + } + sample_counts.push_back(1); + num_samples++; + if (AddToTableAndMerge(sample_counts.size() - 1)) { + for (auto &r : residuals) r.pop_back(); + for (auto &p : props) p.pop_back(); + sample_counts.pop_back(); + } +} + +void TreeSamples::Swap(size_t a, size_t b) { + if (a == b) return; + for (auto &r : residuals) { + std::swap(r[a], r[b]); + } + for (auto &p : props) { + std::swap(p[a], p[b]); + } + std::swap(sample_counts[a], sample_counts[b]); +} + +void TreeSamples::ThreeShuffle(size_t a, size_t b, size_t c) { + if (b == c) return Swap(a, b); + for (auto &r : residuals) { + auto tmp = r[a]; + r[a] = r[c]; + r[c] = r[b]; + r[b] = tmp; + } + for (auto &p : props) { + auto tmp = p[a]; + p[a] = p[c]; + p[c] = p[b]; + p[b] = tmp; + } + auto tmp = sample_counts[a]; + sample_counts[a] = sample_counts[c]; + sample_counts[c] = sample_counts[b]; + sample_counts[b] = tmp; +} + +namespace { +std::vector<int32_t> QuantizeHistogram(const std::vector<uint32_t> &histogram, + size_t num_chunks) { + if (histogram.empty()) return {}; + // TODO(veluca): selecting distinct quantiles is likely not the best + // way to go about this. + std::vector<int32_t> thresholds; + size_t sum = std::accumulate(histogram.begin(), histogram.end(), 0LU); + size_t cumsum = 0; + size_t threshold = 0; + for (size_t i = 0; i + 1 < histogram.size(); i++) { + cumsum += histogram[i]; + if (cumsum > (threshold + 1) * sum / num_chunks) { + thresholds.push_back(i); + while (cumsum >= (threshold + 1) * sum / num_chunks) threshold++; + } + } + return thresholds; +} + +std::vector<int32_t> QuantizeSamples(const std::vector<int32_t> &samples, + size_t num_chunks) { + if (samples.empty()) return {}; + int min = *std::min_element(samples.begin(), samples.end()); + constexpr int kRange = 512; + min = std::min(std::max(min, -kRange), kRange); + std::vector<uint32_t> counts(2 * kRange + 1); + for (int s : samples) { + uint32_t sample_offset = std::min(std::max(s, -kRange), kRange) - min; + counts[sample_offset]++; + } + std::vector<int32_t> thresholds = QuantizeHistogram(counts, num_chunks); + for (auto &v : thresholds) v += min; + return thresholds; +} +} // namespace + +void TreeSamples::PreQuantizeProperties( + const StaticPropRange &range, + const std::vector<ModularMultiplierInfo> &multiplier_info, + const std::vector<uint32_t> &group_pixel_count, + const std::vector<uint32_t> &channel_pixel_count, + std::vector<pixel_type> &pixel_samples, + std::vector<pixel_type> &diff_samples, size_t max_property_values) { + // If we have forced splits because of multipliers, choose channel and group + // thresholds accordingly. + std::vector<int32_t> group_multiplier_thresholds; + std::vector<int32_t> channel_multiplier_thresholds; + for (const auto &v : multiplier_info) { + if (v.range[0][0] != range[0][0]) { + channel_multiplier_thresholds.push_back(v.range[0][0] - 1); + } + if (v.range[0][1] != range[0][1]) { + channel_multiplier_thresholds.push_back(v.range[0][1] - 1); + } + if (v.range[1][0] != range[1][0]) { + group_multiplier_thresholds.push_back(v.range[1][0] - 1); + } + if (v.range[1][1] != range[1][1]) { + group_multiplier_thresholds.push_back(v.range[1][1] - 1); + } + } + std::sort(channel_multiplier_thresholds.begin(), + channel_multiplier_thresholds.end()); + channel_multiplier_thresholds.resize( + std::unique(channel_multiplier_thresholds.begin(), + channel_multiplier_thresholds.end()) - + channel_multiplier_thresholds.begin()); + std::sort(group_multiplier_thresholds.begin(), + group_multiplier_thresholds.end()); + group_multiplier_thresholds.resize( + std::unique(group_multiplier_thresholds.begin(), + group_multiplier_thresholds.end()) - + group_multiplier_thresholds.begin()); + + compact_properties.resize(props_to_use.size()); + auto quantize_channel = [&]() { + if (!channel_multiplier_thresholds.empty()) { + return channel_multiplier_thresholds; + } + return QuantizeHistogram(channel_pixel_count, max_property_values); + }; + auto quantize_group_id = [&]() { + if (!group_multiplier_thresholds.empty()) { + return group_multiplier_thresholds; + } + return QuantizeHistogram(group_pixel_count, max_property_values); + }; + auto quantize_coordinate = [&]() { + std::vector<int32_t> quantized; + quantized.reserve(max_property_values - 1); + for (size_t i = 0; i + 1 < max_property_values; i++) { + quantized.push_back((i + 1) * 256 / max_property_values - 1); + } + return quantized; + }; + std::vector<int32_t> abs_pixel_thr; + std::vector<int32_t> pixel_thr; + auto quantize_pixel_property = [&]() { + if (pixel_thr.empty()) { + pixel_thr = QuantizeSamples(pixel_samples, max_property_values); + } + return pixel_thr; + }; + auto quantize_abs_pixel_property = [&]() { + if (abs_pixel_thr.empty()) { + quantize_pixel_property(); // Compute the non-abs thresholds. + for (auto &v : pixel_samples) v = std::abs(v); + abs_pixel_thr = QuantizeSamples(pixel_samples, max_property_values); + } + return abs_pixel_thr; + }; + std::vector<int32_t> abs_diff_thr; + std::vector<int32_t> diff_thr; + auto quantize_diff_property = [&]() { + if (diff_thr.empty()) { + diff_thr = QuantizeSamples(diff_samples, max_property_values); + } + return diff_thr; + }; + auto quantize_abs_diff_property = [&]() { + if (abs_diff_thr.empty()) { + quantize_diff_property(); // Compute the non-abs thresholds. + for (auto &v : diff_samples) v = std::abs(v); + abs_diff_thr = QuantizeSamples(diff_samples, max_property_values); + } + return abs_diff_thr; + }; + auto quantize_wp = [&]() { + if (max_property_values < 32) { + return std::vector<int32_t>{-127, -63, -31, -15, -7, -3, -1, 0, + 1, 3, 7, 15, 31, 63, 127}; + } + if (max_property_values < 64) { + return std::vector<int32_t>{-255, -191, -127, -95, -63, -47, -31, -23, + -15, -11, -7, -5, -3, -1, 0, 1, + 3, 5, 7, 11, 15, 23, 31, 47, + 63, 95, 127, 191, 255}; + } + return std::vector<int32_t>{ + -255, -223, -191, -159, -127, -111, -95, -79, -63, -55, -47, + -39, -31, -27, -23, -19, -15, -13, -11, -9, -7, -6, + -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, + 6, 7, 9, 11, 13, 15, 19, 23, 27, 31, 39, + 47, 55, 63, 79, 95, 111, 127, 159, 191, 223, 255}; + }; + + property_mapping.resize(props_to_use.size()); + for (size_t i = 0; i < props_to_use.size(); i++) { + if (props_to_use[i] == 0) { + compact_properties[i] = quantize_channel(); + } else if (props_to_use[i] == 1) { + compact_properties[i] = quantize_group_id(); + } else if (props_to_use[i] == 2 || props_to_use[i] == 3) { + compact_properties[i] = quantize_coordinate(); + } else if (props_to_use[i] == 6 || props_to_use[i] == 7 || + props_to_use[i] == 8 || + (props_to_use[i] >= kNumNonrefProperties && + (props_to_use[i] - kNumNonrefProperties) % 4 == 1)) { + compact_properties[i] = quantize_pixel_property(); + } else if (props_to_use[i] == 4 || props_to_use[i] == 5 || + (props_to_use[i] >= kNumNonrefProperties && + (props_to_use[i] - kNumNonrefProperties) % 4 == 0)) { + compact_properties[i] = quantize_abs_pixel_property(); + } else if (props_to_use[i] >= kNumNonrefProperties && + (props_to_use[i] - kNumNonrefProperties) % 4 == 2) { + compact_properties[i] = quantize_abs_diff_property(); + } else if (props_to_use[i] == kWPProp) { + compact_properties[i] = quantize_wp(); + } else { + compact_properties[i] = quantize_diff_property(); + } + property_mapping[i].resize(kPropertyRange * 2 + 1); + size_t mapped = 0; + for (size_t j = 0; j < property_mapping[i].size(); j++) { + while (mapped < compact_properties[i].size() && + static_cast<int>(j) - kPropertyRange > + compact_properties[i][mapped]) { + mapped++; + } + // property_mapping[i] of a value V is `mapped` if + // compact_properties[i][mapped] <= j and + // compact_properties[i][mapped-1] > j + // This is because the decision node in the tree splits on (property) > j, + // hence everything that is not > of a threshold should be clustered + // together. + property_mapping[i][j] = mapped; + } + } +} + +void CollectPixelSamples(const Image &image, const ModularOptions &options, + size_t group_id, + std::vector<uint32_t> &group_pixel_count, + std::vector<uint32_t> &channel_pixel_count, + std::vector<pixel_type> &pixel_samples, + std::vector<pixel_type> &diff_samples) { + if (options.nb_repeats == 0) return; + if (group_pixel_count.size() <= group_id) { + group_pixel_count.resize(group_id + 1); + } + if (channel_pixel_count.size() < image.channel.size()) { + channel_pixel_count.resize(image.channel.size()); + } + Rng rng(group_id); + // Sample 10% of the final number of samples for property quantization. + float fraction = std::min(options.nb_repeats * 0.1, 0.99); + Rng::GeometricDistribution dist(fraction); + size_t total_pixels = 0; + std::vector<size_t> channel_ids; + for (size_t i = 0; i < image.channel.size(); i++) { + if (image.channel[i].w <= 1 || image.channel[i].h == 0) { + continue; // skip empty or width-1 channels. + } + if (i >= image.nb_meta_channels && + (image.channel[i].w > options.max_chan_size || + image.channel[i].h > options.max_chan_size)) { + break; + } + channel_ids.push_back(i); + group_pixel_count[group_id] += image.channel[i].w * image.channel[i].h; + channel_pixel_count[i] += image.channel[i].w * image.channel[i].h; + total_pixels += image.channel[i].w * image.channel[i].h; + } + if (channel_ids.empty()) return; + pixel_samples.reserve(pixel_samples.size() + fraction * total_pixels); + diff_samples.reserve(diff_samples.size() + fraction * total_pixels); + size_t i = 0; + size_t y = 0; + size_t x = 0; + auto advance = [&](size_t amount) { + x += amount; + // Detect row overflow (rare). + while (x >= image.channel[channel_ids[i]].w) { + x -= image.channel[channel_ids[i]].w; + y++; + // Detect end-of-channel (even rarer). + if (y == image.channel[channel_ids[i]].h) { + i++; + y = 0; + if (i >= channel_ids.size()) { + return; + } + } + } + }; + advance(rng.Geometric(dist)); + for (; i < channel_ids.size(); advance(rng.Geometric(dist) + 1)) { + const pixel_type *row = image.channel[channel_ids[i]].Row(y); + pixel_samples.push_back(row[x]); + size_t xp = x == 0 ? 1 : x - 1; + diff_samples.push_back((int64_t)row[x] - row[xp]); + } +} + +// TODO(veluca): very simple encoding scheme. This should be improved. +void TokenizeTree(const Tree &tree, std::vector<Token> *tokens, + Tree *decoder_tree) { + JXL_ASSERT(tree.size() <= kMaxTreeSize); + std::queue<int> q; + q.push(0); + size_t leaf_id = 0; + decoder_tree->clear(); + while (!q.empty()) { + int cur = q.front(); + q.pop(); + JXL_ASSERT(tree[cur].property >= -1); + tokens->emplace_back(kPropertyContext, tree[cur].property + 1); + if (tree[cur].property == -1) { + tokens->emplace_back(kPredictorContext, + static_cast<int>(tree[cur].predictor)); + tokens->emplace_back(kOffsetContext, + PackSigned(tree[cur].predictor_offset)); + uint32_t mul_log = Num0BitsBelowLS1Bit_Nonzero(tree[cur].multiplier); + uint32_t mul_bits = (tree[cur].multiplier >> mul_log) - 1; + tokens->emplace_back(kMultiplierLogContext, mul_log); + tokens->emplace_back(kMultiplierBitsContext, mul_bits); + JXL_ASSERT(tree[cur].predictor < Predictor::Best); + decoder_tree->emplace_back(-1, 0, leaf_id++, 0, tree[cur].predictor, + tree[cur].predictor_offset, + tree[cur].multiplier); + continue; + } + decoder_tree->emplace_back(tree[cur].property, tree[cur].splitval, + decoder_tree->size() + q.size() + 1, + decoder_tree->size() + q.size() + 2, + Predictor::Zero, 0, 1); + q.push(tree[cur].lchild); + q.push(tree[cur].rchild); + tokens->emplace_back(kSplitValContext, PackSigned(tree[cur].splitval)); + } +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_ma.h b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_ma.h new file mode 100644 index 0000000000..ede37c8023 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_ma.h @@ -0,0 +1,157 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_ENCODING_ENC_MA_H_ +#define LIB_JXL_MODULAR_ENCODING_ENC_MA_H_ + +#include <numeric> + +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/modular/encoding/dec_ma.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/options.h" + +namespace jxl { + +// Struct to collect all the data needed to build a tree. +struct TreeSamples { + bool HasSamples() const { + return !residuals.empty() && !residuals[0].empty(); + } + size_t NumDistinctSamples() const { return sample_counts.size(); } + size_t NumSamples() const { return num_samples; } + // Set the predictor to use. Must be called before adding any samples. + Status SetPredictor(Predictor predictor, + ModularOptions::TreeMode wp_tree_mode); + // Set the properties to use. Must be called before adding any samples. + Status SetProperties(const std::vector<uint32_t> &properties, + ModularOptions::TreeMode wp_tree_mode); + + size_t Token(size_t pred, size_t i) const { return residuals[pred][i].tok; } + size_t NBits(size_t pred, size_t i) const { return residuals[pred][i].nbits; } + size_t Count(size_t i) const { return sample_counts[i]; } + size_t PredictorIndex(Predictor predictor) const { + const auto predictor_elem = + std::find(predictors.begin(), predictors.end(), predictor); + JXL_DASSERT(predictor_elem != predictors.end()); + return predictor_elem - predictors.begin(); + } + size_t PropertyIndex(size_t property) const { + const auto property_elem = + std::find(props_to_use.begin(), props_to_use.end(), property); + JXL_DASSERT(property_elem != props_to_use.end()); + return property_elem - props_to_use.begin(); + } + size_t NumPropertyValues(size_t property_index) const { + return compact_properties[property_index].size() + 1; + } + // Returns the *quantized* property value. + size_t Property(size_t property_index, size_t i) const { + return props[property_index][i]; + } + int UnquantizeProperty(size_t property_index, uint32_t quant) const { + JXL_ASSERT(quant < compact_properties[property_index].size()); + return compact_properties[property_index][quant]; + } + + Predictor PredictorFromIndex(size_t index) const { + JXL_DASSERT(index < predictors.size()); + return predictors[index]; + } + size_t PropertyFromIndex(size_t index) const { + JXL_DASSERT(index < props_to_use.size()); + return props_to_use[index]; + } + size_t NumPredictors() const { return predictors.size(); } + size_t NumProperties() const { return props_to_use.size(); } + + // Preallocate data for a given number of samples. MUST be called before + // adding any sample. + void PrepareForSamples(size_t num_samples); + // Add a sample. + void AddSample(pixel_type_w pixel, const Properties &properties, + const pixel_type_w *predictions); + // Pre-cluster property values. + void PreQuantizeProperties( + const StaticPropRange &range, + const std::vector<ModularMultiplierInfo> &multiplier_info, + const std::vector<uint32_t> &group_pixel_count, + const std::vector<uint32_t> &channel_pixel_count, + std::vector<pixel_type> &pixel_samples, + std::vector<pixel_type> &diff_samples, size_t max_property_values); + + void AllSamplesDone() { dedup_table_ = std::vector<uint32_t>(); } + + uint32_t QuantizeProperty(uint32_t prop, pixel_type v) const { + v = std::min(std::max(v, -kPropertyRange), kPropertyRange) + kPropertyRange; + return property_mapping[prop][v]; + } + + // Swaps samples in position a and b. Does nothing if a == b. + void Swap(size_t a, size_t b); + + // Cycles samples: a -> b -> c -> a. We assume a <= b <= c, so that we can + // just call Swap(a, b) if b==c. + void ThreeShuffle(size_t a, size_t b, size_t c); + + private: + // TODO(veluca): as the total number of properties and predictors are known + // before adding any samples, it might be better to interleave predictors, + // properties and counts in a single vector to improve locality. + // A first attempt at doing this actually results in much slower encoding, + // possibly because of the more complex addressing. + struct ResidualToken { + uint8_t tok; + uint8_t nbits; + }; + // Residual information: token and number of extra bits, per predictor. + std::vector<std::vector<ResidualToken>> residuals; + // Number of occurrences of each sample. + std::vector<uint16_t> sample_counts; + // Property values, quantized to at most 256 distinct values. + std::vector<std::vector<uint8_t>> props; + // Decompactification info for `props`. + std::vector<std::vector<int32_t>> compact_properties; + // List of properties to use. + std::vector<uint32_t> props_to_use; + // List of predictors to use. + std::vector<Predictor> predictors; + // Mapping property value -> quantized property value. + static constexpr int32_t kPropertyRange = 511; + std::vector<std::vector<uint8_t>> property_mapping; + // Number of samples seen. + size_t num_samples = 0; + // Table for deduplication. + static constexpr uint32_t kDedupEntryUnused{static_cast<uint32_t>(-1)}; + std::vector<uint32_t> dedup_table_; + + // Functions for sample deduplication. + bool IsSameSample(size_t a, size_t b) const; + size_t Hash1(size_t a) const; + size_t Hash2(size_t a) const; + void InitTable(size_t size); + // Returns true if `a` was already present in the table. + bool AddToTableAndMerge(size_t a); + void AddToTable(size_t a); +}; + +void TokenizeTree(const Tree &tree, std::vector<Token> *tokens, + Tree *decoder_tree); + +void CollectPixelSamples(const Image &image, const ModularOptions &options, + size_t group_id, + std::vector<uint32_t> &group_pixel_count, + std::vector<uint32_t> &channel_pixel_count, + std::vector<pixel_type> &pixel_samples, + std::vector<pixel_type> &diff_samples); + +void ComputeBestTree(TreeSamples &tree_samples, float threshold, + const std::vector<ModularMultiplierInfo> &mul_info, + StaticPropRange static_prop_range, + float fast_decode_multiplier, Tree *tree); + +} // namespace jxl +#endif // LIB_JXL_MODULAR_ENCODING_ENC_MA_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/encoding.cc b/third_party/jpeg-xl/lib/jxl/modular/encoding/encoding.cc new file mode 100644 index 0000000000..9d2c3e5cf9 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/encoding.cc @@ -0,0 +1,622 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/encoding/encoding.h" + +#include <stdint.h> +#include <stdlib.h> + +#include <queue> + +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/scope_guard.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/options.h" + +namespace jxl { + +// Removes all nodes that use a static property (i.e. channel or group ID) from +// the tree and collapses each node on even levels with its two children to +// produce a flatter tree. Also computes whether the resulting tree requires +// using the weighted predictor. +FlatTree FilterTree(const Tree &global_tree, + std::array<pixel_type, kNumStaticProperties> &static_props, + size_t *num_props, bool *use_wp, bool *wp_only, + bool *gradient_only) { + *num_props = 0; + bool has_wp = false; + bool has_non_wp = false; + *gradient_only = true; + const auto mark_property = [&](int32_t p) { + if (p == kWPProp) { + has_wp = true; + } else if (p >= kNumStaticProperties) { + has_non_wp = true; + } + if (p >= kNumStaticProperties && p != kGradientProp) { + *gradient_only = false; + } + }; + FlatTree output; + std::queue<size_t> nodes; + nodes.push(0); + // Produces a trimmed and flattened tree by doing a BFS visit of the original + // tree, ignoring branches that are known to be false and proceeding two + // levels at a time to collapse nodes in a flatter tree; if an inner parent + // node has a leaf as a child, the leaf is duplicated and an implicit fake + // node is added. This allows to reduce the number of branches when traversing + // the resulting flat tree. + while (!nodes.empty()) { + size_t cur = nodes.front(); + nodes.pop(); + // Skip nodes that we can decide now, by jumping directly to their children. + while (global_tree[cur].property < kNumStaticProperties && + global_tree[cur].property != -1) { + if (static_props[global_tree[cur].property] > global_tree[cur].splitval) { + cur = global_tree[cur].lchild; + } else { + cur = global_tree[cur].rchild; + } + } + FlatDecisionNode flat; + if (global_tree[cur].property == -1) { + flat.property0 = -1; + flat.childID = global_tree[cur].lchild; + flat.predictor = global_tree[cur].predictor; + flat.predictor_offset = global_tree[cur].predictor_offset; + flat.multiplier = global_tree[cur].multiplier; + *gradient_only &= flat.predictor == Predictor::Gradient; + has_wp |= flat.predictor == Predictor::Weighted; + has_non_wp |= flat.predictor != Predictor::Weighted; + output.push_back(flat); + continue; + } + flat.childID = output.size() + nodes.size() + 1; + + flat.property0 = global_tree[cur].property; + *num_props = std::max<size_t>(flat.property0 + 1, *num_props); + flat.splitval0 = global_tree[cur].splitval; + + for (size_t i = 0; i < 2; i++) { + size_t cur_child = + i == 0 ? global_tree[cur].lchild : global_tree[cur].rchild; + // Skip nodes that we can decide now. + while (global_tree[cur_child].property < kNumStaticProperties && + global_tree[cur_child].property != -1) { + if (static_props[global_tree[cur_child].property] > + global_tree[cur_child].splitval) { + cur_child = global_tree[cur_child].lchild; + } else { + cur_child = global_tree[cur_child].rchild; + } + } + // We ended up in a leaf, add a dummy decision and two copies of the leaf. + if (global_tree[cur_child].property == -1) { + flat.properties[i] = 0; + flat.splitvals[i] = 0; + nodes.push(cur_child); + nodes.push(cur_child); + } else { + flat.properties[i] = global_tree[cur_child].property; + flat.splitvals[i] = global_tree[cur_child].splitval; + nodes.push(global_tree[cur_child].lchild); + nodes.push(global_tree[cur_child].rchild); + *num_props = std::max<size_t>(flat.properties[i] + 1, *num_props); + } + } + + for (size_t j = 0; j < 2; j++) mark_property(flat.properties[j]); + mark_property(flat.property0); + output.push_back(flat); + } + if (*num_props > kNumNonrefProperties) { + *num_props = + DivCeil(*num_props - kNumNonrefProperties, kExtraPropsPerChannel) * + kExtraPropsPerChannel + + kNumNonrefProperties; + } else { + *num_props = kNumNonrefProperties; + } + *use_wp = has_wp; + *wp_only = has_wp && !has_non_wp; + + return output; +} + +Status DecodeModularChannelMAANS(BitReader *br, ANSSymbolReader *reader, + const std::vector<uint8_t> &context_map, + const Tree &global_tree, + const weighted::Header &wp_header, + pixel_type chan, size_t group_id, + Image *image) { + Channel &channel = image->channel[chan]; + + std::array<pixel_type, kNumStaticProperties> static_props = { + {chan, (int)group_id}}; + // TODO(veluca): filter the tree according to static_props. + + // zero pixel channel? could happen + if (channel.w == 0 || channel.h == 0) return true; + + bool tree_has_wp_prop_or_pred = false; + bool is_wp_only = false; + bool is_gradient_only = false; + size_t num_props; + FlatTree tree = + FilterTree(global_tree, static_props, &num_props, + &tree_has_wp_prop_or_pred, &is_wp_only, &is_gradient_only); + + // From here on, tree lookup returns a *clustered* context ID. + // This avoids an extra memory lookup after tree traversal. + for (size_t i = 0; i < tree.size(); i++) { + if (tree[i].property0 == -1) { + tree[i].childID = context_map[tree[i].childID]; + } + } + + JXL_DEBUG_V(3, "Decoded MA tree with %" PRIuS " nodes", tree.size()); + + // MAANS decode + const auto make_pixel = [](uint64_t v, pixel_type multiplier, + pixel_type_w offset) -> pixel_type { + JXL_DASSERT((v & 0xFFFFFFFF) == v); + pixel_type_w val = UnpackSigned(v); + // if it overflows, it overflows, and we have a problem anyway + return val * multiplier + offset; + }; + + if (tree.size() == 1) { + // special optimized case: no meta-adaptation, so no need + // to compute properties. + Predictor predictor = tree[0].predictor; + int64_t offset = tree[0].predictor_offset; + int32_t multiplier = tree[0].multiplier; + size_t ctx_id = tree[0].childID; + if (predictor == Predictor::Zero) { + uint32_t value; + if (reader->IsSingleValueAndAdvance(ctx_id, &value, + channel.w * channel.h)) { + // Special-case: histogram has a single symbol, with no extra bits, and + // we use ANS mode. + JXL_DEBUG_V(8, "Fastest track."); + pixel_type v = make_pixel(value, multiplier, offset); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + std::fill(r, r + channel.w, v); + } + } else { + JXL_DEBUG_V(8, "Fast track."); + if (multiplier == 1 && offset == 0) { + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + uint32_t v = reader->ReadHybridUintClustered(ctx_id, br); + r[x] = UnpackSigned(v); + } + } + } else { + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + uint32_t v = reader->ReadHybridUintClustered(ctx_id, br); + r[x] = make_pixel(v, multiplier, offset); + } + } + } + } + } else if (predictor == Predictor::Gradient && offset == 0 && + multiplier == 1 && reader->HuffRleOnly()) { + JXL_DEBUG_V(8, "Gradient RLE (fjxl) very fast track."); + uint32_t run = 0; + uint32_t v = 0; + pixel_type_w sv = 0; + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + const pixel_type *JXL_RESTRICT rtop = (y ? channel.Row(y - 1) : r - 1); + const pixel_type *JXL_RESTRICT rtopleft = + (y ? channel.Row(y - 1) - 1 : r - 1); + pixel_type_w guess = (y ? rtop[0] : 0); + if (run == 0) { + reader->ReadHybridUintClusteredHuffRleOnly(ctx_id, br, &v, &run); + sv = UnpackSigned(v); + } else { + run--; + } + r[0] = sv + guess; + for (size_t x = 1; x < channel.w; x++) { + pixel_type left = r[x - 1]; + pixel_type top = rtop[x]; + pixel_type topleft = rtopleft[x]; + pixel_type_w guess = ClampedGradient(top, left, topleft); + if (!run) { + reader->ReadHybridUintClusteredHuffRleOnly(ctx_id, br, &v, &run); + sv = UnpackSigned(v); + } else { + run--; + } + r[x] = sv + guess; + } + } + } else if (predictor == Predictor::Gradient && offset == 0 && + multiplier == 1) { + JXL_DEBUG_V(8, "Gradient very fast track."); + const intptr_t onerow = channel.plane.PixelsPerRow(); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + pixel_type left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); + pixel_type top = (y ? *(r + x - onerow) : left); + pixel_type topleft = (x && y ? *(r + x - 1 - onerow) : left); + pixel_type guess = ClampedGradient(top, left, topleft); + uint64_t v = reader->ReadHybridUintClustered(ctx_id, br); + r[x] = make_pixel(v, 1, guess); + } + } + } else if (predictor != Predictor::Weighted) { + // special optimized case: no wp + JXL_DEBUG_V(8, "Quite fast track."); + const intptr_t onerow = channel.plane.PixelsPerRow(); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + PredictionResult pred = + PredictNoTreeNoWP(channel.w, r + x, onerow, x, y, predictor); + pixel_type_w g = pred.guess + offset; + uint64_t v = reader->ReadHybridUintClustered(ctx_id, br); + // NOTE: pred.multiplier is unset. + r[x] = make_pixel(v, multiplier, g); + } + } + } else { + JXL_DEBUG_V(8, "Somewhat fast track."); + const intptr_t onerow = channel.plane.PixelsPerRow(); + weighted::State wp_state(wp_header, channel.w, channel.h); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + pixel_type_w g = PredictNoTreeWP(channel.w, r + x, onerow, x, y, + predictor, &wp_state) + .guess + + offset; + uint64_t v = reader->ReadHybridUintClustered(ctx_id, br); + r[x] = make_pixel(v, multiplier, g); + wp_state.UpdateErrors(r[x], x, y, channel.w); + } + } + } + return true; + } + + // Check if this tree is a WP-only tree with a small enough property value + // range. + // Initialized to avoid clang-tidy complaining. + uint8_t context_lookup[2 * kPropRangeFast] = {}; + int8_t multipliers[2 * kPropRangeFast] = {}; + int8_t offsets[2 * kPropRangeFast] = {}; + if (is_wp_only) { + is_wp_only = TreeToLookupTable(tree, context_lookup, offsets, multipliers); + } + if (is_gradient_only) { + is_gradient_only = + TreeToLookupTable(tree, context_lookup, offsets, multipliers); + } + + if (is_gradient_only) { + JXL_DEBUG_V(8, "Gradient fast track."); + const intptr_t onerow = channel.plane.PixelsPerRow(); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); + pixel_type_w top = (y ? *(r + x - onerow) : left); + pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left); + int32_t guess = ClampedGradient(top, left, topleft); + uint32_t pos = + kPropRangeFast + + std::min<pixel_type_w>( + std::max<pixel_type_w>(-kPropRangeFast, top + left - topleft), + kPropRangeFast - 1); + uint32_t ctx_id = context_lookup[pos]; + uint64_t v = reader->ReadHybridUintClustered(ctx_id, br); + r[x] = make_pixel(v, multipliers[pos], + static_cast<pixel_type_w>(offsets[pos]) + guess); + } + } + } else if (is_wp_only) { + JXL_DEBUG_V(8, "WP fast track."); + const intptr_t onerow = channel.plane.PixelsPerRow(); + weighted::State wp_state(wp_header, channel.w, channel.h); + Properties properties(1); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + size_t offset = 0; + pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); + pixel_type_w top = (y ? *(r + x - onerow) : left); + pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left); + pixel_type_w topright = + (x + 1 < channel.w && y ? *(r + x + 1 - onerow) : top); + pixel_type_w toptop = (y > 1 ? *(r + x - onerow - onerow) : top); + int32_t guess = wp_state.Predict</*compute_properties=*/true>( + x, y, channel.w, top, left, topright, topleft, toptop, &properties, + offset); + uint32_t pos = + kPropRangeFast + std::min(std::max(-kPropRangeFast, properties[0]), + kPropRangeFast - 1); + uint32_t ctx_id = context_lookup[pos]; + uint64_t v = reader->ReadHybridUintClustered(ctx_id, br); + r[x] = make_pixel(v, multipliers[pos], + static_cast<pixel_type_w>(offsets[pos]) + guess); + wp_state.UpdateErrors(r[x], x, y, channel.w); + } + } + } else if (!tree_has_wp_prop_or_pred) { + // special optimized case: the weighted predictor and its properties are not + // used, so no need to compute weights and properties. + JXL_DEBUG_V(8, "Slow track."); + MATreeLookup tree_lookup(tree); + Properties properties = Properties(num_props); + const intptr_t onerow = channel.plane.PixelsPerRow(); + Channel references(properties.size() - kNumNonrefProperties, channel.w); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT p = channel.Row(y); + PrecomputeReferences(channel, y, *image, chan, &references); + InitPropsRow(&properties, static_props, y); + if (y > 1 && channel.w > 8 && references.w == 0) { + for (size_t x = 0; x < 2; x++) { + PredictionResult res = + PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references); + uint64_t v = reader->ReadHybridUintClustered(res.context, br); + p[x] = make_pixel(v, res.multiplier, res.guess); + } + for (size_t x = 2; x < channel.w - 2; x++) { + PredictionResult res = + PredictTreeNoWPNEC(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references); + uint64_t v = reader->ReadHybridUintClustered(res.context, br); + p[x] = make_pixel(v, res.multiplier, res.guess); + } + for (size_t x = channel.w - 2; x < channel.w; x++) { + PredictionResult res = + PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references); + uint64_t v = reader->ReadHybridUintClustered(res.context, br); + p[x] = make_pixel(v, res.multiplier, res.guess); + } + } else { + for (size_t x = 0; x < channel.w; x++) { + PredictionResult res = + PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references); + uint64_t v = reader->ReadHybridUintClustered(res.context, br); + p[x] = make_pixel(v, res.multiplier, res.guess); + } + } + } + } else { + JXL_DEBUG_V(8, "Slowest track."); + MATreeLookup tree_lookup(tree); + Properties properties = Properties(num_props); + const intptr_t onerow = channel.plane.PixelsPerRow(); + Channel references(properties.size() - kNumNonrefProperties, channel.w); + weighted::State wp_state(wp_header, channel.w, channel.h); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT p = channel.Row(y); + InitPropsRow(&properties, static_props, y); + PrecomputeReferences(channel, y, *image, chan, &references); + for (size_t x = 0; x < channel.w; x++) { + PredictionResult res = + PredictTreeWP(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references, &wp_state); + uint64_t v = reader->ReadHybridUintClustered(res.context, br); + p[x] = make_pixel(v, res.multiplier, res.guess); + wp_state.UpdateErrors(p[x], x, y, channel.w); + } + } + } + return true; +} + +GroupHeader::GroupHeader() { Bundle::Init(this); } + +Status ValidateChannelDimensions(const Image &image, + const ModularOptions &options) { + size_t nb_channels = image.channel.size(); + for (bool is_dc : {true, false}) { + size_t group_dim = options.group_dim * (is_dc ? kBlockDim : 1); + size_t c = image.nb_meta_channels; + for (; c < nb_channels; c++) { + const Channel &ch = image.channel[c]; + if (ch.w > options.group_dim || ch.h > options.group_dim) break; + } + for (; c < nb_channels; c++) { + const Channel &ch = image.channel[c]; + if (ch.w == 0 || ch.h == 0) continue; // skip empty + bool is_dc_channel = std::min(ch.hshift, ch.vshift) >= 3; + if (is_dc_channel != is_dc) continue; + size_t tile_dim = group_dim >> std::max(ch.hshift, ch.vshift); + if (tile_dim == 0) { + return JXL_FAILURE("Inconsistent transforms"); + } + } + } + return true; +} + +Status ModularDecode(BitReader *br, Image &image, GroupHeader &header, + size_t group_id, ModularOptions *options, + const Tree *global_tree, const ANSCode *global_code, + const std::vector<uint8_t> *global_ctx_map, + bool allow_truncated_group) { + if (image.channel.empty()) return true; + + // decode transforms + Status status = Bundle::Read(br, &header); + if (!allow_truncated_group) JXL_RETURN_IF_ERROR(status); + if (status.IsFatalError()) return status; + if (!br->AllReadsWithinBounds()) { + // Don't do/undo transforms if header is incomplete. + header.transforms.clear(); + image.transform = header.transforms; + for (size_t c = 0; c < image.channel.size(); c++) { + ZeroFillImage(&image.channel[c].plane); + } + return Status(StatusCode::kNotEnoughBytes); + } + + JXL_DEBUG_V(3, "Image data underwent %" PRIuS " transformations: ", + header.transforms.size()); + image.transform = header.transforms; + for (Transform &transform : image.transform) { + JXL_RETURN_IF_ERROR(transform.MetaApply(image)); + } + if (image.error) { + return JXL_FAILURE("Corrupt file. Aborting."); + } + JXL_RETURN_IF_ERROR(ValidateChannelDimensions(image, *options)); + + size_t nb_channels = image.channel.size(); + + size_t num_chans = 0; + size_t distance_multiplier = 0; + for (size_t i = 0; i < nb_channels; i++) { + Channel &channel = image.channel[i]; + if (!channel.w || !channel.h) { + continue; // skip empty channels + } + if (i >= image.nb_meta_channels && (channel.w > options->max_chan_size || + channel.h > options->max_chan_size)) { + break; + } + if (channel.w > distance_multiplier) { + distance_multiplier = channel.w; + } + num_chans++; + } + if (num_chans == 0) return true; + + size_t next_channel = 0; + auto scope_guard = MakeScopeGuard([&]() { + // Do not do anything if truncated groups are not allowed. + if (!allow_truncated_group) return; + for (size_t c = next_channel; c < nb_channels; c++) { + ZeroFillImage(&image.channel[c].plane); + } + }); + + // Read tree. + Tree tree_storage; + std::vector<uint8_t> context_map_storage; + ANSCode code_storage; + const Tree *tree = &tree_storage; + const ANSCode *code = &code_storage; + const std::vector<uint8_t> *context_map = &context_map_storage; + if (!header.use_global_tree) { + size_t max_tree_size = 1024; + for (size_t i = 0; i < nb_channels; i++) { + Channel &channel = image.channel[i]; + if (!channel.w || !channel.h) { + continue; // skip empty channels + } + if (i >= image.nb_meta_channels && (channel.w > options->max_chan_size || + channel.h > options->max_chan_size)) { + break; + } + size_t pixels = channel.w * channel.h; + if (pixels / channel.w != channel.h) { + return JXL_FAILURE("Tree size overflow"); + } + max_tree_size += pixels; + if (max_tree_size < pixels) return JXL_FAILURE("Tree size overflow"); + } + max_tree_size = std::min(static_cast<size_t>(1 << 20), max_tree_size); + JXL_RETURN_IF_ERROR(DecodeTree(br, &tree_storage, max_tree_size)); + JXL_RETURN_IF_ERROR(DecodeHistograms(br, (tree_storage.size() + 1) / 2, + &code_storage, &context_map_storage)); + } else { + if (!global_tree || !global_code || !global_ctx_map || + global_tree->empty()) { + return JXL_FAILURE("No global tree available but one was requested"); + } + tree = global_tree; + code = global_code; + context_map = global_ctx_map; + } + + // Read channels + ANSSymbolReader reader(code, br, distance_multiplier); + for (; next_channel < nb_channels; next_channel++) { + Channel &channel = image.channel[next_channel]; + if (!channel.w || !channel.h) { + continue; // skip empty channels + } + if (next_channel >= image.nb_meta_channels && + (channel.w > options->max_chan_size || + channel.h > options->max_chan_size)) { + break; + } + JXL_RETURN_IF_ERROR(DecodeModularChannelMAANS( + br, &reader, *context_map, *tree, header.wp_header, next_channel, + group_id, &image)); + // Truncated group. + if (!br->AllReadsWithinBounds()) { + if (!allow_truncated_group) return JXL_FAILURE("Truncated input"); + return Status(StatusCode::kNotEnoughBytes); + } + } + + // Make sure no zero-filling happens even if next_channel < nb_channels. + scope_guard.Disarm(); + + if (!reader.CheckANSFinalState()) { + return JXL_FAILURE("ANS decode final state failed"); + } + return true; +} + +Status ModularGenericDecompress(BitReader *br, Image &image, + GroupHeader *header, size_t group_id, + ModularOptions *options, bool undo_transforms, + const Tree *tree, const ANSCode *code, + const std::vector<uint8_t> *ctx_map, + bool allow_truncated_group) { +#ifdef JXL_ENABLE_ASSERT + std::vector<std::pair<uint32_t, uint32_t>> req_sizes(image.channel.size()); + for (size_t c = 0; c < req_sizes.size(); c++) { + req_sizes[c] = {image.channel[c].w, image.channel[c].h}; + } +#endif + GroupHeader local_header; + if (header == nullptr) header = &local_header; + size_t bit_pos = br->TotalBitsConsumed(); + auto dec_status = ModularDecode(br, image, *header, group_id, options, tree, + code, ctx_map, allow_truncated_group); + if (!allow_truncated_group) JXL_RETURN_IF_ERROR(dec_status); + if (dec_status.IsFatalError()) return dec_status; + if (undo_transforms) image.undo_transforms(header->wp_header); + if (image.error) return JXL_FAILURE("Corrupt file. Aborting."); + JXL_DEBUG_V(4, + "Modular-decoded a %" PRIuS "x%" PRIuS " nbchans=%" PRIuS + " image from %" PRIuS " bytes", + image.w, image.h, image.channel.size(), + (br->TotalBitsConsumed() - bit_pos) / 8); + JXL_DEBUG_V(5, "Modular image: %s", image.DebugString().c_str()); + (void)bit_pos; +#ifdef JXL_ENABLE_ASSERT + // Check that after applying all transforms we are back to the requested image + // sizes, otherwise there's a programming error with the transformations. + if (undo_transforms) { + JXL_ASSERT(image.channel.size() == req_sizes.size()); + for (size_t c = 0; c < req_sizes.size(); c++) { + JXL_ASSERT(req_sizes[c].first == image.channel[c].w); + JXL_ASSERT(req_sizes[c].second == image.channel[c].h); + } + } +#endif + return dec_status; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/encoding.h b/third_party/jpeg-xl/lib/jxl/modular/encoding/encoding.h new file mode 100644 index 0000000000..89697bce87 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/encoding.h @@ -0,0 +1,135 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_ENCODING_ENCODING_H_ +#define LIB_JXL_MODULAR_ENCODING_ENCODING_H_ + +#include <stddef.h> +#include <stdint.h> + +#include <vector> + +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/image.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/encoding/dec_ma.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/options.h" +#include "lib/jxl/modular/transform/transform.h" + +namespace jxl { + +// Valid range of properties for using lookup tables instead of trees. +constexpr int32_t kPropRangeFast = 512; + +struct GroupHeader : public Fields { + GroupHeader(); + + JXL_FIELDS_NAME(GroupHeader) + + Status VisitFields(Visitor *JXL_RESTRICT visitor) override { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &use_global_tree)); + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&wp_header)); + uint32_t num_transforms = static_cast<uint32_t>(transforms.size()); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(0), Val(1), BitsOffset(4, 2), + BitsOffset(8, 18), 0, + &num_transforms)); + if (visitor->IsReading()) transforms.resize(num_transforms); + for (size_t i = 0; i < num_transforms; i++) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&transforms[i])); + } + return true; + } + + bool use_global_tree; + weighted::Header wp_header; + + std::vector<Transform> transforms; +}; + +FlatTree FilterTree(const Tree &global_tree, + std::array<pixel_type, kNumStaticProperties> &static_props, + size_t *num_props, bool *use_wp, bool *wp_only, + bool *gradient_only); + +template <typename T> +bool TreeToLookupTable(const FlatTree &tree, + T context_lookup[2 * kPropRangeFast], + int8_t offsets[2 * kPropRangeFast], + int8_t multipliers[2 * kPropRangeFast] = nullptr) { + struct TreeRange { + // Begin *excluded*, end *included*. This works best with > vs <= decision + // nodes. + int begin, end; + size_t pos; + }; + std::vector<TreeRange> ranges; + ranges.push_back(TreeRange{-kPropRangeFast - 1, kPropRangeFast - 1, 0}); + while (!ranges.empty()) { + TreeRange cur = ranges.back(); + ranges.pop_back(); + if (cur.begin < -kPropRangeFast - 1 || cur.begin >= kPropRangeFast - 1 || + cur.end > kPropRangeFast - 1) { + // Tree is outside the allowed range, exit. + return false; + } + auto &node = tree[cur.pos]; + // Leaf. + if (node.property0 == -1) { + if (node.predictor_offset < std::numeric_limits<int8_t>::min() || + node.predictor_offset > std::numeric_limits<int8_t>::max()) { + return false; + } + if (node.multiplier < std::numeric_limits<int8_t>::min() || + node.multiplier > std::numeric_limits<int8_t>::max()) { + return false; + } + if (multipliers == nullptr && node.multiplier != 1) { + return false; + } + for (int i = cur.begin + 1; i < cur.end + 1; i++) { + context_lookup[i + kPropRangeFast] = node.childID; + if (multipliers) multipliers[i + kPropRangeFast] = node.multiplier; + offsets[i + kPropRangeFast] = node.predictor_offset; + } + continue; + } + // > side of top node. + if (node.properties[0] >= kNumStaticProperties) { + ranges.push_back(TreeRange({node.splitvals[0], cur.end, node.childID})); + ranges.push_back( + TreeRange({node.splitval0, node.splitvals[0], node.childID + 1})); + } else { + ranges.push_back(TreeRange({node.splitval0, cur.end, node.childID})); + } + // <= side + if (node.properties[1] >= kNumStaticProperties) { + ranges.push_back( + TreeRange({node.splitvals[1], node.splitval0, node.childID + 2})); + ranges.push_back( + TreeRange({cur.begin, node.splitvals[1], node.childID + 3})); + } else { + ranges.push_back( + TreeRange({cur.begin, node.splitval0, node.childID + 2})); + } + } + return true; +} +// TODO(veluca): make cleaner interfaces. + +Status ValidateChannelDimensions(const Image &image, + const ModularOptions &options); + +Status ModularGenericDecompress(BitReader *br, Image &image, + GroupHeader *header, size_t group_id, + ModularOptions *options, + bool undo_transforms = true, + const Tree *tree = nullptr, + const ANSCode *code = nullptr, + const std::vector<uint8_t> *ctx_map = nullptr, + bool allow_truncated_group = false); +} // namespace jxl + +#endif // LIB_JXL_MODULAR_ENCODING_ENCODING_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/ma_common.h b/third_party/jpeg-xl/lib/jxl/modular/encoding/ma_common.h new file mode 100644 index 0000000000..71b7847321 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/ma_common.h @@ -0,0 +1,28 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_ENCODING_MA_COMMON_H_ +#define LIB_JXL_MODULAR_ENCODING_MA_COMMON_H_ + +#include <stddef.h> + +namespace jxl { + +enum MATreeContext : size_t { + kSplitValContext = 0, + kPropertyContext = 1, + kPredictorContext = 2, + kOffsetContext = 3, + kMultiplierLogContext = 4, + kMultiplierBitsContext = 5, + + kNumTreeContexts = 6, +}; + +static constexpr size_t kMaxTreeSize = 1 << 22; + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_ENCODING_MA_COMMON_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/modular_image.cc b/third_party/jpeg-xl/lib/jxl/modular/modular_image.cc new file mode 100644 index 0000000000..785d0c5443 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/modular_image.cc @@ -0,0 +1,77 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/modular_image.h" + +#include <sstream> + +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/modular/transform/transform.h" + +namespace jxl { + +void Image::undo_transforms(const weighted::Header &wp_header, + jxl::ThreadPool *pool) { + while (!transform.empty()) { + Transform t = transform.back(); + JXL_DEBUG_V(4, "Undoing transform"); + Status result = t.Inverse(*this, wp_header, pool); + if (result == false) { + JXL_NOTIFY_ERROR("Error while undoing transform."); + error = true; + return; + } + JXL_DEBUG_V(8, "Undoing transform: done"); + transform.pop_back(); + } +} + +Image::Image(size_t iw, size_t ih, int bitdepth, int nb_chans) + : w(iw), h(ih), bitdepth(bitdepth), nb_meta_channels(0), error(false) { + for (int i = 0; i < nb_chans; i++) channel.emplace_back(Channel(iw, ih)); +} + +Image::Image() : w(0), h(0), bitdepth(8), nb_meta_channels(0), error(true) {} + +Image &Image::operator=(Image &&other) noexcept { + w = other.w; + h = other.h; + bitdepth = other.bitdepth; + nb_meta_channels = other.nb_meta_channels; + error = other.error; + channel = std::move(other.channel); + transform = std::move(other.transform); + return *this; +} + +Image Image::clone() { + Image c(w, h, bitdepth, 0); + c.nb_meta_channels = nb_meta_channels; + c.error = error; + c.transform = transform; + for (Channel &ch : channel) { + Channel a(ch.w, ch.h, ch.hshift, ch.vshift); + CopyImageTo(ch.plane, &a.plane); + c.channel.push_back(std::move(a)); + } + return c; +} + +std::string Image::DebugString() const { + std::ostringstream os; + os << w << "x" << h << ", depth: " << bitdepth; + if (!channel.empty()) { + os << ", channels:"; + for (size_t i = 0; i < channel.size(); ++i) { + os << " " << channel[i].w << "x" << channel[i].h + << "(shift: " << channel[i].hshift << "," << channel[i].vshift << ")"; + if (i < nb_meta_channels) os << "*"; + } + } + return os.str(); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/modular/modular_image.h b/third_party/jpeg-xl/lib/jxl/modular/modular_image.h new file mode 100644 index 0000000000..3e9b5a8a08 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/modular_image.h @@ -0,0 +1,118 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_MODULAR_IMAGE_H_ +#define LIB_JXL_MODULAR_MODULAR_IMAGE_H_ + +#include <stddef.h> +#include <stdint.h> +#include <stdio.h> +#include <string.h> + +#include <string> +#include <utility> +#include <vector> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { + +typedef int32_t pixel_type; // can use int16_t if it's only for 8-bit images. + // Need some wiggle room for YCoCg / Squeeze etc + +typedef int64_t pixel_type_w; + +namespace weighted { +struct Header; +} + +class Channel { + public: + jxl::Plane<pixel_type> plane; + size_t w, h; + int hshift, vshift; // w ~= image.w >> hshift; h ~= image.h >> vshift + Channel(size_t iw, size_t ih, int hsh = 0, int vsh = 0) + : plane(iw, ih), w(iw), h(ih), hshift(hsh), vshift(vsh) {} + + Channel(const Channel& other) = delete; + Channel& operator=(const Channel& other) = delete; + + // Move assignment + Channel& operator=(Channel&& other) noexcept { + w = other.w; + h = other.h; + hshift = other.hshift; + vshift = other.vshift; + plane = std::move(other.plane); + return *this; + } + + // Move constructor + Channel(Channel&& other) noexcept = default; + + void shrink() { + if (plane.xsize() == w && plane.ysize() == h) return; + jxl::Plane<pixel_type> resizedplane(w, h); + plane = std::move(resizedplane); + } + void shrink(int nw, int nh) { + w = nw; + h = nh; + shrink(); + } + + JXL_INLINE pixel_type* Row(const size_t y) { return plane.Row(y); } + JXL_INLINE const pixel_type* Row(const size_t y) const { + return plane.Row(y); + } +}; + +class Transform; + +class Image { + public: + // image data, transforms can dramatically change the number of channels and + // their semantics + std::vector<Channel> channel; + // transforms that have been applied (and that have to be undone) + std::vector<Transform> transform; + + // image dimensions (channels may have different dimensions due to transforms) + size_t w, h; + int bitdepth; + size_t nb_meta_channels; // first few channels might contain palette(s) + bool error; // true if a fatal error occurred, false otherwise + + Image(size_t iw, size_t ih, int bitdepth, int nb_chans); + Image(); + + Image(const Image& other) = delete; + Image& operator=(const Image& other) = delete; + + Image& operator=(Image&& other) noexcept; + Image(Image&& other) noexcept = default; + + bool empty() const { + for (const auto& ch : channel) { + if (ch.w && ch.h) return false; + } + return true; + } + + Image clone(); + + void undo_transforms(const weighted::Header& wp_header, + jxl::ThreadPool* pool = nullptr); + + std::string DebugString() const; +}; + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_MODULAR_IMAGE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/options.h b/third_party/jpeg-xl/lib/jxl/modular/options.h new file mode 100644 index 0000000000..ce6596b912 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/options.h @@ -0,0 +1,117 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_OPTIONS_H_ +#define LIB_JXL_MODULAR_OPTIONS_H_ + +#include <stdint.h> + +#include <array> +#include <vector> + +namespace jxl { + +using PropertyVal = int32_t; +using Properties = std::vector<PropertyVal>; + +enum class Predictor : uint32_t { + Zero = 0, + Left = 1, + Top = 2, + Average0 = 3, + Select = 4, + Gradient = 5, + Weighted = 6, + TopRight = 7, + TopLeft = 8, + LeftLeft = 9, + Average1 = 10, + Average2 = 11, + Average3 = 12, + Average4 = 13, + // The following predictors are encoder-only. + Best = 14, // Best of Gradient and Weighted + Variable = + 15, // Find the best decision tree for predictors/predictor per row +}; + +constexpr size_t kNumModularPredictors = + static_cast<size_t>(Predictor::Average4) + 1; +constexpr size_t kNumModularEncoderPredictors = + static_cast<size_t>(Predictor::Variable) + 1; + +static constexpr ssize_t kNumStaticProperties = 2; // channel, group_id. + +using StaticPropRange = + std::array<std::array<uint32_t, 2>, kNumStaticProperties>; + +struct ModularMultiplierInfo { + StaticPropRange range; + uint32_t multiplier; +}; + +struct ModularOptions { + /// Used in both encode and decode: + + // Stop encoding/decoding when reaching a (non-meta) channel that has a + // dimension bigger than max_chan_size. + size_t max_chan_size = 0xFFFFFF; + + // Used during decoding for validation of transforms (sqeeezing) scheme. + size_t group_dim = 0x1FFFFFFF; + + /// Encode options: + // Fraction of pixels to look at to learn a MA tree + // Number of iterations to do to learn a MA tree + // (if zero there is no MA context model) + float nb_repeats = .5f; + + // Maximum number of (previous channel) properties to use in the MA trees + int max_properties = 0; // no previous channels + + // Alternative heuristic tweaks. + // Properties default to channel, group, weighted, gradient residual, W-NW, + // NW-N, N-NE, N-NN + std::vector<uint32_t> splitting_heuristics_properties = {0, 1, 15, 9, + 10, 11, 12, 13}; + float splitting_heuristics_node_threshold = 96; + size_t max_property_values = 32; + + // Predictor to use for each channel. + Predictor predictor = static_cast<Predictor>(-1); + + int wp_mode = 0; + + float fast_decode_multiplier = 1.01f; + + // Forces the encoder to produce a tree that is compatible with the WP-only + // decode path (or with the no-wp path, or the gradient-only path). + enum class TreeMode { kGradientOnly, kWPOnly, kNoWP, kDefault }; + TreeMode wp_tree_mode = TreeMode::kDefault; + + // Skip fast paths in the encoder. + bool skip_encoder_fast_path = false; + + // Kind of tree to use. + // TODO(veluca): add tree kinds for JPEG recompression with CfL enabled, + // general AC metadata, different DC qualities, and others. + enum class TreeKind { + kTrivialTreeNoPredictor, + kLearn, + kJpegTranscodeACMeta, + kFalconACMeta, + kACMeta, + kWPFixedDC, + kGradientFixedDC, + }; + TreeKind tree_kind = TreeKind::kLearn; + + // Ignore the image and just pretend all tokens are zeroes + bool zero_tokens = false; +}; + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_OPTIONS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/enc_palette.cc b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_palette.cc new file mode 100644 index 0000000000..bc31445bc5 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_palette.cc @@ -0,0 +1,606 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/transform/enc_palette.h" + +#include <array> +#include <map> +#include <set> + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/transform/enc_transform.h" +#include "lib/jxl/modular/transform/palette.h" + +namespace jxl { + +namespace palette_internal { + +static constexpr bool kEncodeToHighQualityImplicitPalette = true; + +// Inclusive. +static constexpr int kMinImplicitPaletteIndex = -(2 * 72 - 1); + +float ColorDistance(const std::vector<float> &JXL_RESTRICT a, + const std::vector<pixel_type> &JXL_RESTRICT b) { + JXL_ASSERT(a.size() == b.size()); + float distance = 0; + float ave3 = 0; + if (a.size() >= 3) { + ave3 = (a[0] + b[0] + a[1] + b[1] + a[2] + b[2]) * (1.21f / 3.0f); + } + float sum_a = 0, sum_b = 0; + for (size_t c = 0; c < a.size(); ++c) { + const float difference = + static_cast<float>(a[c]) - static_cast<float>(b[c]); + float weight = c == 0 ? 3 : c == 1 ? 5 : 2; + if (c < 3 && (a[c] + b[c] >= ave3)) { + const float add_w[3] = { + 1.15, + 1.15, + 1.12, + }; + weight += add_w[c]; + if (c == 2 && ((a[2] + b[2]) < 1.22 * ave3)) { + weight -= 0.5; + } + } + distance += difference * difference * weight * weight; + const int sum_weight = c == 0 ? 3 : c == 1 ? 5 : 1; + sum_a += a[c] * sum_weight; + sum_b += b[c] * sum_weight; + } + distance *= 4; + float sum_difference = sum_a - sum_b; + distance += sum_difference * sum_difference; + return distance; +} + +static int QuantizeColorToImplicitPaletteIndex( + const std::vector<pixel_type> &color, const int palette_size, + const int bit_depth, bool high_quality) { + int index = 0; + if (high_quality) { + int multiplier = 1; + for (size_t c = 0; c < color.size(); c++) { + int quantized = ((kLargeCube - 1) * color[c] + (1 << (bit_depth - 1))) / + ((1 << bit_depth) - 1); + JXL_ASSERT((quantized % kLargeCube) == quantized); + index += quantized * multiplier; + multiplier *= kLargeCube; + } + return index + palette_size + kLargeCubeOffset; + } else { + int multiplier = 1; + for (size_t c = 0; c < color.size(); c++) { + int value = color[c]; + value -= 1 << (std::max(0, bit_depth - 3)); + value = std::max(0, value); + int quantized = ((kLargeCube - 1) * value + (1 << (bit_depth - 1))) / + ((1 << bit_depth) - 1); + JXL_ASSERT((quantized % kLargeCube) == quantized); + if (quantized > kSmallCube - 1) { + quantized = kSmallCube - 1; + } + index += quantized * multiplier; + multiplier *= kSmallCube; + } + return index + palette_size; + } +} + +} // namespace palette_internal + +int RoundInt(int value, int div) { // symmetric rounding around 0 + if (value < 0) return -RoundInt(-value, div); + return (value + div / 2) / div; +} + +struct PaletteIterationData { + static constexpr int kMaxDeltas = 128; + bool final_run = false; + std::vector<pixel_type> deltas[3]; + std::vector<double> delta_distances; + std::vector<pixel_type> frequent_deltas[3]; + + // Populates `frequent_deltas` with items from `deltas` based on frequencies + // and color distances. + void FindFrequentColorDeltas(int num_pixels, int bitdepth) { + using pixel_type_3d = std::array<pixel_type, 3>; + std::map<pixel_type_3d, double> delta_frequency_map; + pixel_type bucket_size = 3 << std::max(0, bitdepth - 8); + // Store frequency weighted by delta distance from quantized value. + for (size_t i = 0; i < deltas[0].size(); ++i) { + pixel_type_3d delta = { + {RoundInt(deltas[0][i], bucket_size), + RoundInt(deltas[1][i], bucket_size), + RoundInt(deltas[2][i], bucket_size)}}; // a basic form of clustering + if (delta[0] == 0 && delta[1] == 0 && delta[2] == 0) continue; + delta_frequency_map[delta] += sqrt(sqrt(delta_distances[i])); + } + + const float delta_distance_multiplier = 1.0f / num_pixels; + + // Weigh frequencies by magnitude and normalize. + for (auto &delta_frequency : delta_frequency_map) { + std::vector<pixel_type> current_delta = {delta_frequency.first[0], + delta_frequency.first[1], + delta_frequency.first[2]}; + float delta_distance = + sqrt(palette_internal::ColorDistance({0, 0, 0}, current_delta)) + 1; + delta_frequency.second *= delta_distance * delta_distance_multiplier; + } + + // Sort by weighted frequency. + using pixel_type_3d_frequency = std::pair<pixel_type_3d, double>; + std::vector<pixel_type_3d_frequency> sorted_delta_frequency_map( + delta_frequency_map.begin(), delta_frequency_map.end()); + std::sort( + sorted_delta_frequency_map.begin(), sorted_delta_frequency_map.end(), + [](const pixel_type_3d_frequency &a, const pixel_type_3d_frequency &b) { + return a.second > b.second; + }); + + // Store the top deltas. + for (auto &delta_frequency : sorted_delta_frequency_map) { + if (frequent_deltas[0].size() >= kMaxDeltas) break; + // Number obtained by optimizing on jyrki31 corpus: + if (delta_frequency.second < 17) break; + for (int c = 0; c < 3; ++c) { + frequent_deltas[c].push_back(delta_frequency.first[c] * bucket_size); + } + } + } +}; + +Status FwdPaletteIteration(Image &input, uint32_t begin_c, uint32_t end_c, + uint32_t &nb_colors, uint32_t &nb_deltas, + bool ordered, bool lossy, Predictor &predictor, + const weighted::Header &wp_header, + PaletteIterationData &palette_iteration_data) { + JXL_QUIET_RETURN_IF_ERROR(CheckEqualChannels(input, begin_c, end_c)); + JXL_ASSERT(begin_c >= input.nb_meta_channels); + uint32_t nb = end_c - begin_c + 1; + + size_t w = input.channel[begin_c].w; + size_t h = input.channel[begin_c].h; + + if (!lossy && nb == 1) { + // Channel palette special case + if (nb_colors == 0) return false; + std::vector<pixel_type> lookup; + pixel_type minval, maxval; + compute_minmax(input.channel[begin_c], &minval, &maxval); + size_t lookup_table_size = + static_cast<int64_t>(maxval) - static_cast<int64_t>(minval) + 1; + if (lookup_table_size > palette_internal::kMaxPaletteLookupTableSize) { + // a lookup table would use too much memory, instead use a slower approach + // with std::set + std::set<pixel_type> chpalette; + pixel_type idx = 0; + for (size_t y = 0; y < h; y++) { + const pixel_type *p = input.channel[begin_c].Row(y); + for (size_t x = 0; x < w; x++) { + const bool new_color = chpalette.insert(p[x]).second; + if (new_color) { + idx++; + if (idx > (int)nb_colors) return false; + } + } + } + JXL_DEBUG_V(6, "Channel %i uses only %i colors.", begin_c, idx); + Channel pch(idx, 1); + pch.hshift = -1; + pch.vshift = -1; + nb_colors = idx; + idx = 0; + pixel_type *JXL_RESTRICT p_palette = pch.Row(0); + for (pixel_type p : chpalette) { + p_palette[idx++] = p; + } + for (size_t y = 0; y < h; y++) { + pixel_type *p = input.channel[begin_c].Row(y); + for (size_t x = 0; x < w; x++) { + for (idx = 0; p[x] != p_palette[idx] && idx < (int)nb_colors; idx++) { + } + JXL_DASSERT(idx < (int)nb_colors); + p[x] = idx; + } + } + predictor = Predictor::Zero; + input.nb_meta_channels++; + input.channel.insert(input.channel.begin(), std::move(pch)); + + return true; + } + lookup.resize(lookup_table_size, 0); + pixel_type idx = 0; + for (size_t y = 0; y < h; y++) { + const pixel_type *p = input.channel[begin_c].Row(y); + for (size_t x = 0; x < w; x++) { + if (lookup[p[x] - minval] == 0) { + lookup[p[x] - minval] = 1; + idx++; + if (idx > (int)nb_colors) return false; + } + } + } + JXL_DEBUG_V(6, "Channel %i uses only %i colors.", begin_c, idx); + Channel pch(idx, 1); + pch.hshift = -1; + pch.vshift = -1; + nb_colors = idx; + idx = 0; + pixel_type *JXL_RESTRICT p_palette = pch.Row(0); + for (size_t i = 0; i < lookup_table_size; i++) { + if (lookup[i]) { + p_palette[idx] = i + minval; + lookup[i] = idx; + idx++; + } + } + for (size_t y = 0; y < h; y++) { + pixel_type *p = input.channel[begin_c].Row(y); + for (size_t x = 0; x < w; x++) p[x] = lookup[p[x] - minval]; + } + predictor = Predictor::Zero; + input.nb_meta_channels++; + input.channel.insert(input.channel.begin(), std::move(pch)); + return true; + } + + Image quantized_input; + if (lossy) { + quantized_input = Image(w, h, input.bitdepth, nb); + for (size_t c = 0; c < nb; c++) { + CopyImageTo(input.channel[begin_c + c].plane, + &quantized_input.channel[c].plane); + } + } + + JXL_DEBUG_V( + 7, "Trying to represent channels %i-%i using at most a %i-color palette.", + begin_c, end_c, nb_colors); + nb_deltas = 0; + bool delta_used = false; + std::set<std::vector<pixel_type>> + candidate_palette; // ordered lexicographically + std::vector<std::vector<pixel_type>> candidate_palette_imageorder; + std::vector<pixel_type> color(nb); + std::vector<float> color_with_error(nb); + std::vector<const pixel_type *> p_in(nb); + + if (lossy) { + palette_iteration_data.FindFrequentColorDeltas(w * h, input.bitdepth); + nb_deltas = palette_iteration_data.frequent_deltas[0].size(); + + // Count color frequency for colors that make a cross. + std::map<std::vector<pixel_type>, size_t> color_freq_map; + for (size_t y = 1; y + 1 < h; y++) { + for (uint32_t c = 0; c < nb; c++) { + p_in[c] = input.channel[begin_c + c].Row(y); + } + for (size_t x = 1; x + 1 < w; x++) { + for (uint32_t c = 0; c < nb; c++) { + color[c] = p_in[c][x]; + } + int offsets[4][2] = {{1, 0}, {-1, 0}, {0, 1}, {0, -1}}; + bool makes_cross = true; + for (int i = 0; i < 4 && makes_cross; ++i) { + int dx = offsets[i][0]; + int dy = offsets[i][1]; + for (uint32_t c = 0; c < nb && makes_cross; c++) { + if (input.channel[begin_c + c].Row(y + dy)[x + dx] != color[c]) { + makes_cross = false; + } + } + } + if (makes_cross) color_freq_map[color] += 1; + } + } + // Add colors satisfying frequency condition to the palette. + constexpr float kImageFraction = 0.01f; + size_t color_frequency_lower_bound = 5 + input.h * input.w * kImageFraction; + for (const auto &color_freq : color_freq_map) { + if (color_freq.second > color_frequency_lower_bound) { + candidate_palette.insert(color_freq.first); + candidate_palette_imageorder.push_back(color_freq.first); + } + } + } + + for (size_t y = 0; y < h; y++) { + for (uint32_t c = 0; c < nb; c++) { + p_in[c] = input.channel[begin_c + c].Row(y); + } + for (size_t x = 0; x < w; x++) { + if (lossy && candidate_palette.size() >= nb_colors) break; + for (uint32_t c = 0; c < nb; c++) { + color[c] = p_in[c][x]; + } + const bool new_color = candidate_palette.insert(color).second; + if (new_color) { + candidate_palette_imageorder.push_back(color); + } + if (candidate_palette.size() > nb_colors) { + return false; // too many colors + } + } + } + + nb_colors = nb_deltas + candidate_palette.size(); + JXL_DEBUG_V(6, "Channels %i-%i can be represented using a %i-color palette.", + begin_c, end_c, nb_colors); + + Channel pch(nb_colors, nb); + pch.hshift = -1; + pch.vshift = -1; + pixel_type *JXL_RESTRICT p_palette = pch.Row(0); + intptr_t onerow = pch.plane.PixelsPerRow(); + intptr_t onerow_image = input.channel[begin_c].plane.PixelsPerRow(); + const int bit_depth = std::min(input.bitdepth, 24); + + if (lossy) { + for (uint32_t i = 0; i < nb_deltas; i++) { + for (size_t c = 0; c < 3; c++) { + p_palette[c * onerow + i] = + palette_iteration_data.frequent_deltas[c][i]; + } + } + } + + int x = 0; + if (ordered) { + JXL_DEBUG_V(7, "Palette of %i colors, using lexicographic order", + nb_colors); + for (auto pcol : candidate_palette) { + JXL_DEBUG_V(9, " Color %i : ", x); + for (size_t i = 0; i < nb; i++) { + p_palette[nb_deltas + i * onerow + x] = pcol[i]; + } + for (size_t i = 0; i < nb; i++) { + JXL_DEBUG_V(9, "%i ", pcol[i]); + } + x++; + } + } else { + JXL_DEBUG_V(7, "Palette of %i colors, using image order", nb_colors); + for (auto pcol : candidate_palette_imageorder) { + JXL_DEBUG_V(9, " Color %i : ", x); + for (size_t i = 0; i < nb; i++) + p_palette[nb_deltas + i * onerow + x] = pcol[i]; + for (size_t i = 0; i < nb; i++) JXL_DEBUG_V(9, "%i ", pcol[i]); + x++; + } + } + std::vector<weighted::State> wp_states; + for (size_t c = 0; c < nb; c++) { + wp_states.emplace_back(wp_header, w, h); + } + std::vector<pixel_type *> p_quant(nb); + // Three rows of error for dithering: y to y + 2. + // Each row has two pixels of padding in the ends, which is + // beneficial for both precision and encoding speed. + std::vector<std::vector<float>> error_row[3]; + if (lossy) { + for (int i = 0; i < 3; ++i) { + error_row[i].resize(nb); + for (size_t c = 0; c < nb; ++c) { + error_row[i][c].resize(w + 4); + } + } + } + for (size_t y = 0; y < h; y++) { + for (size_t c = 0; c < nb; c++) { + p_in[c] = input.channel[begin_c + c].Row(y); + if (lossy) p_quant[c] = quantized_input.channel[c].Row(y); + } + pixel_type *JXL_RESTRICT p = input.channel[begin_c].Row(y); + for (size_t x = 0; x < w; x++) { + int index; + if (!lossy) { + for (size_t c = 0; c < nb; c++) color[c] = p_in[c][x]; + // Exact search. + for (index = 0; static_cast<uint32_t>(index) < nb_colors; index++) { + bool found = true; + for (size_t c = 0; c < nb; c++) { + if (color[c] != p_palette[c * onerow + index]) { + found = false; + break; + } + } + if (found) break; + } + if (index < static_cast<int>(nb_deltas)) { + delta_used = true; + } + } else { + int best_index = 0; + bool best_is_delta = false; + float best_distance = std::numeric_limits<float>::infinity(); + std::vector<pixel_type> best_val(nb, 0); + std::vector<pixel_type> ideal_residual(nb, 0); + std::vector<pixel_type> quantized_val(nb); + std::vector<pixel_type> predictions(nb); + static const double kDiffusionMultiplier[] = {0.55, 0.75}; + for (int diffusion_index = 0; diffusion_index < 2; ++diffusion_index) { + for (size_t c = 0; c < nb; c++) { + color_with_error[c] = + p_in[c][x] + palette_iteration_data.final_run * + kDiffusionMultiplier[diffusion_index] * + error_row[0][c][x + 2]; + color[c] = Clamp1(lroundf(color_with_error[c]), 0l, + (1l << input.bitdepth) - 1); + } + + for (size_t c = 0; c < nb; ++c) { + predictions[c] = PredictNoTreeWP(w, p_quant[c] + x, onerow_image, x, + y, predictor, &wp_states[c]) + .guess; + } + const auto TryIndex = [&](const int index) { + for (size_t c = 0; c < nb; c++) { + quantized_val[c] = palette_internal::GetPaletteValue( + p_palette, index, /*c=*/c, + /*palette_size=*/nb_colors, + /*onerow=*/onerow, /*bit_depth=*/bit_depth); + if (index < static_cast<int>(nb_deltas)) { + quantized_val[c] += predictions[c]; + } + } + const float color_distance = + 32.0 / (1LL << std::max(0, 2 * (bit_depth - 8))) * + palette_internal::ColorDistance(color_with_error, + quantized_val); + float index_penalty = 0; + if (index == -1) { + index_penalty = -124; + } else if (index < 0) { + index_penalty = -2 * index; + } else if (index < static_cast<int>(nb_deltas)) { + index_penalty = 250; + } else if (index < static_cast<int>(nb_colors)) { + index_penalty = 150; + } else if (index < static_cast<int>(nb_colors) + + palette_internal::kLargeCubeOffset) { + index_penalty = 70; + } else { + index_penalty = 256; + } + const float distance = color_distance + index_penalty; + if (distance < best_distance) { + best_distance = distance; + best_index = index; + best_is_delta = index < static_cast<int>(nb_deltas); + best_val.swap(quantized_val); + for (size_t c = 0; c < nb; ++c) { + ideal_residual[c] = color_with_error[c] - predictions[c]; + } + } + }; + for (index = palette_internal::kMinImplicitPaletteIndex; + index < static_cast<int32_t>(nb_colors); index++) { + TryIndex(index); + } + TryIndex(palette_internal::QuantizeColorToImplicitPaletteIndex( + color, nb_colors, bit_depth, + /*high_quality=*/false)); + if (palette_internal::kEncodeToHighQualityImplicitPalette) { + TryIndex(palette_internal::QuantizeColorToImplicitPaletteIndex( + color, nb_colors, bit_depth, + /*high_quality=*/true)); + } + } + index = best_index; + delta_used |= best_is_delta; + if (!palette_iteration_data.final_run) { + for (size_t c = 0; c < 3; ++c) { + palette_iteration_data.deltas[c].push_back(ideal_residual[c]); + } + palette_iteration_data.delta_distances.push_back(best_distance); + } + + for (size_t c = 0; c < nb; ++c) { + wp_states[c].UpdateErrors(best_val[c], x, y, w); + p_quant[c][x] = best_val[c]; + } + float len_error = 0; + for (size_t c = 0; c < nb; ++c) { + float local_error = color_with_error[c] - best_val[c]; + len_error += local_error * local_error; + } + len_error = sqrt(len_error); + float modulate = 1.0; + int len_limit = 38 << std::max(0, bit_depth - 8); + if (len_error > len_limit) { + modulate *= len_limit / len_error; + } + for (size_t c = 0; c < nb; ++c) { + float total_error = (color_with_error[c] - best_val[c]); + + // If the neighboring pixels have some error in the opposite + // direction of total_error, cancel some or all of it out before + // spreading among them. + constexpr int offsets[12][2] = {{1, 2}, {0, 3}, {0, 4}, {1, 1}, + {1, 3}, {2, 2}, {1, 0}, {1, 4}, + {2, 1}, {2, 3}, {2, 0}, {2, 4}}; + float total_available = 0; + for (int i = 0; i < 11; ++i) { + const int row = offsets[i][0]; + const int col = offsets[i][1]; + if (std::signbit(error_row[row][c][x + col]) != + std::signbit(total_error)) { + total_available += error_row[row][c][x + col]; + } + } + float weight = + std::abs(total_error) / (std::abs(total_available) + 1e-3); + weight = std::min(weight, 1.0f); + for (int i = 0; i < 11; ++i) { + const int row = offsets[i][0]; + const int col = offsets[i][1]; + if (std::signbit(error_row[row][c][x + col]) != + std::signbit(total_error)) { + total_error += weight * error_row[row][c][x + col]; + error_row[row][c][x + col] *= (1 - weight); + } + } + total_error *= modulate; + const float remaining_error = (1.0f / 14.) * total_error; + error_row[0][c][x + 3] += 2 * remaining_error; + error_row[0][c][x + 4] += remaining_error; + error_row[1][c][x + 0] += remaining_error; + for (int i = 0; i < 5; ++i) { + error_row[1][c][x + i] += remaining_error; + error_row[2][c][x + i] += remaining_error; + } + } + } + if (palette_iteration_data.final_run) p[x] = index; + } + if (lossy) { + for (size_t c = 0; c < nb; ++c) { + error_row[0][c].swap(error_row[1][c]); + error_row[1][c].swap(error_row[2][c]); + std::fill(error_row[2][c].begin(), error_row[2][c].end(), 0.f); + } + } + } + if (!delta_used) { + predictor = Predictor::Zero; + } + if (palette_iteration_data.final_run) { + input.nb_meta_channels++; + input.channel.erase(input.channel.begin() + begin_c + 1, + input.channel.begin() + end_c + 1); + input.channel.insert(input.channel.begin(), std::move(pch)); + } + nb_colors -= nb_deltas; + return true; +} + +Status FwdPalette(Image &input, uint32_t begin_c, uint32_t end_c, + uint32_t &nb_colors, uint32_t &nb_deltas, bool ordered, + bool lossy, Predictor &predictor, + const weighted::Header &wp_header) { + PaletteIterationData palette_iteration_data; + uint32_t nb_colors_orig = nb_colors; + uint32_t nb_deltas_orig = nb_deltas; + // preprocessing pass in case of lossy palette + if (lossy && input.bitdepth >= 8) { + JXL_RETURN_IF_ERROR(FwdPaletteIteration( + input, begin_c, end_c, nb_colors_orig, nb_deltas_orig, ordered, lossy, + predictor, wp_header, palette_iteration_data)); + } + palette_iteration_data.final_run = true; + return FwdPaletteIteration(input, begin_c, end_c, nb_colors, nb_deltas, + ordered, lossy, predictor, wp_header, + palette_iteration_data); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/enc_palette.h b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_palette.h new file mode 100644 index 0000000000..0f3d66825b --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_palette.h @@ -0,0 +1,22 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_TRANSFORM_ENC_PALETTE_H_ +#define LIB_JXL_MODULAR_TRANSFORM_ENC_PALETTE_H_ + +#include "lib/jxl/fields.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/modular_image.h" + +namespace jxl { + +Status FwdPalette(Image &input, uint32_t begin_c, uint32_t end_c, + uint32_t &nb_colors, uint32_t &nb_deltas, bool ordered, + bool lossy, Predictor &predictor, + const weighted::Header &wp_header); + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_TRANSFORM_ENC_PALETTE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/enc_rct.cc b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_rct.cc new file mode 100644 index 0000000000..050563a3c2 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_rct.cc @@ -0,0 +1,73 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/transform/enc_rct.h" + +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/transform/transform.h" // CheckEqualChannels + +namespace jxl { + +Status FwdRCT(Image& input, size_t begin_c, size_t rct_type, ThreadPool* pool) { + JXL_RETURN_IF_ERROR(CheckEqualChannels(input, begin_c, begin_c + 2)); + if (rct_type == 0) { // noop + return false; + } + // Permutation: 0=RGB, 1=GBR, 2=BRG, 3=RBG, 4=GRB, 5=BGR + int permutation = rct_type / 7; + // 0-5 values have the low bit corresponding to Third and the high bits + // corresponding to Second. 6 corresponds to YCoCg. + // + // Second: 0=nop, 1=SubtractFirst, 2=SubtractAvgFirstThird + // + // Third: 0=nop, 1=SubtractFirst + int custom = rct_type % 7; + size_t m = begin_c; + size_t w = input.channel[m + 0].w; + size_t h = input.channel[m + 0].h; + int second = (custom % 7) >> 1; + int third = (custom % 7) & 1; + const auto do_rct = [&](const int y, const int thread) { + const pixel_type* in0 = input.channel[m + (permutation % 3)].Row(y); + const pixel_type* in1 = + input.channel[m + ((permutation + 1 + permutation / 3) % 3)].Row(y); + const pixel_type* in2 = + input.channel[m + ((permutation + 2 - permutation / 3) % 3)].Row(y); + pixel_type* out0 = input.channel[m].Row(y); + pixel_type* out1 = input.channel[m + 1].Row(y); + pixel_type* out2 = input.channel[m + 2].Row(y); + if (custom == 6) { + for (size_t x = 0; x < w; x++) { + pixel_type R = in0[x]; + pixel_type G = in1[x]; + pixel_type B = in2[x]; + out1[x] = R - B; + pixel_type tmp = B + (out1[x] >> 1); + out2[x] = G - tmp; + out0[x] = tmp + (out2[x] >> 1); + } + } else { + for (size_t x = 0; x < w; x++) { + pixel_type First = in0[x]; + pixel_type Second = in1[x]; + pixel_type Third = in2[x]; + if (second == 1) { + Second = Second - First; + } else if (second == 2) { + Second = Second - ((First + Third) >> 1); + } + if (third) Third = Third - First; + out0[x] = First; + out1[x] = Second; + out2[x] = Third; + } + } + }; + return RunOnPool(pool, 0, h, ThreadPool::NoInit, do_rct, "FwdRCT"); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/enc_rct.h b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_rct.h new file mode 100644 index 0000000000..cb5a193c8d --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_rct.h @@ -0,0 +1,17 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_TRANSFORM_ENC_RCT_H_ +#define LIB_JXL_MODULAR_TRANSFORM_ENC_RCT_H_ + +#include "lib/jxl/modular/modular_image.h" + +namespace jxl { + +Status FwdRCT(Image &input, size_t begin_c, size_t rct_type, ThreadPool *pool); + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_TRANSFORM_ENC_RCT_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/enc_squeeze.cc b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_squeeze.cc new file mode 100644 index 0000000000..dfd90cde68 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_squeeze.cc @@ -0,0 +1,141 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/transform/enc_squeeze.h" + +#include <stdlib.h> + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/common.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/transform/squeeze.h" +#include "lib/jxl/modular/transform/transform.h" + +namespace jxl { + +void FwdHSqueeze(Image &input, int c, int rc) { + const Channel &chin = input.channel[c]; + + JXL_DEBUG_V(4, "Doing horizontal squeeze of channel %i to new channel %i", c, + rc); + + Channel chout((chin.w + 1) / 2, chin.h, chin.hshift + 1, chin.vshift); + Channel chout_residual(chin.w - chout.w, chout.h, chin.hshift + 1, + chin.vshift); + + for (size_t y = 0; y < chout.h; y++) { + const pixel_type *JXL_RESTRICT p_in = chin.Row(y); + pixel_type *JXL_RESTRICT p_out = chout.Row(y); + pixel_type *JXL_RESTRICT p_res = chout_residual.Row(y); + for (size_t x = 0; x < chout_residual.w; x++) { + pixel_type A = p_in[x * 2]; + pixel_type B = p_in[x * 2 + 1]; + pixel_type avg = (A + B + (A > B)) >> 1; + p_out[x] = avg; + + pixel_type diff = A - B; + + pixel_type next_avg = avg; + if (x + 1 < chout_residual.w) { + next_avg = (p_in[x * 2 + 2] + p_in[x * 2 + 3] + + (p_in[x * 2 + 2] > p_in[x * 2 + 3])) >> + 1; // which will be chout.value(y,x+1) + } else if (chin.w & 1) + next_avg = p_in[x * 2 + 2]; + pixel_type left = (x > 0 ? p_in[x * 2 - 1] : avg); + pixel_type tendency = SmoothTendency(left, avg, next_avg); + + p_res[x] = diff - tendency; + } + if (chin.w & 1) { + int x = chout.w - 1; + p_out[x] = p_in[x * 2]; + } + } + input.channel[c] = std::move(chout); + input.channel.insert(input.channel.begin() + rc, std::move(chout_residual)); +} + +void FwdVSqueeze(Image &input, int c, int rc) { + const Channel &chin = input.channel[c]; + + JXL_DEBUG_V(4, "Doing vertical squeeze of channel %i to new channel %i", c, + rc); + + Channel chout(chin.w, (chin.h + 1) / 2, chin.hshift, chin.vshift + 1); + Channel chout_residual(chin.w, chin.h - chout.h, chin.hshift, + chin.vshift + 1); + intptr_t onerow_in = chin.plane.PixelsPerRow(); + for (size_t y = 0; y < chout_residual.h; y++) { + const pixel_type *JXL_RESTRICT p_in = chin.Row(y * 2); + pixel_type *JXL_RESTRICT p_out = chout.Row(y); + pixel_type *JXL_RESTRICT p_res = chout_residual.Row(y); + for (size_t x = 0; x < chout.w; x++) { + pixel_type A = p_in[x]; + pixel_type B = p_in[x + onerow_in]; + pixel_type avg = (A + B + (A > B)) >> 1; + p_out[x] = avg; + + pixel_type diff = A - B; + + pixel_type next_avg = avg; + if (y + 1 < chout_residual.h) { + next_avg = (p_in[x + 2 * onerow_in] + p_in[x + 3 * onerow_in] + + (p_in[x + 2 * onerow_in] > p_in[x + 3 * onerow_in])) >> + 1; // which will be chout.value(y+1,x) + } else if (chin.h & 1) { + next_avg = p_in[x + 2 * onerow_in]; + } + pixel_type top = + (y > 0 ? p_in[static_cast<ssize_t>(x) - onerow_in] : avg); + pixel_type tendency = SmoothTendency(top, avg, next_avg); + + p_res[x] = diff - tendency; + } + } + if (chin.h & 1) { + size_t y = chout.h - 1; + const pixel_type *p_in = chin.Row(y * 2); + pixel_type *p_out = chout.Row(y); + for (size_t x = 0; x < chout.w; x++) { + p_out[x] = p_in[x]; + } + } + input.channel[c] = std::move(chout); + input.channel.insert(input.channel.begin() + rc, std::move(chout_residual)); +} + +Status FwdSqueeze(Image &input, std::vector<SqueezeParams> parameters, + ThreadPool *pool) { + if (parameters.empty()) { + DefaultSqueezeParameters(¶meters, input); + } + // if nothing to do, don't do squeeze + if (parameters.empty()) return false; + for (size_t i = 0; i < parameters.size(); i++) { + JXL_RETURN_IF_ERROR( + CheckMetaSqueezeParams(parameters[i], input.channel.size())); + bool horizontal = parameters[i].horizontal; + bool in_place = parameters[i].in_place; + uint32_t beginc = parameters[i].begin_c; + uint32_t endc = parameters[i].begin_c + parameters[i].num_c - 1; + uint32_t offset; + if (in_place) { + offset = endc + 1; + } else { + offset = input.channel.size(); + } + for (uint32_t c = beginc; c <= endc; c++) { + if (horizontal) { + FwdHSqueeze(input, c, offset + c - beginc); + } else { + FwdVSqueeze(input, c, offset + c - beginc); + } + } + } + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/enc_squeeze.h b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_squeeze.h new file mode 100644 index 0000000000..39b001017b --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_squeeze.h @@ -0,0 +1,20 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_TRANSFORM_ENC_SQUEEZE_H_ +#define LIB_JXL_MODULAR_TRANSFORM_ENC_SQUEEZE_H_ + +#include "lib/jxl/fields.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/transform/transform.h" + +namespace jxl { + +Status FwdSqueeze(Image &input, std::vector<SqueezeParams> parameters, + ThreadPool *pool); + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_TRANSFORM_ENC_SQUEEZE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/enc_transform.cc b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_transform.cc new file mode 100644 index 0000000000..bdaaf9f87e --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_transform.cc @@ -0,0 +1,46 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/transform/enc_transform.h" + +#include "lib/jxl/modular/transform/enc_palette.h" +#include "lib/jxl/modular/transform/enc_rct.h" +#include "lib/jxl/modular/transform/enc_squeeze.h" + +namespace jxl { + +Status TransformForward(Transform &t, Image &input, + const weighted::Header &wp_header, ThreadPool *pool) { + switch (t.id) { + case TransformId::kRCT: + return FwdRCT(input, t.begin_c, t.rct_type, pool); + case TransformId::kSqueeze: + return FwdSqueeze(input, t.squeezes, pool); + case TransformId::kPalette: + return FwdPalette(input, t.begin_c, t.begin_c + t.num_c - 1, t.nb_colors, + t.nb_deltas, t.ordered_palette, t.lossy_palette, + t.predictor, wp_header); + default: + return JXL_FAILURE("Unknown transformation (ID=%u)", + static_cast<unsigned int>(t.id)); + } +} + +void compute_minmax(const Channel &ch, pixel_type *min, pixel_type *max) { + pixel_type realmin = std::numeric_limits<pixel_type>::max(); + pixel_type realmax = std::numeric_limits<pixel_type>::min(); + for (size_t y = 0; y < ch.h; y++) { + const pixel_type *JXL_RESTRICT p = ch.Row(y); + for (size_t x = 0; x < ch.w; x++) { + if (p[x] < realmin) realmin = p[x]; + if (p[x] > realmax) realmax = p[x]; + } + } + + if (min) *min = realmin; + if (max) *max = realmax; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/enc_transform.h b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_transform.h new file mode 100644 index 0000000000..07659e1b0a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_transform.h @@ -0,0 +1,22 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_TRANSFORM_ENC_TRANSFORM_H_ +#define LIB_JXL_MODULAR_TRANSFORM_ENC_TRANSFORM_H_ + +#include "lib/jxl/fields.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/transform/transform.h" + +namespace jxl { + +Status TransformForward(Transform &t, Image &input, + const weighted::Header &wp_header, ThreadPool *pool); + +void compute_minmax(const Channel &ch, pixel_type *min, pixel_type *max); + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_TRANSFORM_ENC_TRANSFORM_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/palette.cc b/third_party/jpeg-xl/lib/jxl/modular/transform/palette.cc new file mode 100644 index 0000000000..46129f19f0 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/palette.cc @@ -0,0 +1,176 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/transform/palette.h" + +namespace jxl { + +Status InvPalette(Image &input, uint32_t begin_c, uint32_t nb_colors, + uint32_t nb_deltas, Predictor predictor, + const weighted::Header &wp_header, ThreadPool *pool) { + if (input.nb_meta_channels < 1) { + return JXL_FAILURE("Error: Palette transform without palette."); + } + std::atomic<int> num_errors{0}; + int nb = input.channel[0].h; + uint32_t c0 = begin_c + 1; + if (c0 >= input.channel.size()) { + return JXL_FAILURE("Channel is out of range."); + } + size_t w = input.channel[c0].w; + size_t h = input.channel[c0].h; + if (nb < 1) return JXL_FAILURE("Corrupted transforms"); + for (int i = 1; i < nb; i++) { + input.channel.insert( + input.channel.begin() + c0 + 1, + Channel(w, h, input.channel[c0].hshift, input.channel[c0].vshift)); + } + const Channel &palette = input.channel[0]; + const pixel_type *JXL_RESTRICT p_palette = input.channel[0].Row(0); + intptr_t onerow = input.channel[0].plane.PixelsPerRow(); + intptr_t onerow_image = input.channel[c0].plane.PixelsPerRow(); + const int bit_depth = std::min(input.bitdepth, 24); + + if (w == 0) { + // Nothing to do. + // Avoid touching "empty" channels with non-zero height. + } else if (nb_deltas == 0 && predictor == Predictor::Zero) { + if (nb == 1) { + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, h, ThreadPool::NoInit, + [&](const uint32_t task, size_t /* thread */) { + const size_t y = task; + pixel_type *p = input.channel[c0].Row(y); + for (size_t x = 0; x < w; x++) { + const int index = Clamp1<int>(p[x], 0, (pixel_type)palette.w - 1); + p[x] = palette_internal::GetPaletteValue( + p_palette, index, /*c=*/0, + /*palette_size=*/palette.w, + /*onerow=*/onerow, /*bit_depth=*/bit_depth); + } + }, + "UndoChannelPalette")); + } else { + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, h, ThreadPool::NoInit, + [&](const uint32_t task, size_t /* thread */) { + const size_t y = task; + std::vector<pixel_type *> p_out(nb); + const pixel_type *p_index = input.channel[c0].Row(y); + for (int c = 0; c < nb; c++) + p_out[c] = input.channel[c0 + c].Row(y); + for (size_t x = 0; x < w; x++) { + const int index = p_index[x]; + for (int c = 0; c < nb; c++) { + p_out[c][x] = palette_internal::GetPaletteValue( + p_palette, index, /*c=*/c, + /*palette_size=*/palette.w, + /*onerow=*/onerow, /*bit_depth=*/bit_depth); + } + } + }, + "UndoPalette")); + } + } else { + // Parallelized per channel. + ImageI indices = CopyImage(input.channel[c0].plane); + if (predictor == Predictor::Weighted) { + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, nb, ThreadPool::NoInit, + [&](const uint32_t c, size_t /* thread */) { + Channel &channel = input.channel[c0 + c]; + weighted::State wp_state(wp_header, channel.w, channel.h); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT p = channel.Row(y); + const pixel_type *JXL_RESTRICT idx = indices.Row(y); + for (size_t x = 0; x < channel.w; x++) { + int index = idx[x]; + pixel_type_w val = 0; + const pixel_type palette_entry = + palette_internal::GetPaletteValue( + p_palette, index, /*c=*/c, + /*palette_size=*/palette.w, /*onerow=*/onerow, + /*bit_depth=*/bit_depth); + if (index < static_cast<int32_t>(nb_deltas)) { + PredictionResult pred = + PredictNoTreeWP(channel.w, p + x, onerow_image, x, y, + predictor, &wp_state); + val = pred.guess + palette_entry; + } else { + val = palette_entry; + } + p[x] = val; + wp_state.UpdateErrors(p[x], x, y, channel.w); + } + } + }, + "UndoDeltaPaletteWP")); + } else { + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, nb, ThreadPool::NoInit, + [&](const uint32_t c, size_t /* thread */) { + Channel &channel = input.channel[c0 + c]; + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT p = channel.Row(y); + const pixel_type *JXL_RESTRICT idx = indices.Row(y); + for (size_t x = 0; x < channel.w; x++) { + int index = idx[x]; + pixel_type_w val = 0; + const pixel_type palette_entry = + palette_internal::GetPaletteValue( + p_palette, index, /*c=*/c, + /*palette_size=*/palette.w, + /*onerow=*/onerow, /*bit_depth=*/bit_depth); + if (index < static_cast<int32_t>(nb_deltas)) { + PredictionResult pred = PredictNoTreeNoWP( + channel.w, p + x, onerow_image, x, y, predictor); + val = pred.guess + palette_entry; + } else { + val = palette_entry; + } + p[x] = val; + } + } + }, + "UndoDeltaPaletteNoWP")); + } + } + if (c0 >= input.nb_meta_channels) { + // Palette was done on normal channels + input.nb_meta_channels--; + } else { + // Palette was done on metachannels + JXL_ASSERT(static_cast<int>(input.nb_meta_channels) >= 2 - nb); + input.nb_meta_channels -= 2 - nb; + JXL_ASSERT(begin_c + nb - 1 < input.nb_meta_channels); + } + input.channel.erase(input.channel.begin(), input.channel.begin() + 1); + return num_errors.load(std::memory_order_relaxed) == 0; +} + +Status MetaPalette(Image &input, uint32_t begin_c, uint32_t end_c, + uint32_t nb_colors, uint32_t nb_deltas, bool lossy) { + JXL_RETURN_IF_ERROR(CheckEqualChannels(input, begin_c, end_c)); + + size_t nb = end_c - begin_c + 1; + if (begin_c >= input.nb_meta_channels) { + // Palette was done on normal channels + input.nb_meta_channels++; + } else { + // Palette was done on metachannels + JXL_ASSERT(end_c < input.nb_meta_channels); + // we remove nb-1 metachannels and add one + input.nb_meta_channels += 2 - nb; + } + input.channel.erase(input.channel.begin() + begin_c + 1, + input.channel.begin() + end_c + 1); + Channel pch(nb_colors + nb_deltas, nb); + pch.hshift = -1; + pch.vshift = -1; + input.channel.insert(input.channel.begin(), std::move(pch)); + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/palette.h b/third_party/jpeg-xl/lib/jxl/modular/transform/palette.h new file mode 100644 index 0000000000..cc0f67960b --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/palette.h @@ -0,0 +1,129 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_TRANSFORM_PALETTE_H_ +#define LIB_JXL_MODULAR_TRANSFORM_PALETTE_H_ + +#include <atomic> + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/transform/transform.h" // CheckEqualChannels + +namespace jxl { + +namespace palette_internal { + +static constexpr int kMaxPaletteLookupTableSize = 1 << 16; + +static constexpr int kRgbChannels = 3; + +// 5x5x5 color cube for the larger cube. +static constexpr int kLargeCube = 5; + +// Smaller interleaved color cube to fill the holes of the larger cube. +static constexpr int kSmallCube = 4; +static constexpr int kSmallCubeBits = 2; +// kSmallCube ** 3 +static constexpr int kLargeCubeOffset = kSmallCube * kSmallCube * kSmallCube; + +static inline pixel_type Scale(uint64_t value, uint64_t bit_depth, + uint64_t denom) { + // return (value * ((static_cast<pixel_type_w>(1) << bit_depth) - 1)) / denom; + // We only call this function with kSmallCube or kLargeCube - 1 as denom, + // allowing us to avoid a division here. + JXL_ASSERT(denom == 4); + return (value * ((static_cast<uint64_t>(1) << bit_depth) - 1)) >> 2; +} + +// The purpose of this function is solely to extend the interpretation of +// palette indices to implicit values. If index < nb_deltas, indicating that the +// result is a delta palette entry, it is the responsibility of the caller to +// treat it as such. +static JXL_MAYBE_UNUSED pixel_type +GetPaletteValue(const pixel_type *const palette, int index, const size_t c, + const int palette_size, const int onerow, const int bit_depth) { + if (index < 0) { + static constexpr std::array<std::array<pixel_type, 3>, 72> kDeltaPalette = { + { + {{0, 0, 0}}, {{4, 4, 4}}, {{11, 0, 0}}, + {{0, 0, -13}}, {{0, -12, 0}}, {{-10, -10, -10}}, + {{-18, -18, -18}}, {{-27, -27, -27}}, {{-18, -18, 0}}, + {{0, 0, -32}}, {{-32, 0, 0}}, {{-37, -37, -37}}, + {{0, -32, -32}}, {{24, 24, 45}}, {{50, 50, 50}}, + {{-45, -24, -24}}, {{-24, -45, -45}}, {{0, -24, -24}}, + {{-34, -34, 0}}, {{-24, 0, -24}}, {{-45, -45, -24}}, + {{64, 64, 64}}, {{-32, 0, -32}}, {{0, -32, 0}}, + {{-32, 0, 32}}, {{-24, -45, -24}}, {{45, 24, 45}}, + {{24, -24, -45}}, {{-45, -24, 24}}, {{80, 80, 80}}, + {{64, 0, 0}}, {{0, 0, -64}}, {{0, -64, -64}}, + {{-24, -24, 45}}, {{96, 96, 96}}, {{64, 64, 0}}, + {{45, -24, -24}}, {{34, -34, 0}}, {{112, 112, 112}}, + {{24, -45, -45}}, {{45, 45, -24}}, {{0, -32, 32}}, + {{24, -24, 45}}, {{0, 96, 96}}, {{45, -24, 24}}, + {{24, -45, -24}}, {{-24, -45, 24}}, {{0, -64, 0}}, + {{96, 0, 0}}, {{128, 128, 128}}, {{64, 0, 64}}, + {{144, 144, 144}}, {{96, 96, 0}}, {{-36, -36, 36}}, + {{45, -24, -45}}, {{45, -45, -24}}, {{0, 0, -96}}, + {{0, 128, 128}}, {{0, 96, 0}}, {{45, 24, -45}}, + {{-128, 0, 0}}, {{24, -45, 24}}, {{-45, 24, -45}}, + {{64, 0, -64}}, {{64, -64, -64}}, {{96, 0, 96}}, + {{45, -45, 24}}, {{24, 45, -45}}, {{64, 64, -64}}, + {{128, 128, 0}}, {{0, 0, -128}}, {{-24, 45, -45}}, + }}; + if (c >= kRgbChannels) { + return 0; + } + // Do not open the brackets, otherwise INT32_MIN negation could overflow. + index = -(index + 1); + index %= 1 + 2 * (kDeltaPalette.size() - 1); + static constexpr int kMultiplier[] = {-1, 1}; + pixel_type result = + kDeltaPalette[((index + 1) >> 1)][c] * kMultiplier[index & 1]; + if (bit_depth > 8) { + result *= static_cast<pixel_type>(1) << (bit_depth - 8); + } + return result; + } else if (palette_size <= index && index < palette_size + kLargeCubeOffset) { + if (c >= kRgbChannels) return 0; + index -= palette_size; + index >>= c * kSmallCubeBits; + return Scale(index % kSmallCube, bit_depth, kSmallCube) + + (1 << (std::max(0, bit_depth - 3))); + } else if (palette_size + kLargeCubeOffset <= index) { + if (c >= kRgbChannels) return 0; + index -= palette_size + kLargeCubeOffset; + // TODO(eustas): should we take care of ambiguity created by + // index >= kLargeCube ** 3 ? + switch (c) { + case 0: + break; + case 1: + index /= kLargeCube; + break; + case 2: + index /= kLargeCube * kLargeCube; + break; + } + return Scale(index % kLargeCube, bit_depth, kLargeCube - 1); + } + return palette[c * onerow + static_cast<size_t>(index)]; +} + +} // namespace palette_internal + +Status InvPalette(Image &input, uint32_t begin_c, uint32_t nb_colors, + uint32_t nb_deltas, Predictor predictor, + const weighted::Header &wp_header, ThreadPool *pool); + +Status MetaPalette(Image &input, uint32_t begin_c, uint32_t end_c, + uint32_t nb_colors, uint32_t nb_deltas, bool lossy); + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_TRANSFORM_PALETTE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/rct.cc b/third_party/jpeg-xl/lib/jxl/modular/transform/rct.cc new file mode 100644 index 0000000000..f3002a5ac3 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/rct.cc @@ -0,0 +1,153 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/transform/rct.h" +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/modular/transform/rct.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::ShiftRight; +using hwy::HWY_NAMESPACE::Sub; + +template <int transform_type> +void InvRCTRow(const pixel_type* in0, const pixel_type* in1, + const pixel_type* in2, pixel_type* out0, pixel_type* out1, + pixel_type* out2, size_t w) { + static_assert(transform_type >= 0 && transform_type < 7, + "Invalid transform type"); + int second = transform_type >> 1; + int third = transform_type & 1; + + size_t x = 0; + const HWY_FULL(pixel_type) d; + const size_t N = Lanes(d); + for (; x + N - 1 < w; x += N) { + if (transform_type == 6) { + auto Y = Load(d, in0 + x); + auto Co = Load(d, in1 + x); + auto Cg = Load(d, in2 + x); + Y = Sub(Y, ShiftRight<1>(Cg)); + auto G = Add(Cg, Y); + Y = Sub(Y, ShiftRight<1>(Co)); + auto R = Add(Y, Co); + Store(R, d, out0 + x); + Store(G, d, out1 + x); + Store(Y, d, out2 + x); + } else { + auto First = Load(d, in0 + x); + auto Second = Load(d, in1 + x); + auto Third = Load(d, in2 + x); + if (third) Third = Add(Third, First); + if (second == 1) { + Second = Add(Second, First); + } else if (second == 2) { + Second = Add(Second, ShiftRight<1>(Add(First, Third))); + } + Store(First, d, out0 + x); + Store(Second, d, out1 + x); + Store(Third, d, out2 + x); + } + } + for (; x < w; x++) { + if (transform_type == 6) { + pixel_type Y = in0[x]; + pixel_type Co = in1[x]; + pixel_type Cg = in2[x]; + pixel_type tmp = PixelAdd(Y, -(Cg >> 1)); + pixel_type G = PixelAdd(Cg, tmp); + pixel_type B = PixelAdd(tmp, -(Co >> 1)); + pixel_type R = PixelAdd(B, Co); + out0[x] = R; + out1[x] = G; + out2[x] = B; + } else { + pixel_type First = in0[x]; + pixel_type Second = in1[x]; + pixel_type Third = in2[x]; + if (third) Third = PixelAdd(Third, First); + if (second == 1) { + Second = PixelAdd(Second, First); + } else if (second == 2) { + Second = PixelAdd(Second, (PixelAdd(First, Third) >> 1)); + } + out0[x] = First; + out1[x] = Second; + out2[x] = Third; + } + } +} + +Status InvRCT(Image& input, size_t begin_c, size_t rct_type, ThreadPool* pool) { + JXL_RETURN_IF_ERROR(CheckEqualChannels(input, begin_c, begin_c + 2)); + size_t m = begin_c; + Channel& c0 = input.channel[m + 0]; + size_t w = c0.w; + size_t h = c0.h; + if (rct_type == 0) { // noop + return true; + } + // Permutation: 0=RGB, 1=GBR, 2=BRG, 3=RBG, 4=GRB, 5=BGR + int permutation = rct_type / 7; + JXL_CHECK(permutation < 6); + // 0-5 values have the low bit corresponding to Third and the high bits + // corresponding to Second. 6 corresponds to YCoCg. + // + // Second: 0=nop, 1=SubtractFirst, 2=SubtractAvgFirstThird + // + // Third: 0=nop, 1=SubtractFirst + int custom = rct_type % 7; + // Special case: permute-only. Swap channels around. + if (custom == 0) { + Channel ch0 = std::move(input.channel[m]); + Channel ch1 = std::move(input.channel[m + 1]); + Channel ch2 = std::move(input.channel[m + 2]); + input.channel[m + (permutation % 3)] = std::move(ch0); + input.channel[m + ((permutation + 1 + permutation / 3) % 3)] = + std::move(ch1); + input.channel[m + ((permutation + 2 - permutation / 3) % 3)] = + std::move(ch2); + return true; + } + constexpr decltype(&InvRCTRow<0>) inv_rct_row[] = { + InvRCTRow<0>, InvRCTRow<1>, InvRCTRow<2>, InvRCTRow<3>, + InvRCTRow<4>, InvRCTRow<5>, InvRCTRow<6>}; + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, h, ThreadPool::NoInit, + [&](const uint32_t task, size_t /* thread */) { + const size_t y = task; + const pixel_type* in0 = input.channel[m].Row(y); + const pixel_type* in1 = input.channel[m + 1].Row(y); + const pixel_type* in2 = input.channel[m + 2].Row(y); + pixel_type* out0 = input.channel[m + (permutation % 3)].Row(y); + pixel_type* out1 = + input.channel[m + ((permutation + 1 + permutation / 3) % 3)].Row(y); + pixel_type* out2 = + input.channel[m + ((permutation + 2 - permutation / 3) % 3)].Row(y); + inv_rct_row[custom](in0, in1, in2, out0, out1, out2, w); + }, + "InvRCT")); + return true; +} + +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(InvRCT); +Status InvRCT(Image& input, size_t begin_c, size_t rct_type, ThreadPool* pool) { + return HWY_DYNAMIC_DISPATCH(InvRCT)(input, begin_c, rct_type, pool); +} + +} // namespace jxl +#endif diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/rct.h b/third_party/jpeg-xl/lib/jxl/modular/transform/rct.h new file mode 100644 index 0000000000..aef65621d5 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/rct.h @@ -0,0 +1,20 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_TRANSFORM_RCT_H_ +#define LIB_JXL_MODULAR_TRANSFORM_RCT_H_ + +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/transform/transform.h" // CheckEqualChannels + +namespace jxl { + +Status InvRCT(Image& input, size_t begin_c, size_t rct_type, ThreadPool* pool); + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_TRANSFORM_RCT_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/squeeze.cc b/third_party/jpeg-xl/lib/jxl/modular/transform/squeeze.cc new file mode 100644 index 0000000000..8440d9e804 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/squeeze.cc @@ -0,0 +1,478 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/transform/squeeze.h" + +#include <stdlib.h> + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/common.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/transform/transform.h" +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/modular/transform/squeeze.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/simd_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Abs; +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::And; +using hwy::HWY_NAMESPACE::Gt; +using hwy::HWY_NAMESPACE::IfThenElse; +using hwy::HWY_NAMESPACE::IfThenZeroElse; +using hwy::HWY_NAMESPACE::Lt; +using hwy::HWY_NAMESPACE::MulEven; +using hwy::HWY_NAMESPACE::Ne; +using hwy::HWY_NAMESPACE::Neg; +using hwy::HWY_NAMESPACE::OddEven; +using hwy::HWY_NAMESPACE::RebindToUnsigned; +using hwy::HWY_NAMESPACE::ShiftLeft; +using hwy::HWY_NAMESPACE::ShiftRight; +using hwy::HWY_NAMESPACE::Sub; +using hwy::HWY_NAMESPACE::Xor; + +#if HWY_TARGET != HWY_SCALAR + +JXL_INLINE void FastUnsqueeze(const pixel_type *JXL_RESTRICT p_residual, + const pixel_type *JXL_RESTRICT p_avg, + const pixel_type *JXL_RESTRICT p_navg, + const pixel_type *p_pout, + pixel_type *JXL_RESTRICT p_out, + pixel_type *p_nout) { + const HWY_CAPPED(pixel_type, 8) d; + const RebindToUnsigned<decltype(d)> du; + const size_t N = Lanes(d); + auto onethird = Set(d, 0x55555556); + for (size_t x = 0; x < 8; x += N) { + auto avg = Load(d, p_avg + x); + auto next_avg = Load(d, p_navg + x); + auto top = Load(d, p_pout + x); + // Equivalent to SmoothTendency(top,avg,next_avg), but without branches + auto Ba = Sub(top, avg); + auto an = Sub(avg, next_avg); + auto nonmono = Xor(Ba, an); + auto absBa = Abs(Ba); + auto absan = Abs(an); + auto absBn = Abs(Sub(top, next_avg)); + // Compute a3 = absBa / 3 + auto a3e = BitCast(d, ShiftRight<32>(MulEven(absBa, onethird))); + auto a3oi = MulEven(Reverse(d, absBa), onethird); + auto a3o = BitCast( + d, Reverse(hwy::HWY_NAMESPACE::Repartition<pixel_type_w, decltype(d)>(), + a3oi)); + auto a3 = OddEven(a3o, a3e); + a3 = Add(a3, Add(absBn, Set(d, 2))); + auto absdiff = ShiftRight<2>(a3); + auto skipdiff = Ne(Ba, Zero(d)); + skipdiff = And(skipdiff, Ne(an, Zero(d))); + skipdiff = And(skipdiff, Lt(nonmono, Zero(d))); + auto absBa2 = Add(ShiftLeft<1>(absBa), And(absdiff, Set(d, 1))); + absdiff = IfThenElse(Gt(absdiff, absBa2), + Add(ShiftLeft<1>(absBa), Set(d, 1)), absdiff); + auto absan2 = ShiftLeft<1>(absan); + absdiff = IfThenElse(Gt(Add(absdiff, And(absdiff, Set(d, 1))), absan2), + absan2, absdiff); + auto diff1 = IfThenElse(Lt(top, next_avg), Neg(absdiff), absdiff); + auto tendency = IfThenZeroElse(skipdiff, diff1); + + auto diff_minus_tendency = Load(d, p_residual + x); + auto diff = Add(diff_minus_tendency, tendency); + auto out = + Add(avg, ShiftRight<1>( + Add(diff, BitCast(d, ShiftRight<31>(BitCast(du, diff)))))); + Store(out, d, p_out + x); + Store(Sub(out, diff), d, p_nout + x); + } +} + +#endif + +Status InvHSqueeze(Image &input, uint32_t c, uint32_t rc, ThreadPool *pool) { + JXL_ASSERT(c < input.channel.size()); + JXL_ASSERT(rc < input.channel.size()); + Channel &chin = input.channel[c]; + const Channel &chin_residual = input.channel[rc]; + // These must be valid since we ran MetaApply already. + JXL_ASSERT(chin.w == DivCeil(chin.w + chin_residual.w, 2)); + JXL_ASSERT(chin.h == chin_residual.h); + + if (chin_residual.w == 0) { + // Short-circuit: output channel has same dimensions as input. + input.channel[c].hshift--; + return true; + } + + // Note: chin.w >= chin_residual.w and at most 1 different. + Channel chout(chin.w + chin_residual.w, chin.h, chin.hshift - 1, chin.vshift); + JXL_DEBUG_V(4, + "Undoing horizontal squeeze of channel %i using residuals in " + "channel %i (going from width %" PRIuS " to %" PRIuS ")", + c, rc, chin.w, chout.w); + + if (chin_residual.h == 0) { + // Short-circuit: channel with no pixels. + input.channel[c] = std::move(chout); + return true; + } + auto unsqueeze_row = [&](size_t y, size_t x0) { + const pixel_type *JXL_RESTRICT p_residual = chin_residual.Row(y); + const pixel_type *JXL_RESTRICT p_avg = chin.Row(y); + pixel_type *JXL_RESTRICT p_out = chout.Row(y); + for (size_t x = x0; x < chin_residual.w; x++) { + pixel_type_w diff_minus_tendency = p_residual[x]; + pixel_type_w avg = p_avg[x]; + pixel_type_w next_avg = (x + 1 < chin.w ? p_avg[x + 1] : avg); + pixel_type_w left = (x ? p_out[(x << 1) - 1] : avg); + pixel_type_w tendency = SmoothTendency(left, avg, next_avg); + pixel_type_w diff = diff_minus_tendency + tendency; + pixel_type_w A = avg + (diff / 2); + p_out[(x << 1)] = A; + pixel_type_w B = A - diff; + p_out[(x << 1) + 1] = B; + } + if (chout.w & 1) p_out[chout.w - 1] = p_avg[chin.w - 1]; + }; + + // somewhat complicated trickery just to be able to SIMD this. + // Horizontal unsqueeze has horizontal data dependencies, so we do + // 8 rows at a time and treat it as a vertical unsqueeze of a + // transposed 8x8 block (or 9x8 for one input). + static constexpr const size_t kRowsPerThread = 8; + const auto unsqueeze_span = [&](const uint32_t task, size_t /* thread */) { + const size_t y0 = task * kRowsPerThread; + const size_t rows = std::min(kRowsPerThread, chin.h - y0); + size_t x = 0; + +#if HWY_TARGET != HWY_SCALAR + intptr_t onerow_in = chin.plane.PixelsPerRow(); + intptr_t onerow_inr = chin_residual.plane.PixelsPerRow(); + intptr_t onerow_out = chout.plane.PixelsPerRow(); + const pixel_type *JXL_RESTRICT p_residual = chin_residual.Row(y0); + const pixel_type *JXL_RESTRICT p_avg = chin.Row(y0); + pixel_type *JXL_RESTRICT p_out = chout.Row(y0); + HWY_ALIGN pixel_type b_p_avg[9 * kRowsPerThread]; + HWY_ALIGN pixel_type b_p_residual[8 * kRowsPerThread]; + HWY_ALIGN pixel_type b_p_out_even[8 * kRowsPerThread]; + HWY_ALIGN pixel_type b_p_out_odd[8 * kRowsPerThread]; + HWY_ALIGN pixel_type b_p_out_evenT[8 * kRowsPerThread]; + HWY_ALIGN pixel_type b_p_out_oddT[8 * kRowsPerThread]; + const HWY_CAPPED(pixel_type, 8) d; + const size_t N = Lanes(d); + if (chin_residual.w > 16 && rows == kRowsPerThread) { + for (; x < chin_residual.w - 9; x += 8) { + Transpose8x8Block(p_residual + x, b_p_residual, onerow_inr); + Transpose8x8Block(p_avg + x, b_p_avg, onerow_in); + for (size_t y = 0; y < kRowsPerThread; y++) { + b_p_avg[8 * 8 + y] = p_avg[x + 8 + onerow_in * y]; + } + for (size_t i = 0; i < 8; i++) { + FastUnsqueeze( + b_p_residual + 8 * i, b_p_avg + 8 * i, b_p_avg + 8 * (i + 1), + (x + i ? b_p_out_odd + 8 * ((x + i - 1) & 7) : b_p_avg + 8 * i), + b_p_out_even + 8 * i, b_p_out_odd + 8 * i); + } + + Transpose8x8Block(b_p_out_even, b_p_out_evenT, 8); + Transpose8x8Block(b_p_out_odd, b_p_out_oddT, 8); + for (size_t y = 0; y < kRowsPerThread; y++) { + for (size_t i = 0; i < kRowsPerThread; i += N) { + auto even = Load(d, b_p_out_evenT + 8 * y + i); + auto odd = Load(d, b_p_out_oddT + 8 * y + i); + StoreInterleaved(d, even, odd, + p_out + ((x + i) << 1) + onerow_out * y); + } + } + } + } +#endif + for (size_t y = 0; y < rows; y++) { + unsqueeze_row(y0 + y, x); + } + }; + JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, DivCeil(chin.h, kRowsPerThread), + ThreadPool::NoInit, unsqueeze_span, + "InvHorizontalSqueeze")); + input.channel[c] = std::move(chout); + return true; +} + +Status InvVSqueeze(Image &input, uint32_t c, uint32_t rc, ThreadPool *pool) { + JXL_ASSERT(c < input.channel.size()); + JXL_ASSERT(rc < input.channel.size()); + const Channel &chin = input.channel[c]; + const Channel &chin_residual = input.channel[rc]; + // These must be valid since we ran MetaApply already. + JXL_ASSERT(chin.h == DivCeil(chin.h + chin_residual.h, 2)); + JXL_ASSERT(chin.w == chin_residual.w); + + if (chin_residual.h == 0) { + // Short-circuit: output channel has same dimensions as input. + input.channel[c].vshift--; + return true; + } + + // Note: chin.h >= chin_residual.h and at most 1 different. + Channel chout(chin.w, chin.h + chin_residual.h, chin.hshift, chin.vshift - 1); + JXL_DEBUG_V( + 4, + "Undoing vertical squeeze of channel %i using residuals in channel " + "%i (going from height %" PRIuS " to %" PRIuS ")", + c, rc, chin.h, chout.h); + + if (chin_residual.w == 0) { + // Short-circuit: channel with no pixels. + input.channel[c] = std::move(chout); + return true; + } + + static constexpr const int kColsPerThread = 64; + const auto unsqueeze_slice = [&](const uint32_t task, size_t /* thread */) { + const size_t x0 = task * kColsPerThread; + const size_t x1 = std::min((size_t)(task + 1) * kColsPerThread, chin.w); + const size_t w = x1 - x0; + // We only iterate up to std::min(chin_residual.h, chin.h) which is + // always chin_residual.h. + for (size_t y = 0; y < chin_residual.h; y++) { + const pixel_type *JXL_RESTRICT p_residual = chin_residual.Row(y) + x0; + const pixel_type *JXL_RESTRICT p_avg = chin.Row(y) + x0; + const pixel_type *JXL_RESTRICT p_navg = + chin.Row(y + 1 < chin.h ? y + 1 : y) + x0; + pixel_type *JXL_RESTRICT p_out = chout.Row(y << 1) + x0; + pixel_type *JXL_RESTRICT p_nout = chout.Row((y << 1) + 1) + x0; + const pixel_type *p_pout = y > 0 ? chout.Row((y << 1) - 1) + x0 : p_avg; + size_t x = 0; +#if HWY_TARGET != HWY_SCALAR + for (; x + 7 < w; x += 8) { + FastUnsqueeze(p_residual + x, p_avg + x, p_navg + x, p_pout + x, + p_out + x, p_nout + x); + } +#endif + for (; x < w; x++) { + pixel_type_w avg = p_avg[x]; + pixel_type_w next_avg = p_navg[x]; + pixel_type_w top = p_pout[x]; + pixel_type_w tendency = SmoothTendency(top, avg, next_avg); + pixel_type_w diff_minus_tendency = p_residual[x]; + pixel_type_w diff = diff_minus_tendency + tendency; + pixel_type_w out = avg + (diff / 2); + p_out[x] = out; + // If the chin_residual.h == chin.h, the output has an even number + // of rows so the next line is fine. Otherwise, this loop won't + // write to the last output row which is handled separately. + p_nout[x] = out - diff; + } + } + }; + JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, DivCeil(chin.w, kColsPerThread), + ThreadPool::NoInit, unsqueeze_slice, + "InvVertSqueeze")); + + if (chout.h & 1) { + size_t y = chin.h - 1; + const pixel_type *p_avg = chin.Row(y); + pixel_type *p_out = chout.Row(y << 1); + for (size_t x = 0; x < chin.w; x++) { + p_out[x] = p_avg[x]; + } + } + input.channel[c] = std::move(chout); + return true; +} + +Status InvSqueeze(Image &input, std::vector<SqueezeParams> parameters, + ThreadPool *pool) { + for (int i = parameters.size() - 1; i >= 0; i--) { + JXL_RETURN_IF_ERROR( + CheckMetaSqueezeParams(parameters[i], input.channel.size())); + bool horizontal = parameters[i].horizontal; + bool in_place = parameters[i].in_place; + uint32_t beginc = parameters[i].begin_c; + uint32_t endc = parameters[i].begin_c + parameters[i].num_c - 1; + uint32_t offset; + if (in_place) { + offset = endc + 1; + } else { + offset = input.channel.size() + beginc - endc - 1; + } + if (beginc < input.nb_meta_channels) { + // This is checked in MetaSqueeze. + JXL_ASSERT(input.nb_meta_channels > parameters[i].num_c); + input.nb_meta_channels -= parameters[i].num_c; + } + + for (uint32_t c = beginc; c <= endc; c++) { + uint32_t rc = offset + c - beginc; + // MetaApply should imply that `rc` is within range, otherwise there's a + // programming bug. + JXL_ASSERT(rc < input.channel.size()); + if ((input.channel[c].w < input.channel[rc].w) || + (input.channel[c].h < input.channel[rc].h)) { + return JXL_FAILURE("Corrupted squeeze transform"); + } + if (horizontal) { + JXL_RETURN_IF_ERROR(InvHSqueeze(input, c, rc, pool)); + } else { + JXL_RETURN_IF_ERROR(InvVSqueeze(input, c, rc, pool)); + } + } + input.channel.erase(input.channel.begin() + offset, + input.channel.begin() + offset + (endc - beginc + 1)); + } + return true; +} + +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace jxl { + +HWY_EXPORT(InvSqueeze); +Status InvSqueeze(Image &input, std::vector<SqueezeParams> parameters, + ThreadPool *pool) { + return HWY_DYNAMIC_DISPATCH(InvSqueeze)(input, parameters, pool); +} + +void DefaultSqueezeParameters(std::vector<SqueezeParams> *parameters, + const Image &image) { + int nb_channels = image.channel.size() - image.nb_meta_channels; + + parameters->clear(); + size_t w = image.channel[image.nb_meta_channels].w; + size_t h = image.channel[image.nb_meta_channels].h; + JXL_DEBUG_V( + 7, "Default squeeze parameters for %" PRIuS "x%" PRIuS " image: ", w, h); + + // do horizontal first on wide images; vertical first on tall images + bool wide = (w > h); + + if (nb_channels > 2 && image.channel[image.nb_meta_channels + 1].w == w && + image.channel[image.nb_meta_channels + 1].h == h) { + // assume channels 1 and 2 are chroma, and can be squeezed first for 4:2:0 + // previews + JXL_DEBUG_V(7, "(4:2:0 chroma), %" PRIuS "x%" PRIuS " image", w, h); + SqueezeParams params; + // horizontal chroma squeeze + params.horizontal = true; + params.in_place = false; + params.begin_c = image.nb_meta_channels + 1; + params.num_c = 2; + parameters->push_back(params); + params.horizontal = false; + // vertical chroma squeeze + parameters->push_back(params); + } + SqueezeParams params; + params.begin_c = image.nb_meta_channels; + params.num_c = nb_channels; + params.in_place = true; + + if (!wide) { + if (h > JXL_MAX_FIRST_PREVIEW_SIZE) { + params.horizontal = false; + parameters->push_back(params); + h = (h + 1) / 2; + JXL_DEBUG_V(7, "Vertical (%" PRIuS "x%" PRIuS "), ", w, h); + } + } + while (w > JXL_MAX_FIRST_PREVIEW_SIZE || h > JXL_MAX_FIRST_PREVIEW_SIZE) { + if (w > JXL_MAX_FIRST_PREVIEW_SIZE) { + params.horizontal = true; + parameters->push_back(params); + w = (w + 1) / 2; + JXL_DEBUG_V(7, "Horizontal (%" PRIuS "x%" PRIuS "), ", w, h); + } + if (h > JXL_MAX_FIRST_PREVIEW_SIZE) { + params.horizontal = false; + parameters->push_back(params); + h = (h + 1) / 2; + JXL_DEBUG_V(7, "Vertical (%" PRIuS "x%" PRIuS "), ", w, h); + } + } + JXL_DEBUG_V(7, "that's it"); +} + +Status CheckMetaSqueezeParams(const SqueezeParams ¶meter, + int num_channels) { + int c1 = parameter.begin_c; + int c2 = parameter.begin_c + parameter.num_c - 1; + if (c1 < 0 || c1 >= num_channels || c2 < 0 || c2 >= num_channels || c2 < c1) { + return JXL_FAILURE("Invalid channel range"); + } + return true; +} + +Status MetaSqueeze(Image &image, std::vector<SqueezeParams> *parameters) { + if (parameters->empty()) { + DefaultSqueezeParameters(parameters, image); + } + + for (size_t i = 0; i < parameters->size(); i++) { + JXL_RETURN_IF_ERROR( + CheckMetaSqueezeParams((*parameters)[i], image.channel.size())); + bool horizontal = (*parameters)[i].horizontal; + bool in_place = (*parameters)[i].in_place; + uint32_t beginc = (*parameters)[i].begin_c; + uint32_t endc = (*parameters)[i].begin_c + (*parameters)[i].num_c - 1; + + uint32_t offset; + if (beginc < image.nb_meta_channels) { + if (endc >= image.nb_meta_channels) { + return JXL_FAILURE("Invalid squeeze: mix of meta and nonmeta channels"); + } + if (!in_place) { + return JXL_FAILURE( + "Invalid squeeze: meta channels require in-place residuals"); + } + image.nb_meta_channels += (*parameters)[i].num_c; + } + if (in_place) { + offset = endc + 1; + } else { + offset = image.channel.size(); + } + for (uint32_t c = beginc; c <= endc; c++) { + if (image.channel[c].hshift > 30 || image.channel[c].vshift > 30) { + return JXL_FAILURE("Too many squeezes: shift > 30"); + } + size_t w = image.channel[c].w; + size_t h = image.channel[c].h; + if (w == 0 || h == 0) return JXL_FAILURE("Squeezing empty channel"); + if (horizontal) { + image.channel[c].w = (w + 1) / 2; + if (image.channel[c].hshift >= 0) image.channel[c].hshift++; + w = w - (w + 1) / 2; + } else { + image.channel[c].h = (h + 1) / 2; + if (image.channel[c].vshift >= 0) image.channel[c].vshift++; + h = h - (h + 1) / 2; + } + image.channel[c].shrink(); + Channel dummy(w, h); + dummy.hshift = image.channel[c].hshift; + dummy.vshift = image.channel[c].vshift; + + image.channel.insert(image.channel.begin() + offset + (c - beginc), + std::move(dummy)); + JXL_DEBUG_V(8, "MetaSqueeze applied, current image: %s", + image.DebugString().c_str()); + } + } + return true; +} + +} // namespace jxl + +#endif diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/squeeze.h b/third_party/jpeg-xl/lib/jxl/modular/transform/squeeze.h new file mode 100644 index 0000000000..fb18710a6f --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/squeeze.h @@ -0,0 +1,90 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_TRANSFORM_SQUEEZE_H_ +#define LIB_JXL_MODULAR_TRANSFORM_SQUEEZE_H_ + +// Haar-like transform: halves the resolution in one direction +// A B -> (A+B)>>1 in one channel (average) -> same range as +// original channel +// A-B - tendency in a new channel ('residual' needed to make +// the transform reversible) +// -> theoretically range could be 2.5 +// times larger (2 times without the +// 'tendency'), but there should be lots +// of zeroes +// Repeated application (alternating horizontal and vertical squeezes) results +// in downscaling +// +// The default coefficient ordering is low-frequency to high-frequency, as in +// M. Antonini, M. Barlaud, P. Mathieu and I. Daubechies, "Image coding using +// wavelet transform", IEEE Transactions on Image Processing, vol. 1, no. 2, pp. +// 205-220, April 1992, doi: 10.1109/83.136597. + +#include <stdlib.h> + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/common.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/transform/transform.h" + +#define JXL_MAX_FIRST_PREVIEW_SIZE 8 + +namespace jxl { + +/* + int avg=(A+B)>>1; + int diff=(A-B); + int rA=(diff+(avg<<1)+(diff&1))>>1; + int rB=rA-diff; + +*/ +// |A B|C D|E F| +// p a n p=avg(A,B), a=avg(C,D), n=avg(E,F) +// +// Goal: estimate C-D (avoiding ringing artifacts) +// (ensuring that in smooth areas, a zero residual corresponds to a smooth +// gradient) + +// best estimate for C: (B + 2*a)/3 +// best estimate for D: (n + 3*a)/4 +// best estimate for C-D: 4*B - 3*n - a /12 + +// avoid ringing by 1) only doing this if B <= a <= n or B >= a >= n +// (otherwise, this is not a smooth area and we cannot really estimate C-D) +// 2) making sure that B <= C <= D <= n or B >= C >= D >= n + +inline pixel_type_w SmoothTendency(pixel_type_w B, pixel_type_w a, + pixel_type_w n) { + pixel_type_w diff = 0; + if (B >= a && a >= n) { + diff = (4 * B - 3 * n - a + 6) / 12; + // 2C = a<<1 + diff - diff&1 <= 2B so diff - diff&1 <= 2B - 2a + // 2D = a<<1 - diff - diff&1 >= 2n so diff + diff&1 <= 2a - 2n + if (diff - (diff & 1) > 2 * (B - a)) diff = 2 * (B - a) + 1; + if (diff + (diff & 1) > 2 * (a - n)) diff = 2 * (a - n); + } else if (B <= a && a <= n) { + diff = (4 * B - 3 * n - a - 6) / 12; + // 2C = a<<1 + diff + diff&1 >= 2B so diff + diff&1 >= 2B - 2a + // 2D = a<<1 - diff + diff&1 <= 2n so diff - diff&1 >= 2a - 2n + if (diff + (diff & 1) < 2 * (B - a)) diff = 2 * (B - a) - 1; + if (diff - (diff & 1) < 2 * (a - n)) diff = 2 * (a - n); + } + return diff; +} + +void DefaultSqueezeParameters(std::vector<SqueezeParams> *parameters, + const Image &image); + +Status CheckMetaSqueezeParams(const SqueezeParams ¶meter, int num_channels); + +Status MetaSqueeze(Image &image, std::vector<SqueezeParams> *parameters); + +Status InvSqueeze(Image &input, std::vector<SqueezeParams> parameters, + ThreadPool *pool); + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_TRANSFORM_SQUEEZE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/transform.cc b/third_party/jpeg-xl/lib/jxl/modular/transform/transform.cc new file mode 100644 index 0000000000..d9f2b435bf --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/transform.cc @@ -0,0 +1,98 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/transform/transform.h" + +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/transform/palette.h" +#include "lib/jxl/modular/transform/rct.h" +#include "lib/jxl/modular/transform/squeeze.h" + +namespace jxl { + +SqueezeParams::SqueezeParams() { Bundle::Init(this); } +Transform::Transform(TransformId id) { + Bundle::Init(this); + this->id = id; +} + +Status Transform::Inverse(Image &input, const weighted::Header &wp_header, + ThreadPool *pool) { + JXL_DEBUG_V(6, "Input channels (%" PRIuS ", %" PRIuS " meta): ", + input.channel.size(), input.nb_meta_channels); + switch (id) { + case TransformId::kRCT: + return InvRCT(input, begin_c, rct_type, pool); + case TransformId::kSqueeze: + return InvSqueeze(input, squeezes, pool); + case TransformId::kPalette: + return InvPalette(input, begin_c, nb_colors, nb_deltas, predictor, + wp_header, pool); + default: + return JXL_FAILURE("Unknown transformation (ID=%u)", + static_cast<unsigned int>(id)); + } +} + +Status Transform::MetaApply(Image &input) { + JXL_DEBUG_V(6, "MetaApply input: %s", input.DebugString().c_str()); + switch (id) { + case TransformId::kRCT: + JXL_DEBUG_V(2, "Transform: kRCT, rct_type=%" PRIu32, rct_type); + return CheckEqualChannels(input, begin_c, begin_c + 2); + case TransformId::kSqueeze: + JXL_DEBUG_V(2, "Transform: kSqueeze:"); +#if JXL_DEBUG_V_LEVEL >= 2 + { + auto squeezes_copy = squeezes; + if (squeezes_copy.empty()) { + DefaultSqueezeParameters(&squeezes_copy, input); + } + for (const auto ¶ms : squeezes_copy) { + JXL_DEBUG_V( + 2, + " squeeze params: horizontal=%d, in_place=%d, begin_c=%" PRIu32 + ", num_c=%" PRIu32, + params.horizontal, params.in_place, params.begin_c, params.num_c); + } + } +#endif + return MetaSqueeze(input, &squeezes); + case TransformId::kPalette: + JXL_DEBUG_V(2, + "Transform: kPalette, begin_c=%" PRIu32 ", num_c=%" PRIu32 + ", nb_colors=%" PRIu32 ", nb_deltas=%" PRIu32, + begin_c, num_c, nb_colors, nb_deltas); + return MetaPalette(input, begin_c, begin_c + num_c - 1, nb_colors, + nb_deltas, lossy_palette); + default: + return JXL_FAILURE("Unknown transformation (ID=%u)", + static_cast<unsigned int>(id)); + } +} + +Status CheckEqualChannels(const Image &image, uint32_t c1, uint32_t c2) { + if (c1 > image.channel.size() || c2 >= image.channel.size() || c2 < c1) { + return JXL_FAILURE("Invalid channel range: %u..%u (there are only %" PRIuS + " channels)", + c1, c2, image.channel.size()); + } + if (c1 < image.nb_meta_channels && c2 >= image.nb_meta_channels) { + return JXL_FAILURE("Invalid: transforming mix of meta and nonmeta"); + } + const auto &ch1 = image.channel[c1]; + for (size_t c = c1 + 1; c <= c2; c++) { + const auto &ch2 = image.channel[c]; + if (ch1.w != ch2.w || ch1.h != ch2.h || ch1.hshift != ch2.hshift || + ch1.vshift != ch2.vshift) { + return false; + } + } + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/transform.h b/third_party/jpeg-xl/lib/jxl/modular/transform/transform.h new file mode 100644 index 0000000000..d5d3259f7a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/transform.h @@ -0,0 +1,148 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_TRANSFORM_TRANSFORM_H_ +#define LIB_JXL_MODULAR_TRANSFORM_TRANSFORM_H_ + +#include <cstdint> +#include <string> +#include <vector> + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/options.h" + +namespace jxl { + +enum class TransformId : uint32_t { + // G, R-G, B-G and variants (including YCoCg). + kRCT = 0, + + // Color palette. Parameters are: [begin_c] [end_c] [nb_colors] + kPalette = 1, + + // Squeezing (Haar-style) + kSqueeze = 2, + + // Invalid for now. + kInvalid = 3, +}; + +struct SqueezeParams : public Fields { + JXL_FIELDS_NAME(SqueezeParams) + bool horizontal; + bool in_place; + uint32_t begin_c; + uint32_t num_c; + SqueezeParams(); + Status VisitFields(Visitor *JXL_RESTRICT visitor) override { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &horizontal)); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &in_place)); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Bits(3), BitsOffset(6, 8), + BitsOffset(10, 72), + BitsOffset(13, 1096), 0, &begin_c)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(1), Val(2), Val(3), BitsOffset(4, 4), 2, &num_c)); + return true; + } +}; + +class Transform : public Fields { + public: + TransformId id; + // for Palette and RCT. + uint32_t begin_c; + // for RCT. 42 possible values starting from 0. + uint32_t rct_type; + // Only for Palette and NearLossless. + uint32_t num_c; + // Only for Palette. + uint32_t nb_colors; + uint32_t nb_deltas; + // for Squeeze. Default squeeze if empty. + std::vector<SqueezeParams> squeezes; + // for NearLossless, not serialized. + int max_delta_error; + // Serialized for Palette. + Predictor predictor; + // for Palette, not serialized. + bool ordered_palette = true; + bool lossy_palette = false; + + explicit Transform(TransformId id); + // default constructor for bundles. + Transform() : Transform(TransformId::kInvalid) {} + + Status VisitFields(Visitor *JXL_RESTRICT visitor) override { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32( + Val((uint32_t)TransformId::kRCT), Val((uint32_t)TransformId::kPalette), + Val((uint32_t)TransformId::kSqueeze), + Val((uint32_t)TransformId::kInvalid), (uint32_t)TransformId::kRCT, + reinterpret_cast<uint32_t *>(&id))); + if (id == TransformId::kInvalid) { + return JXL_FAILURE("Invalid transform ID"); + } + if (visitor->Conditional(id == TransformId::kRCT || + id == TransformId::kPalette)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Bits(3), BitsOffset(6, 8), BitsOffset(10, 72), + BitsOffset(13, 1096), 0, &begin_c)); + } + if (visitor->Conditional(id == TransformId::kRCT)) { + // 0-41, default YCoCg. + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(6), Bits(2), BitsOffset(4, 2), + BitsOffset(6, 10), 6, &rct_type)); + if (rct_type >= 42) { + return JXL_FAILURE("Invalid transform RCT type"); + } + } + if (visitor->Conditional(id == TransformId::kPalette)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(1), Val(3), Val(4), BitsOffset(13, 1), 3, &num_c)); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32( + BitsOffset(8, 0), BitsOffset(10, 256), BitsOffset(12, 1280), + BitsOffset(16, 5376), 256, &nb_colors)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(0), BitsOffset(8, 1), BitsOffset(10, 257), + BitsOffset(16, 1281), 0, &nb_deltas)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->Bits(4, (uint32_t)Predictor::Zero, + reinterpret_cast<uint32_t *>(&predictor))); + if (predictor >= Predictor::Best) { + return JXL_FAILURE("Invalid predictor"); + } + } + + if (visitor->Conditional(id == TransformId::kSqueeze)) { + uint32_t num_squeezes = static_cast<uint32_t>(squeezes.size()); + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(0), BitsOffset(4, 1), BitsOffset(6, 9), + BitsOffset(8, 41), 0, &num_squeezes)); + if (visitor->IsReading()) squeezes.resize(num_squeezes); + for (size_t i = 0; i < num_squeezes; i++) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&squeezes[i])); + } + } + return true; + } + + JXL_FIELDS_NAME(Transform) + + Status Inverse(Image &input, const weighted::Header &wp_header, + ThreadPool *pool = nullptr); + Status MetaApply(Image &input); +}; + +Status CheckEqualChannels(const Image &image, uint32_t c1, uint32_t c2); + +static inline pixel_type PixelAdd(pixel_type a, pixel_type b) { + return static_cast<pixel_type>(static_cast<uint32_t>(a) + + static_cast<uint32_t>(b)); +} + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_TRANSFORM_TRANSFORM_H_ |