summaryrefslogtreecommitdiffstats
path: root/third_party/jpeg-xl/lib/jxl/modular/transform/palette.cc
blob: 1ab499ccf611264c660add0bfd229ffcf58d627a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

#include "lib/jxl/modular/transform/palette.h"

namespace jxl {

Status InvPalette(Image &input, uint32_t begin_c, uint32_t nb_colors,
                  uint32_t nb_deltas, Predictor predictor,
                  const weighted::Header &wp_header, ThreadPool *pool) {
  if (input.nb_meta_channels < 1) {
    return JXL_FAILURE("Error: Palette transform without palette.");
  }
  std::atomic<int> num_errors{0};
  int nb = input.channel[0].h;
  uint32_t c0 = begin_c + 1;
  if (c0 >= input.channel.size()) {
    return JXL_FAILURE("Channel is out of range.");
  }
  size_t w = input.channel[c0].w;
  size_t h = input.channel[c0].h;
  if (nb < 1) return JXL_FAILURE("Corrupted transforms");
  for (int i = 1; i < nb; i++) {
    StatusOr<Channel> channel_or = Channel::Create(
        w, h, input.channel[c0].hshift, input.channel[c0].vshift);
    JXL_RETURN_IF_ERROR(channel_or.status());
    input.channel.insert(input.channel.begin() + c0 + 1,
                         std::move(channel_or).value());
  }
  const Channel &palette = input.channel[0];
  const pixel_type *JXL_RESTRICT p_palette = input.channel[0].Row(0);
  intptr_t onerow = input.channel[0].plane.PixelsPerRow();
  intptr_t onerow_image = input.channel[c0].plane.PixelsPerRow();
  const int bit_depth = std::min(input.bitdepth, 24);

  if (w == 0) {
    // Nothing to do.
    // Avoid touching "empty" channels with non-zero height.
  } else if (nb_deltas == 0 && predictor == Predictor::Zero) {
    if (nb == 1) {
      JXL_RETURN_IF_ERROR(RunOnPool(
          pool, 0, h, ThreadPool::NoInit,
          [&](const uint32_t task, size_t /* thread */) {
            const size_t y = task;
            pixel_type *p = input.channel[c0].Row(y);
            for (size_t x = 0; x < w; x++) {
              const int index =
                  Clamp1<int>(p[x], 0, static_cast<pixel_type>(palette.w) - 1);
              p[x] = palette_internal::GetPaletteValue(
                  p_palette, index, /*c=*/0,
                  /*palette_size=*/palette.w,
                  /*onerow=*/onerow, /*bit_depth=*/bit_depth);
            }
          },
          "UndoChannelPalette"));
    } else {
      JXL_RETURN_IF_ERROR(RunOnPool(
          pool, 0, h, ThreadPool::NoInit,
          [&](const uint32_t task, size_t /* thread */) {
            const size_t y = task;
            std::vector<pixel_type *> p_out(nb);
            const pixel_type *p_index = input.channel[c0].Row(y);
            for (int c = 0; c < nb; c++)
              p_out[c] = input.channel[c0 + c].Row(y);
            for (size_t x = 0; x < w; x++) {
              const int index = p_index[x];
              for (int c = 0; c < nb; c++) {
                p_out[c][x] = palette_internal::GetPaletteValue(
                    p_palette, index, /*c=*/c,
                    /*palette_size=*/palette.w,
                    /*onerow=*/onerow, /*bit_depth=*/bit_depth);
              }
            }
          },
          "UndoPalette"));
    }
  } else {
    // Parallelized per channel.
    ImageI indices;
    ImageI &plane = input.channel[c0].plane;
    JXL_ASSIGN_OR_RETURN(indices, ImageI::Create(plane.xsize(), plane.ysize()));
    plane.Swap(indices);
    if (predictor == Predictor::Weighted) {
      JXL_RETURN_IF_ERROR(RunOnPool(
          pool, 0, nb, ThreadPool::NoInit,
          [&](const uint32_t c, size_t /* thread */) {
            Channel &channel = input.channel[c0 + c];
            weighted::State wp_state(wp_header, channel.w, channel.h);
            for (size_t y = 0; y < channel.h; y++) {
              pixel_type *JXL_RESTRICT p = channel.Row(y);
              const pixel_type *JXL_RESTRICT idx = indices.Row(y);
              for (size_t x = 0; x < channel.w; x++) {
                int index = idx[x];
                pixel_type_w val = 0;
                const pixel_type palette_entry =
                    palette_internal::GetPaletteValue(
                        p_palette, index, /*c=*/c,
                        /*palette_size=*/palette.w, /*onerow=*/onerow,
                        /*bit_depth=*/bit_depth);
                if (index < static_cast<int32_t>(nb_deltas)) {
                  PredictionResult pred =
                      PredictNoTreeWP(channel.w, p + x, onerow_image, x, y,
                                      predictor, &wp_state);
                  val = pred.guess + palette_entry;
                } else {
                  val = palette_entry;
                }
                p[x] = val;
                wp_state.UpdateErrors(p[x], x, y, channel.w);
              }
            }
          },
          "UndoDeltaPaletteWP"));
    } else {
      JXL_RETURN_IF_ERROR(RunOnPool(
          pool, 0, nb, ThreadPool::NoInit,
          [&](const uint32_t c, size_t /* thread */) {
            Channel &channel = input.channel[c0 + c];
            for (size_t y = 0; y < channel.h; y++) {
              pixel_type *JXL_RESTRICT p = channel.Row(y);
              const pixel_type *JXL_RESTRICT idx = indices.Row(y);
              for (size_t x = 0; x < channel.w; x++) {
                int index = idx[x];
                pixel_type_w val = 0;
                const pixel_type palette_entry =
                    palette_internal::GetPaletteValue(
                        p_palette, index, /*c=*/c,
                        /*palette_size=*/palette.w,
                        /*onerow=*/onerow, /*bit_depth=*/bit_depth);
                if (index < static_cast<int32_t>(nb_deltas)) {
                  PredictionResult pred = PredictNoTreeNoWP(
                      channel.w, p + x, onerow_image, x, y, predictor);
                  val = pred.guess + palette_entry;
                } else {
                  val = palette_entry;
                }
                p[x] = val;
              }
            }
          },
          "UndoDeltaPaletteNoWP"));
    }
  }
  if (c0 >= input.nb_meta_channels) {
    // Palette was done on normal channels
    input.nb_meta_channels--;
  } else {
    // Palette was done on metachannels
    JXL_ASSERT(static_cast<int>(input.nb_meta_channels) >= 2 - nb);
    input.nb_meta_channels -= 2 - nb;
    JXL_ASSERT(begin_c + nb - 1 < input.nb_meta_channels);
  }
  input.channel.erase(input.channel.begin(), input.channel.begin() + 1);
  return num_errors.load(std::memory_order_relaxed) == 0;
}

Status MetaPalette(Image &input, uint32_t begin_c, uint32_t end_c,
                   uint32_t nb_colors, uint32_t nb_deltas, bool lossy) {
  JXL_RETURN_IF_ERROR(CheckEqualChannels(input, begin_c, end_c));

  size_t nb = end_c - begin_c + 1;
  if (begin_c >= input.nb_meta_channels) {
    // Palette was done on normal channels
    input.nb_meta_channels++;
  } else {
    // Palette was done on metachannels
    JXL_ASSERT(end_c < input.nb_meta_channels);
    // we remove nb-1 metachannels and add one
    input.nb_meta_channels += 2 - nb;
  }
  input.channel.erase(input.channel.begin() + begin_c + 1,
                      input.channel.begin() + end_c + 1);
  JXL_ASSIGN_OR_RETURN(Channel pch, Channel::Create(nb_colors + nb_deltas, nb));
  pch.hshift = -1;
  pch.vshift = -1;
  input.channel.insert(input.channel.begin(), std::move(pch));
  return true;
}

}  // namespace jxl