summaryrefslogtreecommitdiffstats
path: root/third_party/jpeg-xl/lib/jxl/dec_context_map.cc
blob: baff87fa493caf28ac44d948edfb8b422caf8275 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

#include "lib/jxl/dec_context_map.h"

#include <algorithm>
#include <cstdint>
#include <vector>

#include "lib/jxl/ans_params.h"
#include "lib/jxl/base/status.h"
#include "lib/jxl/dec_ans.h"
#include "lib/jxl/entropy_coder.h"
#include "lib/jxl/inverse_mtf-inl.h"

namespace jxl {

namespace {

Status VerifyContextMap(const std::vector<uint8_t>& context_map,
                        const size_t num_htrees) {
  std::vector<bool> have_htree(num_htrees);
  size_t num_found = 0;
  for (const uint8_t htree : context_map) {
    if (htree >= num_htrees) {
      return JXL_FAILURE("Invalid histogram index in context map.");
    }
    if (!have_htree[htree]) {
      have_htree[htree] = true;
      ++num_found;
    }
  }
  if (num_found != num_htrees) {
    return JXL_FAILURE("Incomplete context map.");
  }
  return true;
}

}  // namespace

Status DecodeContextMap(std::vector<uint8_t>* context_map, size_t* num_htrees,
                        BitReader* input) {
  bool is_simple = static_cast<bool>(input->ReadFixedBits<1>());
  if (is_simple) {
    int bits_per_entry = input->ReadFixedBits<2>();
    if (bits_per_entry != 0) {
      for (uint8_t& entry : *context_map) {
        entry = input->ReadBits(bits_per_entry);
      }
    } else {
      std::fill(context_map->begin(), context_map->end(), 0);
    }
  } else {
    bool use_mtf = static_cast<bool>(input->ReadFixedBits<1>());
    ANSCode code;
    std::vector<uint8_t> sink_ctx_map;
    // Usage of LZ77 is disallowed if decoding only two symbols. This doesn't
    // make sense in non-malicious bitstreams, and could cause a stack overflow
    // in malicious bitstreams by making every context map require its own
    // context map.
    JXL_RETURN_IF_ERROR(
        DecodeHistograms(input, 1, &code, &sink_ctx_map,
                         /*disallow_lz77=*/context_map->size() <= 2));
    ANSSymbolReader reader(&code, input);
    size_t i = 0;
    uint32_t maxsym = 0;
    while (i < context_map->size()) {
      uint32_t sym = reader.ReadHybridUintInlined</*uses_lz77=*/true>(
          0, input, sink_ctx_map);
      maxsym = sym > maxsym ? sym : maxsym;
      (*context_map)[i] = sym;
      i++;
    }
    if (maxsym >= kMaxClusters) {
      return JXL_FAILURE("Invalid cluster ID");
    }
    if (!reader.CheckANSFinalState()) {
      return JXL_FAILURE("Invalid context map");
    }
    if (use_mtf) {
      InverseMoveToFrontTransform(context_map->data(), context_map->size());
    }
  }
  *num_htrees = *std::max_element(context_map->begin(), context_map->end()) + 1;
  return VerifyContextMap(*context_map, *num_htrees);
}

}  // namespace jxl