summaryrefslogtreecommitdiffstats
path: root/third_party/jpeg-xl/lib/jxl/enc_patch_dictionary.h
blob: f30881b23258ee2e87de8b69092908c36511349f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

#ifndef LIB_JXL_ENC_PATCH_DICTIONARY_H_
#define LIB_JXL_ENC_PATCH_DICTIONARY_H_

// Chooses reference patches, and avoids encoding them once per occurrence.

#include <stddef.h>
#include <string.h>
#include <sys/types.h>

#include <tuple>
#include <vector>

#include "lib/jxl/base/data_parallel.h"
#include "lib/jxl/base/status.h"
#include "lib/jxl/chroma_from_luma.h"
#include "lib/jxl/common.h"
#include "lib/jxl/dec_bit_reader.h"
#include "lib/jxl/dec_patch_dictionary.h"
#include "lib/jxl/enc_bit_writer.h"
#include "lib/jxl/enc_cache.h"
#include "lib/jxl/enc_params.h"
#include "lib/jxl/image.h"
#include "lib/jxl/opsin_params.h"

namespace jxl {

struct AuxOut;

constexpr size_t kMaxPatchSize = 32;

struct QuantizedPatch {
  size_t xsize;
  size_t ysize;
  QuantizedPatch() {
    for (size_t i = 0; i < 3; i++) {
      pixels[i].resize(kMaxPatchSize * kMaxPatchSize);
      fpixels[i].resize(kMaxPatchSize * kMaxPatchSize);
    }
  }
  std::vector<int8_t> pixels[3] = {};
  // Not compared. Used only to retrieve original pixels to construct the
  // reference image.
  std::vector<float> fpixels[3] = {};
  bool operator==(const QuantizedPatch& other) const {
    if (xsize != other.xsize) return false;
    if (ysize != other.ysize) return false;
    for (size_t c = 0; c < 3; c++) {
      if (memcmp(pixels[c].data(), other.pixels[c].data(),
                 sizeof(int8_t) * xsize * ysize) != 0)
        return false;
    }
    return true;
  }

  bool operator<(const QuantizedPatch& other) const {
    if (xsize != other.xsize) return xsize < other.xsize;
    if (ysize != other.ysize) return ysize < other.ysize;
    for (size_t c = 0; c < 3; c++) {
      int cmp = memcmp(pixels[c].data(), other.pixels[c].data(),
                       sizeof(int8_t) * xsize * ysize);
      if (cmp > 0) return false;
      if (cmp < 0) return true;
    }
    return false;
  }
};

// Pair (patch, vector of occurrences).
using PatchInfo =
    std::pair<QuantizedPatch, std::vector<std::pair<uint32_t, uint32_t>>>;

// Friend class of PatchDictionary.
class PatchDictionaryEncoder {
 public:
  // Only call if HasAny().
  static void Encode(const PatchDictionary& pdic, BitWriter* writer,
                     size_t layer, AuxOut* aux_out);

  static void SetPositions(PatchDictionary* pdic,
                           std::vector<PatchPosition> positions,
                           std::vector<PatchReferencePosition> ref_positions,
                           std::vector<PatchBlending> blendings) {
    pdic->positions_ = std::move(positions);
    pdic->ref_positions_ = std::move(ref_positions);
    pdic->blendings_ = std::move(blendings);
    pdic->ComputePatchTree();
  }

  static void SubtractFrom(const PatchDictionary& pdic, Image3F* opsin);
};

void FindBestPatchDictionary(const Image3F& opsin,
                             PassesEncoderState* JXL_RESTRICT state,
                             const JxlCmsInterface& cms, ThreadPool* pool,
                             AuxOut* aux_out, bool is_xyb = true);

void RoundtripPatchFrame(Image3F* reference_frame,
                         PassesEncoderState* JXL_RESTRICT state, int idx,
                         CompressParams& cparams, const JxlCmsInterface& cms,
                         ThreadPool* pool, AuxOut* aux_out, bool subtract);

}  // namespace jxl

#endif  // LIB_JXL_ENC_PATCH_DICTIONARY_H_