1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
|
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
#include "lib/jxl/dec_context_map.h"
#include <algorithm>
#include <vector>
#include "lib/jxl/ans_params.h"
#include "lib/jxl/base/status.h"
#include "lib/jxl/dec_ans.h"
#include "lib/jxl/entropy_coder.h"
#include "lib/jxl/inverse_mtf-inl.h"
namespace jxl {
namespace {
Status VerifyContextMap(const std::vector<uint8_t>& context_map,
const size_t num_htrees) {
std::vector<bool> have_htree(num_htrees);
size_t num_found = 0;
for (const uint8_t htree : context_map) {
if (htree >= num_htrees) {
return JXL_FAILURE("Invalid histogram index in context map.");
}
if (!have_htree[htree]) {
have_htree[htree] = true;
++num_found;
}
}
if (num_found != num_htrees) {
return JXL_FAILURE("Incomplete context map.");
}
return true;
}
} // namespace
Status DecodeContextMap(std::vector<uint8_t>* context_map, size_t* num_htrees,
BitReader* input) {
bool is_simple = input->ReadFixedBits<1>();
if (is_simple) {
int bits_per_entry = input->ReadFixedBits<2>();
if (bits_per_entry != 0) {
for (size_t i = 0; i < context_map->size(); i++) {
(*context_map)[i] = input->ReadBits(bits_per_entry);
}
} else {
std::fill(context_map->begin(), context_map->end(), 0);
}
} else {
bool use_mtf = input->ReadFixedBits<1>();
ANSCode code;
std::vector<uint8_t> dummy_ctx_map;
// Usage of LZ77 is disallowed if decoding only two symbols. This doesn't
// make sense in non-malicious bitstreams, and could cause a stack overflow
// in malicious bitstreams by making every context map require its own
// context map.
JXL_RETURN_IF_ERROR(
DecodeHistograms(input, 1, &code, &dummy_ctx_map,
/*disallow_lz77=*/context_map->size() <= 2));
ANSSymbolReader reader(&code, input);
size_t i = 0;
while (i < context_map->size()) {
uint32_t sym = reader.ReadHybridUint(0, input, dummy_ctx_map);
if (sym >= kMaxClusters) {
return JXL_FAILURE("Invalid cluster ID");
}
(*context_map)[i] = sym;
i++;
}
if (!reader.CheckANSFinalState()) {
return JXL_FAILURE("Invalid context map");
}
if (use_mtf) {
InverseMoveToFrontTransform(context_map->data(), context_map->size());
}
}
*num_htrees = *std::max_element(context_map->begin(), context_map->end()) + 1;
return VerifyContextMap(*context_map, *num_htrees);
}
} // namespace jxl
|