summaryrefslogtreecommitdiffstats
path: root/third_party/jpeg-xl/lib/jxl/dec_patch_dictionary.cc
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/jpeg-xl/lib/jxl/dec_patch_dictionary.cc')
-rw-r--r--third_party/jpeg-xl/lib/jxl/dec_patch_dictionary.cc347
1 files changed, 347 insertions, 0 deletions
diff --git a/third_party/jpeg-xl/lib/jxl/dec_patch_dictionary.cc b/third_party/jpeg-xl/lib/jxl/dec_patch_dictionary.cc
new file mode 100644
index 0000000000..85e5de3c8d
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/dec_patch_dictionary.cc
@@ -0,0 +1,347 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#include "lib/jxl/dec_patch_dictionary.h"
+
+#include <stdint.h>
+#include <stdlib.h>
+#include <sys/types.h>
+
+#include <algorithm>
+#include <string>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+#include "lib/jxl/ans_params.h"
+#include "lib/jxl/base/compiler_specific.h"
+#include "lib/jxl/base/override.h"
+#include "lib/jxl/base/printf_macros.h"
+#include "lib/jxl/base/status.h"
+#include "lib/jxl/blending.h"
+#include "lib/jxl/chroma_from_luma.h"
+#include "lib/jxl/color_management.h"
+#include "lib/jxl/common.h"
+#include "lib/jxl/dec_ans.h"
+#include "lib/jxl/dec_frame.h"
+#include "lib/jxl/entropy_coder.h"
+#include "lib/jxl/frame_header.h"
+#include "lib/jxl/image.h"
+#include "lib/jxl/image_bundle.h"
+#include "lib/jxl/image_ops.h"
+#include "lib/jxl/patch_dictionary_internal.h"
+
+namespace jxl {
+
+Status PatchDictionary::Decode(BitReader* br, size_t xsize, size_t ysize,
+ bool* uses_extra_channels) {
+ positions_.clear();
+ std::vector<uint8_t> context_map;
+ ANSCode code;
+ JXL_RETURN_IF_ERROR(
+ DecodeHistograms(br, kNumPatchDictionaryContexts, &code, &context_map));
+ ANSSymbolReader decoder(&code, br);
+
+ auto read_num = [&](size_t context) {
+ size_t r = decoder.ReadHybridUint(context, br, context_map);
+ return r;
+ };
+
+ size_t num_ref_patch = read_num(kNumRefPatchContext);
+ // Limit max memory usage of patches to about 66 bytes per pixel (assuming 8
+ // bytes per size_t)
+ const size_t num_pixels = xsize * ysize;
+ const size_t max_ref_patches = 1024 + num_pixels / 4;
+ const size_t max_patches = max_ref_patches * 4;
+ const size_t max_blending_infos = max_patches * 4;
+ if (num_ref_patch > max_ref_patches) {
+ return JXL_FAILURE("Too many patches in dictionary");
+ }
+ size_t num_ec = shared_->metadata->m.num_extra_channels;
+
+ size_t total_patches = 0;
+ size_t next_size = 1;
+
+ for (size_t id = 0; id < num_ref_patch; id++) {
+ PatchReferencePosition ref_pos;
+ ref_pos.ref = read_num(kReferenceFrameContext);
+ if (ref_pos.ref >= kMaxNumReferenceFrames ||
+ shared_->reference_frames[ref_pos.ref].frame.xsize() == 0) {
+ return JXL_FAILURE("Invalid reference frame ID");
+ }
+ if (!shared_->reference_frames[ref_pos.ref].ib_is_in_xyb) {
+ return JXL_FAILURE(
+ "Patches cannot use frames saved post color transforms");
+ }
+ const ImageBundle& ib = shared_->reference_frames[ref_pos.ref].frame;
+ ref_pos.x0 = read_num(kPatchReferencePositionContext);
+ ref_pos.y0 = read_num(kPatchReferencePositionContext);
+ ref_pos.xsize = read_num(kPatchSizeContext) + 1;
+ ref_pos.ysize = read_num(kPatchSizeContext) + 1;
+ if (ref_pos.x0 + ref_pos.xsize > ib.xsize()) {
+ return JXL_FAILURE("Invalid position specified in reference frame");
+ }
+ if (ref_pos.y0 + ref_pos.ysize > ib.ysize()) {
+ return JXL_FAILURE("Invalid position specified in reference frame");
+ }
+ size_t id_count = read_num(kPatchCountContext) + 1;
+ total_patches += id_count;
+ if (total_patches > max_patches) {
+ return JXL_FAILURE("Too many patches in dictionary");
+ }
+ if (next_size < total_patches) {
+ next_size *= 2;
+ next_size = std::min<size_t>(next_size, max_patches);
+ }
+ if (next_size * (num_ec + 1) > max_blending_infos) {
+ return JXL_FAILURE("Too many patches in dictionary");
+ }
+ positions_.reserve(next_size);
+ blendings_.reserve(next_size * (num_ec + 1));
+ for (size_t i = 0; i < id_count; i++) {
+ PatchPosition pos;
+ pos.ref_pos_idx = ref_positions_.size();
+ if (i == 0) {
+ pos.x = read_num(kPatchPositionContext);
+ pos.y = read_num(kPatchPositionContext);
+ } else {
+ pos.x =
+ positions_.back().x + UnpackSigned(read_num(kPatchOffsetContext));
+ pos.y =
+ positions_.back().y + UnpackSigned(read_num(kPatchOffsetContext));
+ }
+ if (pos.x + ref_pos.xsize > xsize) {
+ return JXL_FAILURE("Invalid patch x: at %" PRIuS " + %" PRIuS
+ " > %" PRIuS,
+ pos.x, ref_pos.xsize, xsize);
+ }
+ if (pos.y + ref_pos.ysize > ysize) {
+ return JXL_FAILURE("Invalid patch y: at %" PRIuS " + %" PRIuS
+ " > %" PRIuS,
+ pos.y, ref_pos.ysize, ysize);
+ }
+ for (size_t j = 0; j < num_ec + 1; j++) {
+ uint32_t blend_mode = read_num(kPatchBlendModeContext);
+ if (blend_mode >= uint32_t(PatchBlendMode::kNumBlendModes)) {
+ return JXL_FAILURE("Invalid patch blend mode: %u", blend_mode);
+ }
+ PatchBlending info;
+ info.mode = static_cast<PatchBlendMode>(blend_mode);
+ if (UsesAlpha(info.mode)) {
+ *uses_extra_channels = true;
+ }
+ if (info.mode != PatchBlendMode::kNone && j > 0) {
+ *uses_extra_channels = true;
+ }
+ if (UsesAlpha(info.mode) &&
+ shared_->metadata->m.extra_channel_info.size() > 1) {
+ info.alpha_channel = read_num(kPatchAlphaChannelContext);
+ if (info.alpha_channel >=
+ shared_->metadata->m.extra_channel_info.size()) {
+ return JXL_FAILURE(
+ "Invalid alpha channel for blending: %u out of %u\n",
+ info.alpha_channel,
+ (uint32_t)shared_->metadata->m.extra_channel_info.size());
+ }
+ } else {
+ info.alpha_channel = 0;
+ }
+ if (UsesClamp(info.mode)) {
+ info.clamp = read_num(kPatchClampContext);
+ } else {
+ info.clamp = false;
+ }
+ blendings_.push_back(info);
+ }
+ positions_.push_back(std::move(pos));
+ }
+ ref_positions_.emplace_back(std::move(ref_pos));
+ }
+ positions_.shrink_to_fit();
+
+ if (!decoder.CheckANSFinalState()) {
+ return JXL_FAILURE("ANS checksum failure.");
+ }
+
+ ComputePatchTree();
+ return true;
+}
+
+int PatchDictionary::GetReferences() const {
+ int result = 0;
+ for (size_t i = 0; i < ref_positions_.size(); ++i) {
+ result |= (1 << static_cast<int>(ref_positions_[i].ref));
+ }
+ return result;
+}
+
+namespace {
+struct PatchInterval {
+ size_t idx;
+ size_t y0, y1;
+};
+} // namespace
+
+void PatchDictionary::ComputePatchTree() {
+ patch_tree_.clear();
+ num_patches_.clear();
+ sorted_patches_y0_.clear();
+ sorted_patches_y1_.clear();
+ if (positions_.empty()) {
+ return;
+ }
+ // Create a y-interval for each patch.
+ std::vector<PatchInterval> intervals(positions_.size());
+ for (size_t i = 0; i < positions_.size(); ++i) {
+ const auto& pos = positions_[i];
+ intervals[i].idx = i;
+ intervals[i].y0 = pos.y;
+ intervals[i].y1 = pos.y + ref_positions_[pos.ref_pos_idx].ysize;
+ }
+ auto sort_by_y0 = [&intervals](size_t start, size_t end) {
+ std::sort(intervals.data() + start, intervals.data() + end,
+ [](const PatchInterval& i0, const PatchInterval& i1) {
+ return i0.y0 < i1.y0;
+ });
+ };
+ auto sort_by_y1 = [&intervals](size_t start, size_t end) {
+ std::sort(intervals.data() + start, intervals.data() + end,
+ [](const PatchInterval& i0, const PatchInterval& i1) {
+ return i0.y1 < i1.y1;
+ });
+ };
+ // Count the number of patches for each row.
+ sort_by_y1(0, intervals.size());
+ num_patches_.resize(intervals.back().y1);
+ for (auto iv : intervals) {
+ for (size_t y = iv.y0; y < iv.y1; ++y) num_patches_[y]++;
+ }
+ PatchTreeNode root;
+ root.start = 0;
+ root.num = intervals.size();
+ patch_tree_.push_back(root);
+ size_t next = 0;
+ while (next < patch_tree_.size()) {
+ auto& node = patch_tree_[next];
+ size_t start = node.start;
+ size_t end = node.start + node.num;
+ // Choose the y_center for this node to be the median of interval starts.
+ sort_by_y0(start, end);
+ size_t middle_idx = start + node.num / 2;
+ node.y_center = intervals[middle_idx].y0;
+ // Divide the intervals in [start, end) into three groups:
+ // * those completely to the right of y_center: [right_start, end)
+ // * those overlapping y_center: [left_end, right_start)
+ // * those completely to the left of y_center: [start, left_end)
+ size_t right_start = middle_idx;
+ while (right_start < end && intervals[right_start].y0 == node.y_center) {
+ ++right_start;
+ }
+ sort_by_y1(start, right_start);
+ size_t left_end = right_start;
+ while (left_end > start && intervals[left_end - 1].y1 > node.y_center) {
+ --left_end;
+ }
+ // Fill in sorted_patches_y0_ and sorted_patches_y1_ for the current node.
+ node.num = right_start - left_end;
+ node.start = sorted_patches_y0_.size();
+ for (ssize_t i = static_cast<ssize_t>(right_start) - 1;
+ i >= static_cast<ssize_t>(left_end); --i) {
+ sorted_patches_y1_.push_back({intervals[i].y1, intervals[i].idx});
+ }
+ sort_by_y0(left_end, right_start);
+ for (size_t i = left_end; i < right_start; ++i) {
+ sorted_patches_y0_.push_back({intervals[i].y0, intervals[i].idx});
+ }
+ // Create the left and right nodes (if not empty).
+ node.left_child = node.right_child = -1;
+ if (left_end > start) {
+ PatchTreeNode left;
+ left.start = start;
+ left.num = left_end - left.start;
+ patch_tree_[next].left_child = patch_tree_.size();
+ patch_tree_.push_back(left);
+ }
+ if (right_start < end) {
+ PatchTreeNode right;
+ right.start = right_start;
+ right.num = end - right.start;
+ patch_tree_[next].right_child = patch_tree_.size();
+ patch_tree_.push_back(right);
+ }
+ ++next;
+ }
+}
+
+std::vector<size_t> PatchDictionary::GetPatchesForRow(size_t y) const {
+ std::vector<size_t> result;
+ if (y < num_patches_.size() && num_patches_[y] > 0) {
+ result.reserve(num_patches_[y]);
+ for (ssize_t tree_idx = 0; tree_idx != -1;) {
+ JXL_DASSERT(tree_idx < (ssize_t)patch_tree_.size());
+ const auto& node = patch_tree_[tree_idx];
+ if (y <= node.y_center) {
+ for (size_t i = 0; i < node.num; ++i) {
+ const auto& p = sorted_patches_y0_[node.start + i];
+ if (y < p.first) break;
+ result.push_back(p.second);
+ }
+ tree_idx = y < node.y_center ? node.left_child : -1;
+ } else {
+ for (size_t i = 0; i < node.num; ++i) {
+ const auto& p = sorted_patches_y1_[node.start + i];
+ if (y >= p.first) break;
+ result.push_back(p.second);
+ }
+ tree_idx = node.right_child;
+ }
+ }
+ // Ensure that he relative order of patches that affect the same pixels is
+ // preserved. This is important for patches that have a blend mode
+ // different from kAdd.
+ std::sort(result.begin(), result.end());
+ }
+ return result;
+}
+
+// Adds patches to a segment of `xsize` pixels, starting at `inout`, assumed
+// to be located at position (x0, y) in the frame.
+void PatchDictionary::AddOneRow(float* const* inout, size_t y, size_t x0,
+ size_t xsize) const {
+ size_t num_ec = shared_->metadata->m.num_extra_channels;
+ std::vector<const float*> fg_ptrs(3 + num_ec);
+ for (size_t pos_idx : GetPatchesForRow(y)) {
+ const size_t blending_idx = pos_idx * (num_ec + 1);
+ const PatchPosition& pos = positions_[pos_idx];
+ const PatchReferencePosition& ref_pos = ref_positions_[pos.ref_pos_idx];
+ size_t by = pos.y;
+ size_t bx = pos.x;
+ size_t patch_xsize = ref_pos.xsize;
+ JXL_DASSERT(y >= by);
+ JXL_DASSERT(y < by + ref_pos.ysize);
+ size_t iy = y - by;
+ size_t ref = ref_pos.ref;
+ if (bx >= x0 + xsize) continue;
+ if (bx + patch_xsize < x0) continue;
+ size_t patch_x0 = std::max(bx, x0);
+ size_t patch_x1 = std::min(bx + patch_xsize, x0 + xsize);
+ for (size_t c = 0; c < 3; c++) {
+ fg_ptrs[c] = shared_->reference_frames[ref].frame.color().ConstPlaneRow(
+ c, ref_pos.y0 + iy) +
+ ref_pos.x0 + x0 - bx;
+ }
+ for (size_t i = 0; i < num_ec; i++) {
+ fg_ptrs[3 + i] =
+ shared_->reference_frames[ref].frame.extra_channels()[i].ConstRow(
+ ref_pos.y0 + iy) +
+ ref_pos.x0 + x0 - bx;
+ }
+ PerformBlending(inout, fg_ptrs.data(), inout, patch_x0 - x0,
+ patch_x1 - patch_x0, blendings_[blending_idx],
+ blendings_.data() + blending_idx + 1,
+ shared_->metadata->m.extra_channel_info);
+ }
+}
+} // namespace jxl