1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
|
pub struct BitReader<'s> {
idx: usize, //index counts bits already read
source: &'s [u8],
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum GetBitsError {
#[error("Cant serve this request. The reader is limited to {limit} bits, requested {num_requested_bits} bits")]
TooManyBits {
num_requested_bits: usize,
limit: u8,
},
#[error("Can't read {requested} bits, only have {remaining} bits left")]
NotEnoughRemainingBits { requested: usize, remaining: usize },
}
impl<'s> BitReader<'s> {
pub fn new(source: &'s [u8]) -> BitReader<'_> {
BitReader { idx: 0, source }
}
pub fn bits_left(&self) -> usize {
self.source.len() * 8 - self.idx
}
pub fn bits_read(&self) -> usize {
self.idx
}
pub fn return_bits(&mut self, n: usize) {
if n > self.idx {
panic!("Cant return this many bits");
}
self.idx -= n;
}
pub fn get_bits(&mut self, n: usize) -> Result<u64, GetBitsError> {
if n > 64 {
return Err(GetBitsError::TooManyBits {
num_requested_bits: n,
limit: 64,
});
}
if self.bits_left() < n {
return Err(GetBitsError::NotEnoughRemainingBits {
requested: n,
remaining: self.bits_left(),
});
}
let old_idx = self.idx;
let bits_left_in_current_byte = 8 - (self.idx % 8);
let bits_not_needed_in_current_byte = 8 - bits_left_in_current_byte;
//collect bits from the currently pointed to byte
let mut value = u64::from(self.source[self.idx / 8] >> bits_not_needed_in_current_byte);
if bits_left_in_current_byte >= n {
//no need for fancy stuff
//just mask all but the needed n bit
value &= (1 << n) - 1;
self.idx += n;
} else {
self.idx += bits_left_in_current_byte;
//n spans over multiple bytes
let full_bytes_needed = (n - bits_left_in_current_byte) / 8;
let bits_in_last_byte_needed = n - bits_left_in_current_byte - full_bytes_needed * 8;
assert!(
bits_left_in_current_byte + full_bytes_needed * 8 + bits_in_last_byte_needed == n
);
let mut bit_shift = bits_left_in_current_byte; //this many bits are already set in value
assert!(self.idx % 8 == 0);
//collect full bytes
for _ in 0..full_bytes_needed {
value |= u64::from(self.source[self.idx / 8]) << bit_shift;
self.idx += 8;
bit_shift += 8;
}
assert!(n - bit_shift == bits_in_last_byte_needed);
if bits_in_last_byte_needed > 0 {
let val_las_byte =
u64::from(self.source[self.idx / 8]) & ((1 << bits_in_last_byte_needed) - 1);
value |= val_las_byte << bit_shift;
self.idx += bits_in_last_byte_needed;
}
}
assert!(self.idx == old_idx + n);
Ok(value)
}
pub fn reset(&mut self, new_source: &'s [u8]) {
self.idx = 0;
self.source = new_source;
}
}
|