diff options
Diffstat (limited to 'third_party/jpeg-xl/lib/jxl/dec_group.cc')
-rw-r--r-- | third_party/jpeg-xl/lib/jxl/dec_group.cc | 801 |
1 files changed, 801 insertions, 0 deletions
diff --git a/third_party/jpeg-xl/lib/jxl/dec_group.cc b/third_party/jpeg-xl/lib/jxl/dec_group.cc new file mode 100644 index 0000000000..be8df9b062 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_group.cc @@ -0,0 +1,801 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/dec_group.h" + +#include <stdint.h> +#include <string.h> + +#include <algorithm> +#include <memory> +#include <utility> + +#include "lib/jxl/frame_header.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/dec_group.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/ac_context.h" +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/common.h" +#include "lib/jxl/convolve.h" +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/dec_transforms-inl.h" +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/epf.h" +#include "lib/jxl/opsin_params.h" +#include "lib/jxl/quant_weights.h" +#include "lib/jxl/quantizer-inl.h" +#include "lib/jxl/quantizer.h" + +#ifndef LIB_JXL_DEC_GROUP_CC +#define LIB_JXL_DEC_GROUP_CC +namespace jxl { + +struct AuxOut; + +// Interface for reading groups for DecodeGroupImpl. +class GetBlock { + public: + virtual void StartRow(size_t by) = 0; + virtual Status LoadBlock(size_t bx, size_t by, const AcStrategy& acs, + size_t size, size_t log2_covered_blocks, + ACPtr block[3], ACType ac_type) = 0; + virtual ~GetBlock() {} +}; + +// Controls whether DecodeGroupImpl renders to pixels or not. +enum DrawMode { + // Render to pixels. + kDraw = 0, + // Don't render to pixels. + kDontDraw = 1, +}; + +} // namespace jxl +#endif // LIB_JXL_DEC_GROUP_CC + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Rebind; +using hwy::HWY_NAMESPACE::ShiftRight; + +using D = HWY_FULL(float); +using DU = HWY_FULL(uint32_t); +using DI = HWY_FULL(int32_t); +using DI16 = Rebind<int16_t, DI>; +constexpr D d; +constexpr DI di; +constexpr DI16 di16; + +// TODO(veluca): consider SIMDfying. +void Transpose8x8InPlace(int32_t* JXL_RESTRICT block) { + for (size_t x = 0; x < 8; x++) { + for (size_t y = x + 1; y < 8; y++) { + std::swap(block[y * 8 + x], block[x * 8 + y]); + } + } +} + +template <ACType ac_type> +void DequantLane(Vec<D> scaled_dequant_x, Vec<D> scaled_dequant_y, + Vec<D> scaled_dequant_b, + const float* JXL_RESTRICT dequant_matrices, size_t size, + size_t k, Vec<D> x_cc_mul, Vec<D> b_cc_mul, + const float* JXL_RESTRICT biases, ACPtr qblock[3], + float* JXL_RESTRICT block) { + const auto x_mul = Mul(Load(d, dequant_matrices + k), scaled_dequant_x); + const auto y_mul = + Mul(Load(d, dequant_matrices + size + k), scaled_dequant_y); + const auto b_mul = + Mul(Load(d, dequant_matrices + 2 * size + k), scaled_dequant_b); + + Vec<DI> quantized_x_int; + Vec<DI> quantized_y_int; + Vec<DI> quantized_b_int; + if (ac_type == ACType::k16) { + Rebind<int16_t, DI> di16; + quantized_x_int = PromoteTo(di, Load(di16, qblock[0].ptr16 + k)); + quantized_y_int = PromoteTo(di, Load(di16, qblock[1].ptr16 + k)); + quantized_b_int = PromoteTo(di, Load(di16, qblock[2].ptr16 + k)); + } else { + quantized_x_int = Load(di, qblock[0].ptr32 + k); + quantized_y_int = Load(di, qblock[1].ptr32 + k); + quantized_b_int = Load(di, qblock[2].ptr32 + k); + } + + const auto dequant_x_cc = + Mul(AdjustQuantBias(di, 0, quantized_x_int, biases), x_mul); + const auto dequant_y = + Mul(AdjustQuantBias(di, 1, quantized_y_int, biases), y_mul); + const auto dequant_b_cc = + Mul(AdjustQuantBias(di, 2, quantized_b_int, biases), b_mul); + + const auto dequant_x = MulAdd(x_cc_mul, dequant_y, dequant_x_cc); + const auto dequant_b = MulAdd(b_cc_mul, dequant_y, dequant_b_cc); + Store(dequant_x, d, block + k); + Store(dequant_y, d, block + size + k); + Store(dequant_b, d, block + 2 * size + k); +} + +template <ACType ac_type> +void DequantBlock(const AcStrategy& acs, float inv_global_scale, int quant, + float x_dm_multiplier, float b_dm_multiplier, Vec<D> x_cc_mul, + Vec<D> b_cc_mul, size_t kind, size_t size, + const Quantizer& quantizer, size_t covered_blocks, + const size_t* sbx, + const float* JXL_RESTRICT* JXL_RESTRICT dc_row, + size_t dc_stride, const float* JXL_RESTRICT biases, + ACPtr qblock[3], float* JXL_RESTRICT block) { + PROFILER_FUNC; + + const auto scaled_dequant_s = inv_global_scale / quant; + + const auto scaled_dequant_x = Set(d, scaled_dequant_s * x_dm_multiplier); + const auto scaled_dequant_y = Set(d, scaled_dequant_s); + const auto scaled_dequant_b = Set(d, scaled_dequant_s * b_dm_multiplier); + + const float* dequant_matrices = quantizer.DequantMatrix(kind, 0); + + for (size_t k = 0; k < covered_blocks * kDCTBlockSize; k += Lanes(d)) { + DequantLane<ac_type>(scaled_dequant_x, scaled_dequant_y, scaled_dequant_b, + dequant_matrices, size, k, x_cc_mul, b_cc_mul, biases, + qblock, block); + } + for (size_t c = 0; c < 3; c++) { + LowestFrequenciesFromDC(acs.Strategy(), dc_row[c] + sbx[c], dc_stride, + block + c * size); + } +} + +Status DecodeGroupImpl(GetBlock* JXL_RESTRICT get_block, + GroupDecCache* JXL_RESTRICT group_dec_cache, + PassesDecoderState* JXL_RESTRICT dec_state, + size_t thread, size_t group_idx, + RenderPipelineInput& render_pipeline_input, + ImageBundle* decoded, DrawMode draw) { + // TODO(veluca): investigate cache usage in this function. + PROFILER_FUNC; + const Rect block_rect = dec_state->shared->BlockGroupRect(group_idx); + const AcStrategyImage& ac_strategy = dec_state->shared->ac_strategy; + + const size_t xsize_blocks = block_rect.xsize(); + const size_t ysize_blocks = block_rect.ysize(); + + const size_t dc_stride = dec_state->shared->dc->PixelsPerRow(); + + const float inv_global_scale = dec_state->shared->quantizer.InvGlobalScale(); + + const YCbCrChromaSubsampling& cs = + dec_state->shared->frame_header.chroma_subsampling; + + size_t idct_stride[3]; + for (size_t c = 0; c < 3; c++) { + idct_stride[c] = render_pipeline_input.GetBuffer(c).first->PixelsPerRow(); + } + + HWY_ALIGN int32_t scaled_qtable[64 * 3]; + + ACType ac_type = dec_state->coefficients->Type(); + auto dequant_block = ac_type == ACType::k16 ? DequantBlock<ACType::k16> + : DequantBlock<ACType::k32>; + // Whether or not coefficients should be stored for future usage, and/or read + // from past usage. + bool accumulate = !dec_state->coefficients->IsEmpty(); + // Offset of the current block in the group. + size_t offset = 0; + + std::array<int, 3> jpeg_c_map; + bool jpeg_is_gray = false; + std::array<int, 3> dcoff = {}; + + // TODO(veluca): all of this should be done only once per image. + if (decoded->IsJPEG()) { + if (!dec_state->shared->cmap.IsJPEGCompatible()) { + return JXL_FAILURE("The CfL map is not JPEG-compatible"); + } + jpeg_is_gray = (decoded->jpeg_data->components.size() == 1); + jpeg_c_map = JpegOrder(dec_state->shared->frame_header.color_transform, + jpeg_is_gray); + const std::vector<QuantEncoding>& qe = + dec_state->shared->matrices.encodings(); + if (qe.empty() || qe[0].mode != QuantEncoding::Mode::kQuantModeRAW || + std::abs(qe[0].qraw.qtable_den - 1.f / (8 * 255)) > 1e-8f) { + return JXL_FAILURE( + "Quantization table is not a JPEG quantization table."); + } + for (size_t c = 0; c < 3; c++) { + if (dec_state->shared->frame_header.color_transform == + ColorTransform::kNone) { + dcoff[c] = 1024 / (*qe[0].qraw.qtable)[64 * c]; + } + for (size_t i = 0; i < 64; i++) { + // Transpose the matrix, as it will be used on the transposed block. + int n = qe[0].qraw.qtable->at(64 + i); + int d = qe[0].qraw.qtable->at(64 * c + i); + if (n <= 0 || d <= 0 || n >= 65536 || d >= 65536) { + return JXL_FAILURE("Invalid JPEG quantization table"); + } + scaled_qtable[64 * c + (i % 8) * 8 + (i / 8)] = + (1 << kCFLFixedPointPrecision) * n / d; + } + } + } + + size_t hshift[3] = {cs.HShift(0), cs.HShift(1), cs.HShift(2)}; + size_t vshift[3] = {cs.VShift(0), cs.VShift(1), cs.VShift(2)}; + Rect r[3]; + for (size_t i = 0; i < 3; i++) { + r[i] = + Rect(block_rect.x0() >> hshift[i], block_rect.y0() >> vshift[i], + block_rect.xsize() >> hshift[i], block_rect.ysize() >> vshift[i]); + if (!r[i].IsInside({0, 0, dec_state->shared->dc->Plane(i).xsize(), + dec_state->shared->dc->Plane(i).ysize()})) { + return JXL_FAILURE("Frame dimensions are too big for the image."); + } + } + + for (size_t by = 0; by < ysize_blocks; ++by) { + get_block->StartRow(by); + size_t sby[3] = {by >> vshift[0], by >> vshift[1], by >> vshift[2]}; + + const int32_t* JXL_RESTRICT row_quant = + block_rect.ConstRow(dec_state->shared->raw_quant_field, by); + + const float* JXL_RESTRICT dc_rows[3] = { + r[0].ConstPlaneRow(*dec_state->shared->dc, 0, sby[0]), + r[1].ConstPlaneRow(*dec_state->shared->dc, 1, sby[1]), + r[2].ConstPlaneRow(*dec_state->shared->dc, 2, sby[2]), + }; + + const size_t ty = (block_rect.y0() + by) / kColorTileDimInBlocks; + AcStrategyRow acs_row = ac_strategy.ConstRow(block_rect, by); + + const int8_t* JXL_RESTRICT row_cmap[3] = { + dec_state->shared->cmap.ytox_map.ConstRow(ty), + nullptr, + dec_state->shared->cmap.ytob_map.ConstRow(ty), + }; + + float* JXL_RESTRICT idct_row[3]; + int16_t* JXL_RESTRICT jpeg_row[3]; + for (size_t c = 0; c < 3; c++) { + idct_row[c] = render_pipeline_input.GetBuffer(c).second.Row( + render_pipeline_input.GetBuffer(c).first, sby[c] * kBlockDim); + if (decoded->IsJPEG()) { + auto& component = decoded->jpeg_data->components[jpeg_c_map[c]]; + jpeg_row[c] = + component.coeffs.data() + + (component.width_in_blocks * (r[c].y0() + sby[c]) + r[c].x0()) * + kDCTBlockSize; + } + } + + size_t bx = 0; + for (size_t tx = 0; tx < DivCeil(xsize_blocks, kColorTileDimInBlocks); + tx++) { + size_t abs_tx = tx + block_rect.x0() / kColorTileDimInBlocks; + auto x_cc_mul = + Set(d, dec_state->shared->cmap.YtoXRatio(row_cmap[0][abs_tx])); + auto b_cc_mul = + Set(d, dec_state->shared->cmap.YtoBRatio(row_cmap[2][abs_tx])); + // Increment bx by llf_x because those iterations would otherwise + // immediately continue (!IsFirstBlock). Reduces mispredictions. + for (; bx < xsize_blocks && bx < (tx + 1) * kColorTileDimInBlocks;) { + size_t sbx[3] = {bx >> hshift[0], bx >> hshift[1], bx >> hshift[2]}; + AcStrategy acs = acs_row[bx]; + const size_t llf_x = acs.covered_blocks_x(); + + // Can only happen in the second or lower rows of a varblock. + if (JXL_UNLIKELY(!acs.IsFirstBlock())) { + bx += llf_x; + continue; + } + PROFILER_ZONE("DecodeGroupImpl inner"); + const size_t log2_covered_blocks = acs.log2_covered_blocks(); + + const size_t covered_blocks = 1 << log2_covered_blocks; + const size_t size = covered_blocks * kDCTBlockSize; + + ACPtr qblock[3]; + if (accumulate) { + for (size_t c = 0; c < 3; c++) { + qblock[c] = dec_state->coefficients->PlaneRow(c, group_idx, offset); + } + } else { + // No point in reading from bitstream without accumulating and not + // drawing. + JXL_ASSERT(draw == kDraw); + if (ac_type == ACType::k16) { + memset(group_dec_cache->dec_group_qblock16, 0, + size * 3 * sizeof(int16_t)); + for (size_t c = 0; c < 3; c++) { + qblock[c].ptr16 = group_dec_cache->dec_group_qblock16 + c * size; + } + } else { + memset(group_dec_cache->dec_group_qblock, 0, + size * 3 * sizeof(int32_t)); + for (size_t c = 0; c < 3; c++) { + qblock[c].ptr32 = group_dec_cache->dec_group_qblock + c * size; + } + } + } + JXL_RETURN_IF_ERROR(get_block->LoadBlock( + bx, by, acs, size, log2_covered_blocks, qblock, ac_type)); + offset += size; + if (draw == kDontDraw) { + bx += llf_x; + continue; + } + + if (JXL_UNLIKELY(decoded->IsJPEG())) { + if (acs.Strategy() != AcStrategy::Type::DCT) { + return JXL_FAILURE( + "Can only decode to JPEG if only DCT-8 is used."); + } + + HWY_ALIGN int32_t transposed_dct_y[64]; + for (size_t c : {1, 0, 2}) { + // Propagate only Y for grayscale. + if (jpeg_is_gray && c != 1) { + continue; + } + if ((sbx[c] << hshift[c] != bx) || (sby[c] << vshift[c] != by)) { + continue; + } + int16_t* JXL_RESTRICT jpeg_pos = + jpeg_row[c] + sbx[c] * kDCTBlockSize; + // JPEG XL is transposed, JPEG is not. + auto transposed_dct = qblock[c].ptr32; + Transpose8x8InPlace(transposed_dct); + // No CfL - no need to store the y block converted to integers. + if (!cs.Is444() || + (row_cmap[0][abs_tx] == 0 && row_cmap[2][abs_tx] == 0)) { + for (size_t i = 0; i < 64; i += Lanes(d)) { + const auto ini = Load(di, transposed_dct + i); + const auto ini16 = DemoteTo(di16, ini); + StoreU(ini16, di16, jpeg_pos + i); + } + } else if (c == 1) { + // Y channel: save for restoring X/B, but nothing else to do. + for (size_t i = 0; i < 64; i += Lanes(d)) { + const auto ini = Load(di, transposed_dct + i); + Store(ini, di, transposed_dct_y + i); + const auto ini16 = DemoteTo(di16, ini); + StoreU(ini16, di16, jpeg_pos + i); + } + } else { + // transposed_dct_y contains the y channel block, transposed. + const auto scale = Set( + di, dec_state->shared->cmap.RatioJPEG(row_cmap[c][abs_tx])); + const auto round = Set(di, 1 << (kCFLFixedPointPrecision - 1)); + for (int i = 0; i < 64; i += Lanes(d)) { + auto in = Load(di, transposed_dct + i); + auto in_y = Load(di, transposed_dct_y + i); + auto qt = Load(di, scaled_qtable + c * size + i); + auto coeff_scale = ShiftRight<kCFLFixedPointPrecision>( + Add(Mul(qt, scale), round)); + auto cfl_factor = ShiftRight<kCFLFixedPointPrecision>( + Add(Mul(in_y, coeff_scale), round)); + StoreU(DemoteTo(di16, Add(in, cfl_factor)), di16, jpeg_pos + i); + } + } + jpeg_pos[0] = + Clamp1<float>(dc_rows[c][sbx[c]] - dcoff[c], -2047, 2047); + } + } else { + HWY_ALIGN float* const block = group_dec_cache->dec_group_block; + // Dequantize and add predictions. + dequant_block( + acs, inv_global_scale, row_quant[bx], dec_state->x_dm_multiplier, + dec_state->b_dm_multiplier, x_cc_mul, b_cc_mul, acs.RawStrategy(), + size, dec_state->shared->quantizer, + acs.covered_blocks_y() * acs.covered_blocks_x(), sbx, dc_rows, + dc_stride, + dec_state->output_encoding_info.opsin_params.quant_biases, qblock, + block); + + for (size_t c : {1, 0, 2}) { + if ((sbx[c] << hshift[c] != bx) || (sby[c] << vshift[c] != by)) { + continue; + } + // IDCT + float* JXL_RESTRICT idct_pos = idct_row[c] + sbx[c] * kBlockDim; + TransformToPixels(acs.Strategy(), block + c * size, idct_pos, + idct_stride[c], group_dec_cache->scratch_space); + } + } + bx += llf_x; + } + } + } + if (draw == kDontDraw) { + return true; + } + return true; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +namespace { +// Decode quantized AC coefficients of DCT blocks. +// LLF components in the output block will not be modified. +template <ACType ac_type> +Status DecodeACVarBlock(size_t ctx_offset, size_t log2_covered_blocks, + int32_t* JXL_RESTRICT row_nzeros, + const int32_t* JXL_RESTRICT row_nzeros_top, + size_t nzeros_stride, size_t c, size_t bx, size_t by, + size_t lbx, AcStrategy acs, + const coeff_order_t* JXL_RESTRICT coeff_order, + BitReader* JXL_RESTRICT br, + ANSSymbolReader* JXL_RESTRICT decoder, + const std::vector<uint8_t>& context_map, + const uint8_t* qdc_row, const int32_t* qf_row, + const BlockCtxMap& block_ctx_map, ACPtr block, + size_t shift = 0) { + PROFILER_FUNC; + // Equal to number of LLF coefficients. + const size_t covered_blocks = 1 << log2_covered_blocks; + const size_t size = covered_blocks * kDCTBlockSize; + int32_t predicted_nzeros = + PredictFromTopAndLeft(row_nzeros_top, row_nzeros, bx, 32); + + size_t ord = kStrategyOrder[acs.RawStrategy()]; + const coeff_order_t* JXL_RESTRICT order = + &coeff_order[CoeffOrderOffset(ord, c)]; + + size_t block_ctx = block_ctx_map.Context(qdc_row[lbx], qf_row[bx], ord, c); + const int32_t nzero_ctx = + block_ctx_map.NonZeroContext(predicted_nzeros, block_ctx) + ctx_offset; + + size_t nzeros = decoder->ReadHybridUint(nzero_ctx, br, context_map); + if (nzeros + covered_blocks > size) { + return JXL_FAILURE("Invalid AC: nzeros too large"); + } + for (size_t y = 0; y < acs.covered_blocks_y(); y++) { + for (size_t x = 0; x < acs.covered_blocks_x(); x++) { + row_nzeros[bx + x + y * nzeros_stride] = + (nzeros + covered_blocks - 1) >> log2_covered_blocks; + } + } + + const size_t histo_offset = + ctx_offset + block_ctx_map.ZeroDensityContextsOffset(block_ctx); + + // Skip LLF + { + PROFILER_ZONE("AcDecSkipLLF, reader"); + size_t prev = (nzeros > size / 16 ? 0 : 1); + for (size_t k = covered_blocks; k < size && nzeros != 0; ++k) { + const size_t ctx = + histo_offset + ZeroDensityContext(nzeros, k, covered_blocks, + log2_covered_blocks, prev); + const size_t u_coeff = decoder->ReadHybridUint(ctx, br, context_map); + // Hand-rolled version of UnpackSigned, shifting before the conversion to + // signed integer to avoid undefined behavior of shifting negative + // numbers. + const size_t magnitude = u_coeff >> 1; + const size_t neg_sign = (~u_coeff) & 1; + const intptr_t coeff = + static_cast<intptr_t>((magnitude ^ (neg_sign - 1)) << shift); + if (ac_type == ACType::k16) { + block.ptr16[order[k]] += coeff; + } else { + block.ptr32[order[k]] += coeff; + } + prev = static_cast<size_t>(u_coeff != 0); + nzeros -= prev; + } + if (JXL_UNLIKELY(nzeros != 0)) { + return JXL_FAILURE("Invalid AC: nzeros not 0. Block (%" PRIuS ", %" PRIuS + "), channel %" PRIuS, + bx, by, c); + } + } + return true; +} + +// Structs used by DecodeGroupImpl to get a quantized block. +// GetBlockFromBitstream uses ANS decoding (and thus keeps track of row +// pointers in row_nzeros), GetBlockFromEncoder simply reads the coefficient +// image provided by the encoder. + +struct GetBlockFromBitstream : public GetBlock { + void StartRow(size_t by) override { + qf_row = rect.ConstRow(*qf, by); + for (size_t c = 0; c < 3; c++) { + size_t sby = by >> vshift[c]; + quant_dc_row = quant_dc->ConstRow(rect.y0() + by) + rect.x0(); + for (size_t i = 0; i < num_passes; i++) { + row_nzeros[i][c] = group_dec_cache->num_nzeroes[i].PlaneRow(c, sby); + row_nzeros_top[i][c] = + sby == 0 + ? nullptr + : group_dec_cache->num_nzeroes[i].ConstPlaneRow(c, sby - 1); + } + } + } + + Status LoadBlock(size_t bx, size_t by, const AcStrategy& acs, size_t size, + size_t log2_covered_blocks, ACPtr block[3], + ACType ac_type) override { + auto decode_ac_varblock = ac_type == ACType::k16 + ? DecodeACVarBlock<ACType::k16> + : DecodeACVarBlock<ACType::k32>; + for (size_t c : {1, 0, 2}) { + size_t sbx = bx >> hshift[c]; + size_t sby = by >> vshift[c]; + if (JXL_UNLIKELY((sbx << hshift[c] != bx) || (sby << vshift[c] != by))) { + continue; + } + + for (size_t pass = 0; JXL_UNLIKELY(pass < num_passes); pass++) { + JXL_RETURN_IF_ERROR(decode_ac_varblock( + ctx_offset[pass], log2_covered_blocks, row_nzeros[pass][c], + row_nzeros_top[pass][c], nzeros_stride, c, sbx, sby, bx, acs, + &coeff_orders[pass * coeff_order_size], readers[pass], + &decoders[pass], context_map[pass], quant_dc_row, qf_row, + *block_ctx_map, block[c], shift_for_pass[pass])); + } + } + return true; + } + + Status Init(BitReader* JXL_RESTRICT* JXL_RESTRICT readers, size_t num_passes, + size_t group_idx, size_t histo_selector_bits, const Rect& rect, + GroupDecCache* JXL_RESTRICT group_dec_cache, + PassesDecoderState* dec_state, size_t first_pass) { + for (size_t i = 0; i < 3; i++) { + hshift[i] = dec_state->shared->frame_header.chroma_subsampling.HShift(i); + vshift[i] = dec_state->shared->frame_header.chroma_subsampling.VShift(i); + } + this->coeff_order_size = dec_state->shared->coeff_order_size; + this->coeff_orders = + dec_state->shared->coeff_orders.data() + first_pass * coeff_order_size; + this->context_map = dec_state->context_map.data() + first_pass; + this->readers = readers; + this->num_passes = num_passes; + this->shift_for_pass = + dec_state->shared->frame_header.passes.shift + first_pass; + this->group_dec_cache = group_dec_cache; + this->rect = rect; + block_ctx_map = &dec_state->shared->block_ctx_map; + qf = &dec_state->shared->raw_quant_field; + quant_dc = &dec_state->shared->quant_dc; + + for (size_t pass = 0; pass < num_passes; pass++) { + // Select which histogram set to use among those of the current pass. + size_t cur_histogram = 0; + if (histo_selector_bits != 0) { + cur_histogram = readers[pass]->ReadBits(histo_selector_bits); + } + if (cur_histogram >= dec_state->shared->num_histograms) { + return JXL_FAILURE("Invalid histogram selector"); + } + ctx_offset[pass] = cur_histogram * block_ctx_map->NumACContexts(); + + decoders[pass] = + ANSSymbolReader(&dec_state->code[pass + first_pass], readers[pass]); + } + nzeros_stride = group_dec_cache->num_nzeroes[0].PixelsPerRow(); + for (size_t i = 0; i < num_passes; i++) { + JXL_ASSERT( + nzeros_stride == + static_cast<size_t>(group_dec_cache->num_nzeroes[i].PixelsPerRow())); + } + return true; + } + + const uint32_t* shift_for_pass = nullptr; // not owned + const coeff_order_t* JXL_RESTRICT coeff_orders; + size_t coeff_order_size; + const std::vector<uint8_t>* JXL_RESTRICT context_map; + ANSSymbolReader decoders[kMaxNumPasses]; + BitReader* JXL_RESTRICT* JXL_RESTRICT readers; + size_t num_passes; + size_t ctx_offset[kMaxNumPasses]; + size_t nzeros_stride; + int32_t* JXL_RESTRICT row_nzeros[kMaxNumPasses][3]; + const int32_t* JXL_RESTRICT row_nzeros_top[kMaxNumPasses][3]; + GroupDecCache* JXL_RESTRICT group_dec_cache; + const BlockCtxMap* block_ctx_map; + const ImageI* qf; + const ImageB* quant_dc; + const int32_t* qf_row; + const uint8_t* quant_dc_row; + Rect rect; + size_t hshift[3], vshift[3]; +}; + +struct GetBlockFromEncoder : public GetBlock { + void StartRow(size_t by) override {} + + Status LoadBlock(size_t bx, size_t by, const AcStrategy& acs, size_t size, + size_t log2_covered_blocks, ACPtr block[3], + ACType ac_type) override { + JXL_DASSERT(ac_type == ACType::k32); + for (size_t c = 0; c < 3; c++) { + // for each pass + for (size_t i = 0; i < quantized_ac->size(); i++) { + for (size_t k = 0; k < size; k++) { + // TODO(veluca): SIMD. + block[c].ptr32[k] += + rows[i][c][offset + k] * (1 << shift_for_pass[i]); + } + } + } + offset += size; + return true; + } + + GetBlockFromEncoder(const std::vector<std::unique_ptr<ACImage>>& ac, + size_t group_idx, const uint32_t* shift_for_pass) + : quantized_ac(&ac), shift_for_pass(shift_for_pass) { + // TODO(veluca): not supported with chroma subsampling. + for (size_t i = 0; i < quantized_ac->size(); i++) { + JXL_CHECK((*quantized_ac)[i]->Type() == ACType::k32); + for (size_t c = 0; c < 3; c++) { + rows[i][c] = (*quantized_ac)[i]->PlaneRow(c, group_idx, 0).ptr32; + } + } + } + + const std::vector<std::unique_ptr<ACImage>>* JXL_RESTRICT quantized_ac; + size_t offset = 0; + const int32_t* JXL_RESTRICT rows[kMaxNumPasses][3]; + const uint32_t* shift_for_pass = nullptr; // not owned +}; + +HWY_EXPORT(DecodeGroupImpl); + +} // namespace + +Status DecodeGroup(BitReader* JXL_RESTRICT* JXL_RESTRICT readers, + size_t num_passes, size_t group_idx, + PassesDecoderState* JXL_RESTRICT dec_state, + GroupDecCache* JXL_RESTRICT group_dec_cache, size_t thread, + RenderPipelineInput& render_pipeline_input, + ImageBundle* JXL_RESTRICT decoded, size_t first_pass, + bool force_draw, bool dc_only, bool* should_run_pipeline) { + PROFILER_FUNC; + + DrawMode draw = (num_passes + first_pass == + dec_state->shared->frame_header.passes.num_passes) || + force_draw + ? kDraw + : kDontDraw; + + if (should_run_pipeline) { + *should_run_pipeline = draw != kDontDraw; + } + + if (draw == kDraw && num_passes == 0 && first_pass == 0) { + group_dec_cache->InitDCBufferOnce(); + const YCbCrChromaSubsampling& cs = + dec_state->shared->frame_header.chroma_subsampling; + for (size_t c : {0, 1, 2}) { + size_t hs = cs.HShift(c); + size_t vs = cs.VShift(c); + // We reuse filter_input_storage here as it is not currently in use. + const Rect src_rect_precs = dec_state->shared->BlockGroupRect(group_idx); + const Rect src_rect = + Rect(src_rect_precs.x0() >> hs, src_rect_precs.y0() >> vs, + src_rect_precs.xsize() >> hs, src_rect_precs.ysize() >> vs); + const Rect copy_rect(kRenderPipelineXOffset, 2, src_rect.xsize(), + src_rect.ysize()); + CopyImageToWithPadding(src_rect, dec_state->shared->dc->Plane(c), 2, + copy_rect, &group_dec_cache->dc_buffer); + // Mirrorpad. Interleaving left and right padding ensures that padding + // works out correctly even for images with DC size of 1. + for (size_t y = 0; y < src_rect.ysize() + 4; y++) { + size_t xend = kRenderPipelineXOffset + + (dec_state->shared->dc->Plane(c).xsize() >> hs) - + src_rect.x0(); + for (size_t ix = 0; ix < 2; ix++) { + if (src_rect.x0() == 0) { + group_dec_cache->dc_buffer.Row(y)[kRenderPipelineXOffset - ix - 1] = + group_dec_cache->dc_buffer.Row(y)[kRenderPipelineXOffset + ix]; + } + if (src_rect.x0() + src_rect.xsize() + 2 >= + (dec_state->shared->dc->xsize() >> hs)) { + group_dec_cache->dc_buffer.Row(y)[xend + ix] = + group_dec_cache->dc_buffer.Row(y)[xend - ix - 1]; + } + } + } + Rect dst_rect = render_pipeline_input.GetBuffer(c).second; + ImageF* upsampling_dst = render_pipeline_input.GetBuffer(c).first; + JXL_ASSERT(dst_rect.IsInside(*upsampling_dst)); + + RenderPipelineStage::RowInfo input_rows(1, std::vector<float*>(5)); + RenderPipelineStage::RowInfo output_rows(1, std::vector<float*>(8)); + for (size_t y = src_rect.y0(); y < src_rect.y0() + src_rect.ysize(); + y++) { + for (ssize_t iy = 0; iy < 5; iy++) { + input_rows[0][iy] = group_dec_cache->dc_buffer.Row( + Mirror(ssize_t(y) + iy - 2, + dec_state->shared->dc->Plane(c).ysize() >> vs) + + 2 - src_rect.y0()); + } + for (size_t iy = 0; iy < 8; iy++) { + output_rows[0][iy] = + dst_rect.Row(upsampling_dst, ((y - src_rect.y0()) << 3) + iy) - + kRenderPipelineXOffset; + } + // Arguments set to 0/nullptr are not used. + dec_state->upsampler8x->ProcessRow(input_rows, output_rows, + /*xextra=*/0, src_rect.xsize(), 0, 0, + thread); + } + } + return true; + } + + size_t histo_selector_bits = 0; + if (dc_only) { + JXL_ASSERT(num_passes == 0); + } else { + JXL_ASSERT(dec_state->shared->num_histograms > 0); + histo_selector_bits = CeilLog2Nonzero(dec_state->shared->num_histograms); + } + + GetBlockFromBitstream get_block; + JXL_RETURN_IF_ERROR( + get_block.Init(readers, num_passes, group_idx, histo_selector_bits, + dec_state->shared->BlockGroupRect(group_idx), + group_dec_cache, dec_state, first_pass)); + + JXL_RETURN_IF_ERROR(HWY_DYNAMIC_DISPATCH(DecodeGroupImpl)( + &get_block, group_dec_cache, dec_state, thread, group_idx, + render_pipeline_input, decoded, draw)); + + for (size_t pass = 0; pass < num_passes; pass++) { + if (!get_block.decoders[pass].CheckANSFinalState()) { + return JXL_FAILURE("ANS checksum failure."); + } + } + return true; +} + +Status DecodeGroupForRoundtrip(const std::vector<std::unique_ptr<ACImage>>& ac, + size_t group_idx, + PassesDecoderState* JXL_RESTRICT dec_state, + GroupDecCache* JXL_RESTRICT group_dec_cache, + size_t thread, + RenderPipelineInput& render_pipeline_input, + ImageBundle* JXL_RESTRICT decoded, + AuxOut* aux_out) { + PROFILER_FUNC; + + GetBlockFromEncoder get_block(ac, group_idx, + dec_state->shared->frame_header.passes.shift); + group_dec_cache->InitOnce( + /*num_passes=*/0, + /*used_acs=*/(1u << AcStrategy::kNumValidStrategies) - 1); + + return HWY_DYNAMIC_DISPATCH(DecodeGroupImpl)( + &get_block, group_dec_cache, dec_state, thread, group_idx, + render_pipeline_input, decoded, kDraw); +} + +} // namespace jxl +#endif // HWY_ONCE |