diff options
Diffstat (limited to 'src/arrow/cpp/src/arrow/util/utf8.h')
-rw-r--r-- | src/arrow/cpp/src/arrow/util/utf8.h | 566 |
1 files changed, 566 insertions, 0 deletions
diff --git a/src/arrow/cpp/src/arrow/util/utf8.h b/src/arrow/cpp/src/arrow/util/utf8.h new file mode 100644 index 000000000..45cdcd833 --- /dev/null +++ b/src/arrow/cpp/src/arrow/util/utf8.h @@ -0,0 +1,566 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cassert> +#include <cstdint> +#include <cstring> +#include <memory> +#include <string> + +#if defined(ARROW_HAVE_NEON) || defined(ARROW_HAVE_SSE4_2) +#include <xsimd/xsimd.hpp> +#endif + +#include "arrow/type_fwd.h" +#include "arrow/util/macros.h" +#include "arrow/util/simd.h" +#include "arrow/util/string_view.h" +#include "arrow/util/ubsan.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace util { + +// Convert a UTF8 string to a wstring (either UTF16 or UTF32, depending +// on the wchar_t width). +ARROW_EXPORT Result<std::wstring> UTF8ToWideString(const std::string& source); + +// Similarly, convert a wstring to a UTF8 string. +ARROW_EXPORT Result<std::string> WideStringToUTF8(const std::wstring& source); + +namespace internal { + +// Copyright (c) 2008-2010 Bjoern Hoehrmann <bjoern@hoehrmann.de> +// See http://bjoern.hoehrmann.de/utf-8/decoder/dfa/ for details. + +// A compact state table allowing UTF8 decoding using two dependent +// lookups per byte. The first lookup determines the character class +// and the second lookup reads the next state. +// In this table states are multiples of 12. +ARROW_EXPORT extern const uint8_t utf8_small_table[256 + 9 * 12]; + +// Success / reject states when looked up in the small table +static constexpr uint8_t kUTF8DecodeAccept = 0; +static constexpr uint8_t kUTF8DecodeReject = 12; + +// An expanded state table allowing transitions using a single lookup +// at the expense of a larger memory footprint (but on non-random data, +// not all the table will end up accessed and cached). +// In this table states are multiples of 256. +ARROW_EXPORT extern uint16_t utf8_large_table[9 * 256]; + +ARROW_EXPORT extern const uint8_t utf8_byte_size_table[16]; + +// Success / reject states when looked up in the large table +static constexpr uint16_t kUTF8ValidateAccept = 0; +static constexpr uint16_t kUTF8ValidateReject = 256; + +static inline uint8_t DecodeOneUTF8Byte(uint8_t byte, uint8_t state, uint32_t* codep) { + uint8_t type = utf8_small_table[byte]; + + *codep = (state != kUTF8DecodeAccept) ? (byte & 0x3fu) | (*codep << 6) + : (0xff >> type) & (byte); + + state = utf8_small_table[256 + state + type]; + return state; +} + +static inline uint16_t ValidateOneUTF8Byte(uint8_t byte, uint16_t state) { + return utf8_large_table[state + byte]; +} + +ARROW_EXPORT void CheckUTF8Initialized(); + +} // namespace internal + +// This function needs to be called before doing UTF8 validation. +ARROW_EXPORT void InitializeUTF8(); + +static inline bool ValidateUTF8(const uint8_t* data, int64_t size) { + static constexpr uint64_t high_bits_64 = 0x8080808080808080ULL; + static constexpr uint32_t high_bits_32 = 0x80808080UL; + static constexpr uint16_t high_bits_16 = 0x8080U; + static constexpr uint8_t high_bits_8 = 0x80U; + +#ifndef NDEBUG + internal::CheckUTF8Initialized(); +#endif + + while (size >= 8) { + // XXX This is doing an unaligned access. Contemporary architectures + // (x86-64, AArch64, PPC64) support it natively and often have good + // performance nevertheless. + uint64_t mask64 = SafeLoadAs<uint64_t>(data); + if (ARROW_PREDICT_TRUE((mask64 & high_bits_64) == 0)) { + // 8 bytes of pure ASCII, move forward + size -= 8; + data += 8; + continue; + } + // Non-ASCII run detected. + // We process at least 4 bytes, to avoid too many spurious 64-bit reads + // in case the non-ASCII bytes are at the end of the tested 64-bit word. + // We also only check for rejection at the end since that state is stable + // (once in reject state, we always remain in reject state). + // It is guaranteed that size >= 8 when arriving here, which allows + // us to avoid size checks. + uint16_t state = internal::kUTF8ValidateAccept; + // Byte 0 + state = internal::ValidateOneUTF8Byte(*data++, state); + --size; + // Byte 1 + state = internal::ValidateOneUTF8Byte(*data++, state); + --size; + // Byte 2 + state = internal::ValidateOneUTF8Byte(*data++, state); + --size; + // Byte 3 + state = internal::ValidateOneUTF8Byte(*data++, state); + --size; + // Byte 4 + state = internal::ValidateOneUTF8Byte(*data++, state); + --size; + if (state == internal::kUTF8ValidateAccept) { + continue; // Got full char, switch back to ASCII detection + } + // Byte 5 + state = internal::ValidateOneUTF8Byte(*data++, state); + --size; + if (state == internal::kUTF8ValidateAccept) { + continue; // Got full char, switch back to ASCII detection + } + // Byte 6 + state = internal::ValidateOneUTF8Byte(*data++, state); + --size; + if (state == internal::kUTF8ValidateAccept) { + continue; // Got full char, switch back to ASCII detection + } + // Byte 7 + state = internal::ValidateOneUTF8Byte(*data++, state); + --size; + if (state == internal::kUTF8ValidateAccept) { + continue; // Got full char, switch back to ASCII detection + } + // kUTF8ValidateAccept not reached along 4 transitions has to mean a rejection + assert(state == internal::kUTF8ValidateReject); + return false; + } + + // Check if string tail is full ASCII (common case, fast) + if (size >= 4) { + uint32_t tail_mask = SafeLoadAs<uint32_t>(data + size - 4); + uint32_t head_mask = SafeLoadAs<uint32_t>(data); + if (ARROW_PREDICT_TRUE(((head_mask | tail_mask) & high_bits_32) == 0)) { + return true; + } + } else if (size >= 2) { + uint16_t tail_mask = SafeLoadAs<uint16_t>(data + size - 2); + uint16_t head_mask = SafeLoadAs<uint16_t>(data); + if (ARROW_PREDICT_TRUE(((head_mask | tail_mask) & high_bits_16) == 0)) { + return true; + } + } else if (size == 1) { + if (ARROW_PREDICT_TRUE((*data & high_bits_8) == 0)) { + return true; + } + } else { + /* size == 0 */ + return true; + } + + // Fall back to UTF8 validation of tail string. + // Note the state table is designed so that, once in the reject state, + // we remain in that state until the end. So we needn't check for + // rejection at each char (we don't gain much by short-circuiting here). + uint16_t state = internal::kUTF8ValidateAccept; + switch (size) { + case 7: + state = internal::ValidateOneUTF8Byte(data[size - 7], state); + case 6: + state = internal::ValidateOneUTF8Byte(data[size - 6], state); + case 5: + state = internal::ValidateOneUTF8Byte(data[size - 5], state); + case 4: + state = internal::ValidateOneUTF8Byte(data[size - 4], state); + case 3: + state = internal::ValidateOneUTF8Byte(data[size - 3], state); + case 2: + state = internal::ValidateOneUTF8Byte(data[size - 2], state); + case 1: + state = internal::ValidateOneUTF8Byte(data[size - 1], state); + default: + break; + } + return ARROW_PREDICT_TRUE(state == internal::kUTF8ValidateAccept); +} + +static inline bool ValidateUTF8(const util::string_view& str) { + const uint8_t* data = reinterpret_cast<const uint8_t*>(str.data()); + const size_t length = str.size(); + + return ValidateUTF8(data, length); +} + +static inline bool ValidateAsciiSw(const uint8_t* data, int64_t len) { + uint8_t orall = 0; + + if (len >= 8) { + uint64_t or8 = 0; + + do { + or8 |= SafeLoadAs<uint64_t>(data); + data += 8; + len -= 8; + } while (len >= 8); + + orall = !(or8 & 0x8080808080808080ULL) - 1; + } + + while (len--) { + orall |= *data++; + } + + return orall < 0x80U; +} + +#if defined(ARROW_HAVE_NEON) || defined(ARROW_HAVE_SSE4_2) +static inline bool ValidateAsciiSimd(const uint8_t* data, int64_t len) { +#ifdef ARROW_HAVE_NEON + using simd_batch = xsimd::batch<int8_t, xsimd::neon64>; +#else + using simd_batch = xsimd::batch<int8_t, xsimd::sse4_2>; +#endif + + if (len >= 32) { + const simd_batch zero(static_cast<int8_t>(0)); + const uint8_t* data2 = data + 16; + simd_batch or1 = zero, or2 = zero; + + while (len >= 32) { + or1 |= simd_batch::load_unaligned(reinterpret_cast<const int8_t*>(data)); + or2 |= simd_batch::load_unaligned(reinterpret_cast<const int8_t*>(data2)); + data += 32; + data2 += 32; + len -= 32; + } + + // To test for upper bit in all bytes, test whether any of them is negative + or1 |= or2; + if (xsimd::any(or1 < zero)) { + return false; + } + } + + return ValidateAsciiSw(data, len); +} +#endif // ARROW_HAVE_SSE4_2 + +static inline bool ValidateAscii(const uint8_t* data, int64_t len) { +#if defined(ARROW_HAVE_NEON) || defined(ARROW_HAVE_SSE4_2) + return ValidateAsciiSimd(data, len); +#else + return ValidateAsciiSw(data, len); +#endif +} + +static inline bool ValidateAscii(const util::string_view& str) { + const uint8_t* data = reinterpret_cast<const uint8_t*>(str.data()); + const size_t length = str.size(); + + return ValidateAscii(data, length); +} + +// Skip UTF8 byte order mark, if any. +ARROW_EXPORT +Result<const uint8_t*> SkipUTF8BOM(const uint8_t* data, int64_t size); + +static constexpr uint32_t kMaxUnicodeCodepoint = 0x110000; + +// size of a valid UTF8 can be determined by looking at leading 4 bits of BYTE1 +// utf8_byte_size_table[0..7] --> pure ascii chars --> 1B length +// utf8_byte_size_table[8..11] --> internal bytes --> 1B length +// utf8_byte_size_table[12,13] --> 2B long UTF8 chars +// utf8_byte_size_table[14] --> 3B long UTF8 chars +// utf8_byte_size_table[15] --> 4B long UTF8 chars +// NOTE: Results for invalid/ malformed utf-8 sequences are undefined. +// ex: \xFF... returns 4B +static inline uint8_t ValidUtf8CodepointByteSize(const uint8_t* codeunit) { + return internal::utf8_byte_size_table[*codeunit >> 4]; +} + +static inline bool Utf8IsContinuation(const uint8_t codeunit) { + return (codeunit & 0xC0) == 0x80; // upper two bits should be 10 +} + +static inline bool Utf8Is2ByteStart(const uint8_t codeunit) { + return (codeunit & 0xE0) == 0xC0; // upper three bits should be 110 +} + +static inline bool Utf8Is3ByteStart(const uint8_t codeunit) { + return (codeunit & 0xF0) == 0xE0; // upper four bits should be 1110 +} + +static inline bool Utf8Is4ByteStart(const uint8_t codeunit) { + return (codeunit & 0xF8) == 0xF0; // upper five bits should be 11110 +} + +static inline uint8_t* UTF8Encode(uint8_t* str, uint32_t codepoint) { + if (codepoint < 0x80) { + *str++ = codepoint; + } else if (codepoint < 0x800) { + *str++ = 0xC0 + (codepoint >> 6); + *str++ = 0x80 + (codepoint & 0x3F); + } else if (codepoint < 0x10000) { + *str++ = 0xE0 + (codepoint >> 12); + *str++ = 0x80 + ((codepoint >> 6) & 0x3F); + *str++ = 0x80 + (codepoint & 0x3F); + } else { + // Assume proper codepoints are always passed + assert(codepoint < kMaxUnicodeCodepoint); + *str++ = 0xF0 + (codepoint >> 18); + *str++ = 0x80 + ((codepoint >> 12) & 0x3F); + *str++ = 0x80 + ((codepoint >> 6) & 0x3F); + *str++ = 0x80 + (codepoint & 0x3F); + } + return str; +} + +static inline bool UTF8Decode(const uint8_t** data, uint32_t* codepoint) { + const uint8_t* str = *data; + if (*str < 0x80) { // ascii + *codepoint = *str++; + } else if (ARROW_PREDICT_FALSE(*str < 0xC0)) { // invalid non-ascii char + return false; + } else if (*str < 0xE0) { + uint8_t code_unit_1 = (*str++) & 0x1F; // take last 5 bits + if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) { + return false; + } + uint8_t code_unit_2 = (*str++) & 0x3F; // take last 6 bits + *codepoint = (code_unit_1 << 6) + code_unit_2; + } else if (*str < 0xF0) { + uint8_t code_unit_1 = (*str++) & 0x0F; // take last 4 bits + if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) { + return false; + } + uint8_t code_unit_2 = (*str++) & 0x3F; // take last 6 bits + if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) { + return false; + } + uint8_t code_unit_3 = (*str++) & 0x3F; // take last 6 bits + *codepoint = (code_unit_1 << 12) + (code_unit_2 << 6) + code_unit_3; + } else if (*str < 0xF8) { + uint8_t code_unit_1 = (*str++) & 0x07; // take last 3 bits + if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) { + return false; + } + uint8_t code_unit_2 = (*str++) & 0x3F; // take last 6 bits + if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) { + return false; + } + uint8_t code_unit_3 = (*str++) & 0x3F; // take last 6 bits + if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) { + return false; + } + uint8_t code_unit_4 = (*str++) & 0x3F; // take last 6 bits + *codepoint = + (code_unit_1 << 18) + (code_unit_2 << 12) + (code_unit_3 << 6) + code_unit_4; + } else { // invalid non-ascii char + return false; + } + *data = str; + return true; +} + +static inline bool UTF8DecodeReverse(const uint8_t** data, uint32_t* codepoint) { + const uint8_t* str = *data; + if (*str < 0x80) { // ascii + *codepoint = *str--; + } else { + if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) { + return false; + } + uint8_t code_unit_N = (*str--) & 0x3F; // take last 6 bits + if (Utf8Is2ByteStart(*str)) { + uint8_t code_unit_1 = (*str--) & 0x1F; // take last 5 bits + *codepoint = (code_unit_1 << 6) + code_unit_N; + } else { + if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) { + return false; + } + uint8_t code_unit_Nmin1 = (*str--) & 0x3F; // take last 6 bits + if (Utf8Is3ByteStart(*str)) { + uint8_t code_unit_1 = (*str--) & 0x0F; // take last 4 bits + *codepoint = (code_unit_1 << 12) + (code_unit_Nmin1 << 6) + code_unit_N; + } else { + if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) { + return false; + } + uint8_t code_unit_Nmin2 = (*str--) & 0x3F; // take last 6 bits + if (ARROW_PREDICT_TRUE(Utf8Is4ByteStart(*str))) { + uint8_t code_unit_1 = (*str--) & 0x07; // take last 3 bits + *codepoint = (code_unit_1 << 18) + (code_unit_Nmin2 << 12) + + (code_unit_Nmin1 << 6) + code_unit_N; + } else { + return false; + } + } + } + } + *data = str; + return true; +} + +template <class UnaryOperation> +static inline bool UTF8Transform(const uint8_t* first, const uint8_t* last, + uint8_t** destination, UnaryOperation&& unary_op) { + const uint8_t* i = first; + uint8_t* out = *destination; + while (i < last) { + uint32_t codepoint = 0; + if (ARROW_PREDICT_FALSE(!UTF8Decode(&i, &codepoint))) { + return false; + } + out = UTF8Encode(out, unary_op(codepoint)); + } + *destination = out; + return true; +} + +template <class Predicate> +static inline bool UTF8FindIf(const uint8_t* first, const uint8_t* last, + Predicate&& predicate, const uint8_t** position) { + const uint8_t* i = first; + while (i < last) { + uint32_t codepoint = 0; + const uint8_t* current = i; + if (ARROW_PREDICT_FALSE(!UTF8Decode(&i, &codepoint))) { + return false; + } + if (predicate(codepoint)) { + *position = current; + return true; + } + } + *position = last; + return true; +} + +// Same semantics as std::find_if using reverse iterators with the return value +// having the same semantics as std::reverse_iterator<..>.base() +// A reverse iterator physically points to the next address, e.g.: +// &*reverse_iterator(i) == &*(i + 1) +template <class Predicate> +static inline bool UTF8FindIfReverse(const uint8_t* first, const uint8_t* last, + Predicate&& predicate, const uint8_t** position) { + // converts to a normal point + const uint8_t* i = last - 1; + while (i >= first) { + uint32_t codepoint = 0; + const uint8_t* current = i; + if (ARROW_PREDICT_FALSE(!UTF8DecodeReverse(&i, &codepoint))) { + return false; + } + if (predicate(codepoint)) { + // converts normal pointer to 'reverse iterator semantics'. + *position = current + 1; + return true; + } + } + // similar to how an end pointer point to 1 beyond the last, reverse iterators point + // to the 'first' pointer to indicate out of range. + *position = first; + return true; +} + +static inline bool UTF8AdvanceCodepoints(const uint8_t* first, const uint8_t* last, + const uint8_t** destination, int64_t n) { + return UTF8FindIf( + first, last, + [&](uint32_t codepoint) { + bool done = n == 0; + n--; + return done; + }, + destination); +} + +static inline bool UTF8AdvanceCodepointsReverse(const uint8_t* first, const uint8_t* last, + const uint8_t** destination, int64_t n) { + return UTF8FindIfReverse( + first, last, + [&](uint32_t codepoint) { + bool done = n == 0; + n--; + return done; + }, + destination); +} + +template <class UnaryFunction> +static inline bool UTF8ForEach(const uint8_t* first, const uint8_t* last, + UnaryFunction&& f) { + const uint8_t* i = first; + while (i < last) { + uint32_t codepoint = 0; + if (ARROW_PREDICT_FALSE(!UTF8Decode(&i, &codepoint))) { + return false; + } + f(codepoint); + } + return true; +} + +template <class UnaryFunction> +static inline bool UTF8ForEach(const std::string& s, UnaryFunction&& f) { + return UTF8ForEach(reinterpret_cast<const uint8_t*>(s.data()), + reinterpret_cast<const uint8_t*>(s.data() + s.length()), + std::forward<UnaryFunction>(f)); +} + +template <class UnaryPredicate> +static inline bool UTF8AllOf(const uint8_t* first, const uint8_t* last, bool* result, + UnaryPredicate&& predicate) { + const uint8_t* i = first; + while (i < last) { + uint32_t codepoint = 0; + if (ARROW_PREDICT_FALSE(!UTF8Decode(&i, &codepoint))) { + return false; + } + + if (!predicate(codepoint)) { + *result = false; + return true; + } + } + *result = true; + return true; +} + +/// Count the number of codepoints in the given string (assuming it is valid UTF8). +static inline int64_t UTF8Length(const uint8_t* first, const uint8_t* last) { + int64_t length = 0; + while (first != last) { + length += ((*first++ & 0xc0) != 0x80); + } + return length; +} + +} // namespace util +} // namespace arrow |