diff options
Diffstat (limited to 'third_party/rust/encoding_rs/src/single_byte.rs')
-rw-r--r-- | third_party/rust/encoding_rs/src/single_byte.rs | 64 |
1 files changed, 63 insertions, 1 deletions
diff --git a/third_party/rust/encoding_rs/src/single_byte.rs b/third_party/rust/encoding_rs/src/single_byte.rs index b3b6089d31..b7a4bf23da 100644 --- a/third_party/rust/encoding_rs/src/single_byte.rs +++ b/third_party/rust/encoding_rs/src/single_byte.rs @@ -53,6 +53,9 @@ impl SingleByteDecoder { // statically omit the bound check when accessing // `[u16; 128]` with an index // `non_ascii as usize - 0x80usize`. + // + // Safety: `non_ascii` is a u8 byte >=0x80, from the invariants + // on Utf8Destination::copy_ascii_from_check_space_bmp() let mapped = unsafe { *(self.table.get_unchecked(non_ascii as usize - 0x80usize)) }; // let mapped = self.table[non_ascii as usize - 0x80usize]; @@ -151,9 +154,12 @@ impl SingleByteDecoder { } else { (DecoderResult::InputEmpty, src.len()) }; + // Safety invariant: converted <= length. Quite often we have `converted < length` + // which will be separately marked. let mut converted = 0usize; 'outermost: loop { match unsafe { + // Safety: length is the minimum length, `src/dst + x` will always be valid for reads/writes of `len - x` ascii_to_basic_latin( src.as_ptr().add(converted), dst.as_mut_ptr().add(converted), @@ -164,6 +170,12 @@ impl SingleByteDecoder { return (pending, length, length); } Some((mut non_ascii, consumed)) => { + // Safety invariant: `converted <= length` upheld, since this can only consume + // up to `length - converted` bytes. + // + // Furthermore, in this context, + // we can assume `converted < length` since this branch is only ever hit when + // ascii_to_basic_latin fails to consume the entire slice converted += consumed; 'middle: loop { // `converted` doesn't count the reading of `non_ascii` yet. @@ -172,6 +184,9 @@ impl SingleByteDecoder { // statically omit the bound check when accessing // `[u16; 128]` with an index // `non_ascii as usize - 0x80usize`. + // + // Safety: We can rely on `non_ascii` being between `0x80` and `0xFF` due to + // the invariants of `ascii_to_basic_latin()`, and our table has enough space for that. let mapped = unsafe { *(self.table.get_unchecked(non_ascii as usize - 0x80usize)) }; // let mapped = self.table[non_ascii as usize - 0x80usize]; @@ -183,9 +198,10 @@ impl SingleByteDecoder { ); } unsafe { - // The bound check has already been performed + // Safety: As mentioned above, `converted < length` *(dst.get_unchecked_mut(converted)) = mapped; } + // Safety: `converted <= length` upheld, since `converted < length` before this converted += 1; // Next, handle ASCII punctuation and non-ASCII without // going back to ASCII acceleration. Non-ASCII scripts @@ -198,7 +214,10 @@ impl SingleByteDecoder { if converted == length { return (pending, length, length); } + // Safety: We are back to `converted < length` because of the == above + // and can perform this check. let mut b = unsafe { *(src.get_unchecked(converted)) }; + // Safety: `converted < length` is upheld for this loop 'innermost: loop { if b > 127 { non_ascii = b; @@ -208,15 +227,20 @@ impl SingleByteDecoder { // byte unconditionally instead of trying to unread it // to make it part of the next SIMD stride. unsafe { + // Safety: `converted < length` is true for this loop *(dst.get_unchecked_mut(converted)) = u16::from(b); } + // Safety: We are now at `converted <= length`. We should *not* `continue` + // the loop without reverifying converted += 1; if b < 60 { // We've got punctuation if converted == length { return (pending, length, length); } + // Safety: we're back to `converted <= length` because of the == above b = unsafe { *(src.get_unchecked(converted)) }; + // Safety: The loop continues as `converted < length` continue 'innermost; } // We've got markup or ASCII text @@ -234,6 +258,8 @@ impl SingleByteDecoder { loop { if let Some((non_ascii, offset)) = validate_ascii(bytes) { total += offset; + // Safety: We can rely on `non_ascii` being between `0x80` and `0xFF` due to + // the invariants of `ascii_to_basic_latin()`, and our table has enough space for that. let mapped = unsafe { *(self.table.get_unchecked(non_ascii as usize - 0x80usize)) }; if mapped != u16::from(non_ascii) { return total; @@ -384,9 +410,12 @@ impl SingleByteEncoder { } else { (EncoderResult::InputEmpty, src.len()) }; + // Safety invariant: converted <= length. Quite often we have `converted < length` + // which will be separately marked. let mut converted = 0usize; 'outermost: loop { match unsafe { + // Safety: length is the minimum length, `src/dst + x` will always be valid for reads/writes of `len - x` basic_latin_to_ascii( src.as_ptr().add(converted), dst.as_mut_ptr().add(converted), @@ -397,15 +426,23 @@ impl SingleByteEncoder { return (pending, length, length); } Some((mut non_ascii, consumed)) => { + // Safety invariant: `converted <= length` upheld, since this can only consume + // up to `length - converted` bytes. + // + // Furthermore, in this context, + // we can assume `converted < length` since this branch is only ever hit when + // ascii_to_basic_latin fails to consume the entire slice converted += consumed; 'middle: loop { // `converted` doesn't count the reading of `non_ascii` yet. match self.encode_u16(non_ascii) { Some(byte) => { unsafe { + // Safety: we're allowed this access since `converted < length` *(dst.get_unchecked_mut(converted)) = byte; } converted += 1; + // `converted <= length` now } None => { // At this point, we need to know if we @@ -421,6 +458,8 @@ impl SingleByteEncoder { converted, ); } + // Safety: convered < length from outside the match, and `converted + 1 != length`, + // So `converted + 1 < length` as well. We're in bounds let second = u32::from(unsafe { *src.get_unchecked(converted + 1) }); if second & 0xFC00u32 != 0xDC00u32 { @@ -432,6 +471,18 @@ impl SingleByteEncoder { } // The next code unit is a low surrogate. let astral: char = unsafe { + // Safety: We can rely on non_ascii being 0xD800-0xDBFF since the high bits are 0xD800 + // Then, (non_ascii << 10 - 0xD800 << 10) becomes between (0 to 0x3FF) << 10, which is between + // 0x400 to 0xffc00. Adding the 0x10000 gives a range of 0x10400 to 0x10fc00. Subtracting the 0xDC00 + // gives 0x2800 to 0x102000 + // The second term is between 0xDC00 and 0xDFFF from the check above. This gives a maximum + // possible range of (0x10400 + 0xDC00) to (0x102000 + 0xDFFF) which is 0x1E000 to 0x10ffff. + // This is in range. + // + // From a Unicode principles perspective this can also be verified as we have checked that `non_ascii` is a high surrogate + // (0xD800..=0xDBFF), and that `second` is a low surrogate (`0xDC00..=0xDFFF`), and we are applying reverse of the UTC16 transformation + // algorithm <https://en.wikipedia.org/wiki/UTF-16#Code_points_from_U+010000_to_U+10FFFF>, by applying the high surrogate - 0xD800 to the + // high ten bits, and the low surrogate - 0xDc00 to the low ten bits, and then adding 0x10000 ::core::char::from_u32_unchecked( (u32::from(non_ascii) << 10) + second - (((0xD800u32 << 10) - 0x1_0000u32) + 0xDC00u32), @@ -456,6 +507,7 @@ impl SingleByteEncoder { converted + 1, // +1 `for non_ascii` converted, ); + // Safety: This branch diverges, so no need to uphold invariants on `converted` } } // Next, handle ASCII punctuation and non-ASCII without @@ -469,8 +521,12 @@ impl SingleByteEncoder { if converted == length { return (pending, length, length); } + // Safety: we're back to `converted < length` due to the == above and can perform + // the unchecked read let mut unit = unsafe { *(src.get_unchecked(converted)) }; 'innermost: loop { + // Safety: This loop always begins with `converted < length`, see + // the invariant outside and the comment on the continue below if unit > 127 { non_ascii = unit; continue 'middle; @@ -479,19 +535,25 @@ impl SingleByteEncoder { // byte unconditionally instead of trying to unread it // to make it part of the next SIMD stride. unsafe { + // Safety: Can rely on converted < length *(dst.get_unchecked_mut(converted)) = unit as u8; } converted += 1; + // `converted <= length` here if unit < 60 { // We've got punctuation if converted == length { return (pending, length, length); } + // Safety: `converted < length` due to the == above. The read is safe. unit = unsafe { *(src.get_unchecked(converted)) }; + // Safety: This only happens if `converted < length`, maintaining it continue 'innermost; } // We've got markup or ASCII text continue 'outermost; + // Safety: All other routes to here diverge so the continue is the only + // way to run the innermost loop. } } } |